diff --git a/app/otbTensorflowModelServe.cxx b/app/otbTensorflowModelServe.cxx index cfed5516f534f8240ed86e0bbf3a309ce39ab3e7..cc5e073494d10931bc04a039619f4411adc968bd 100644 --- a/app/otbTensorflowModelServe.cxx +++ b/app/otbTensorflowModelServe.cxx @@ -30,8 +30,7 @@ #include "otbTensorflowSource.h" // Streaming -#include "otbImageRegionSquareTileSplitter.h" -#include "itkStreamingImageFilter.h" +#include "otbTensorflowStreamerFilter.h" namespace otb { @@ -58,7 +57,7 @@ public: /** Typedef for streaming */ typedef otb::ImageRegionSquareTileSplitter TileSplitterType; - typedef itk::StreamingImageFilter StreamingFilterType; + typedef otb::TensorflowStreamerFilter StreamingFilterType; /** Typedefs for images */ typedef FloatVectorImageType::SizeType SizeType; @@ -198,9 +197,12 @@ public: AddParameter(ParameterType_Bool, "optim.disabletiling", "Disable tiling"); MandatoryOff ("optim.disabletiling"); SetParameterDescription ("optim.disabletiling", "Tiling avoids to process a too large subset of image, but sometimes it can be useful to disable it"); - AddParameter(ParameterType_Int, "optim.tilesize", "Tile width used to stream the filter output"); - SetMinimumParameterIntValue ("optim.tilesize", 1); - SetDefaultParameterInt ("optim.tilesize", 16); + AddParameter(ParameterType_Int, "optim.tilesizex", "Tile width used to stream the filter output"); + SetMinimumParameterIntValue ("optim.tilesizex", 1); + SetDefaultParameterInt ("optim.tilesizex", 16); + AddParameter(ParameterType_Int, "optim.tilesizey", "Tile height used to stream the filter output"); + SetMinimumParameterIntValue ("optim.tilesizey", 1); + SetDefaultParameterInt ("optim.tilesizey", 16); // Output image AddParameter(ParameterType_OutputImage, "out", "output image"); @@ -292,22 +294,28 @@ public: if (GetParameterInt("optim.disabletiling") != 1) { // Get the tile size - const unsigned int tileSize = GetParameterInt("optim.tilesize"); - otbAppLogINFO("Force tiling with squared tiles of " << tileSize) + SizeType tileSize; + tileSize[0] = GetParameterInt("optim.tilesizex"); + tileSize[1] = GetParameterInt("optim.tilesizey"); + + // Check that the tile size is aligned to the field of expression + for (unsigned int i = 0 ; i < FloatVectorImageType::ImageDimension ; i++) + if (tileSize[i] % foe[i] != 0) + { + SizeType::SizeValueType newSize = 1 + std::floor(tileSize[i] / foe[i]); + newSize *= foe[i]; - // Update the TensorFlow filter output information to get the output image size - m_TFFilter->UpdateOutputInformation(); + otbAppLogWARNING("Aligning the tiling to the output expression field " + << "for better performances (dim " << i << "). New value set to " << newSize) - // Splitting using square tiles - TileSplitterType::Pointer splitter = TileSplitterType::New(); - splitter->SetTileSizeAlignment(tileSize); - unsigned int nbDesiredTiles = itk::Math::Ceil( - double(m_TFFilter->GetOutput()->GetLargestPossibleRegion().GetNumberOfPixels() ) / (tileSize * tileSize) ); + tileSize[i] = newSize; + } + + otbAppLogINFO("Force tiling with squared tiles of " << tileSize) - // Use an itk::StreamingImageFilter to force the computation tile by tile + // Force the computation tile by tile m_StreamFilter = StreamingFilterType::New(); - m_StreamFilter->SetRegionSplitter(splitter); - m_StreamFilter->SetNumberOfStreamDivisions(nbDesiredTiles); + m_StreamFilter->SetOutputGridSize(tileSize); m_StreamFilter->SetInput(m_TFFilter->GetOutput()); SetParameterOutputImage("out", m_StreamFilter->GetOutput()); diff --git a/app/otbTensorflowModelTrain.cxx b/app/otbTensorflowModelTrain.cxx index c20466408d9b686d7243fc3cb550d7dd1e65344d..5315cbb7a5d72a396787c0fa71360796265890e2 100644 --- a/app/otbTensorflowModelTrain.cxx +++ b/app/otbTensorflowModelTrain.cxx @@ -438,6 +438,7 @@ public: } // Setup the validation filter + const bool do_validation = HasUserValue("validation.mode"); if (GetParameterInt("validation.mode")==1) // class { otbAppLogINFO("Set validation mode to classification validation"); @@ -467,50 +468,53 @@ public: AddProcess(m_TrainModelFilter, "Training epoch #" + std::to_string(epoch)); m_TrainModelFilter->Update(); - // Validate the model - if (epoch % GetParameterInt("validation.step") == 0) + if (do_validation) + { + // Validate the model + if (epoch % GetParameterInt("validation.step") == 0) { - // 1. Evaluate the metrics against the learning data + // 1. Evaluate the metrics against the learning data - for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++) + for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++) { - m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstLearningData[i]); + m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstLearningData[i]); } - m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData); + m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData); - // As we use the learning data here, it's rational to use the same option as streaming during training - m_ValidateModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming")); + // As we use the learning data here, it's rational to use the same option as streaming during training + m_ValidateModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming")); - // Update - AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)"); - m_ValidateModelFilter->Update(); + // Update + AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)"); + m_ValidateModelFilter->Update(); - for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) + for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) { - otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); - PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i)); + otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); + PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i)); } - // 2. Evaluate the metrics against the validation data + // 2. Evaluate the metrics against the validation data - // Here we just change the input sources and references - for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstValidationData.size() ; i++) + // Here we just change the input sources and references + for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstValidationData.size() ; i++) { - m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstValidationData[i]); + m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstValidationData[i]); } - m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstValidationData); - m_ValidateModelFilter->SetUseStreaming(GetParameterInt("validation.usestreaming")); + m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstValidationData); + m_ValidateModelFilter->SetUseStreaming(GetParameterInt("validation.usestreaming")); - // Update - AddProcess(m_ValidateModelFilter, "Evaluate model (Validation data)"); - m_ValidateModelFilter->Update(); + // Update + AddProcess(m_ValidateModelFilter, "Evaluate model (Validation data)"); + m_ValidateModelFilter->Update(); - for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) + for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) { - otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); - PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i)); + otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); + PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i)); } } // Step is OK to perform validation + } // Do the validation against the validation data } // Next epoch diff --git a/app/otbTrainClassifierFromDeepFeatures.cxx b/app/otbTrainClassifierFromDeepFeatures.cxx index 343b290fd446b92aa75cac0061f1ca2462721348..5fa979b36ed6a16dec8dc6c0d81a34c8aaeb39a0 100644 --- a/app/otbTrainClassifierFromDeepFeatures.cxx +++ b/app/otbTrainClassifierFromDeepFeatures.cxx @@ -93,6 +93,7 @@ private: ShareParameter("optim", "tfmodel.optim", "Processing time optimization", "This group of parameters allows optimization of processing time"); // Train shared parameters + ShareParameter("ram", "train.ram", "Available RAM (Mb)", "Available RAM (Mb)"); ShareParameter("vd", "train.io.vd", "Vector data for training", "Input vector data for training"); ShareParameter("valid", "train.io.valid", "Vector data for validation", "Input vector data for validation"); ShareParameter("out", "train.io.out", "Output classification model", "Output classification model"); diff --git a/include/otbTensorflowMultisourceModelFilter.hxx b/include/otbTensorflowMultisourceModelFilter.hxx index 9aaeeda91856eb83e2532cf6eefc3fbc8a18843b..77902bbdf6c045fe8fa2a9f2136919dd781ff54c 100644 --- a/include/otbTensorflowMultisourceModelFilter.hxx +++ b/include/otbTensorflowMultisourceModelFilter.hxx @@ -368,9 +368,6 @@ TensorflowMultisourceModelFilter RegionType outputAlignedReqRegion(outputReqRegion); EnlargeToAlignedRegion(outputAlignedReqRegion); - // Add a progress reporter - itk::ProgressReporter progress(this, 0, outputReqRegion.GetNumberOfPixels()); - const unsigned int nInputs = this->GetNumberOfInputs(); // Create input tensors list diff --git a/include/otbTensorflowMultisourceModelTrain.hxx b/include/otbTensorflowMultisourceModelTrain.hxx index 6e3a98988d53bf1f5f78d6876995278082eafea9..520c4f3db645a9823402485490930618b23384ee 100644 --- a/include/otbTensorflowMultisourceModelTrain.hxx +++ b/include/otbTensorflowMultisourceModelTrain.hxx @@ -55,6 +55,12 @@ TensorflowMultisourceModelTrain TensorListType outputs; this->RunSession(inputs, outputs); + // Display outputs tensors + for (auto& o: outputs) + { + tf::PrintTensorInfos(o); + } + } diff --git a/include/otbTensorflowStreamerFilter.h b/include/otbTensorflowStreamerFilter.h new file mode 100644 index 0000000000000000000000000000000000000000..b58e92923b5b5b4ad78d5cc3d48c487c97594bfe --- /dev/null +++ b/include/otbTensorflowStreamerFilter.h @@ -0,0 +1,79 @@ +/*========================================================================= + + Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef otbTensorflowStreamerFilter_h +#define otbTensorflowStreamerFilter_h + +// Image2image +#include "itkImageToImageFilter.h" + +namespace otb +{ + +/** + * \class TensorflowStreamerFilter + * \brief This filter generates an output image with an internal + * explicit streaming mechanism. + * + * \ingroup OTBTensorflow + */ +template +class ITK_EXPORT TensorflowStreamerFilter : +public itk::ImageToImageFilter +{ + +public: + + /** Standard class typedefs. */ + typedef TensorflowStreamerFilter Self; + typedef itk::ImageToImageFilter Superclass; + typedef itk::SmartPointer Pointer; + typedef itk::SmartPointer ConstPointer; + + /** Method for creation through the object factory. */ + itkNewMacro(Self); + + /** Run-time type information (and related methods). */ + itkTypeMacro(TensorflowStreamerFilter, itk::ImageToImageFilter); + + /** Images typedefs */ + typedef typename Superclass::InputImageType ImageType; + typedef typename ImageType::IndexType IndexType; + typedef typename ImageType::IndexValueType IndexValueType; + typedef typename ImageType::SizeType SizeType; + typedef typename Superclass::InputImageRegionType RegionType; + + typedef TOutputImage OutputImageType; + + itkSetMacro(OutputGridSize, SizeType); + itkGetMacro(OutputGridSize, SizeType); + +protected: + TensorflowStreamerFilter(); + virtual ~TensorflowStreamerFilter() {}; + + virtual void UpdateOutputData(itk::DataObject *output){(void) output; this->GenerateData();} + + virtual void GenerateData(); + +private: + TensorflowStreamerFilter(const Self&); //purposely not implemented + void operator=(const Self&); //purposely not implemented + + SizeType m_OutputGridSize; // Output grid size + +}; // end class + + +} // end namespace otb + +#include "otbTensorflowStreamerFilter.hxx" + +#endif diff --git a/include/otbTensorflowStreamerFilter.hxx b/include/otbTensorflowStreamerFilter.hxx new file mode 100644 index 0000000000000000000000000000000000000000..54b8563f7ee49f87477697471458c314c92dd0e9 --- /dev/null +++ b/include/otbTensorflowStreamerFilter.hxx @@ -0,0 +1,107 @@ +/*========================================================================= + + Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef otbTensorflowStreamerFilter_txx +#define otbTensorflowStreamerFilter_txx + +#include "otbTensorflowStreamerFilter.h" +#include "itkImageAlgorithm.h" + +namespace otb +{ + +template +TensorflowStreamerFilter +::TensorflowStreamerFilter() + { + m_OutputGridSize.Fill(1); + } + +/** + * Compute the output image + */ +template +void +TensorflowStreamerFilter +::GenerateData() + { + // Output pointer and requested region + OutputImageType * outputPtr = this->GetOutput(); + const RegionType outputReqRegion = outputPtr->GetRequestedRegion(); + outputPtr->SetBufferedRegion(outputReqRegion); + outputPtr->Allocate(); + + // Compute the aligned region + RegionType region; + for(unsigned int dim = 0; dim 0) + { + upper += m_OutputGridSize[dim] - deltaUp; + } + + // Update region + region.SetIndex(dim, lower); + region.SetSize(dim, upper - lower); + + } + + // Compute the number of subregions to process + const unsigned int nbTilesX = region.GetSize(0) / m_OutputGridSize[0]; + const unsigned int nbTilesY = region.GetSize(1) / m_OutputGridSize[1]; + + // Progress + itk::ProgressReporter progress(this, 0, nbTilesX*nbTilesY); + + // For each tile, propagate the input region and recopy the output + ImageType * inputImage = static_cast( Superclass::ProcessObject::GetInput(0) ); + unsigned int tx, ty; + RegionType subRegion; + subRegion.SetSize(m_OutputGridSize); + for (ty = 0; ty < nbTilesY; ty++) + { + subRegion.SetIndex(1, ty*m_OutputGridSize[1] + region.GetIndex(1)); + for (tx = 0; tx < nbTilesX; tx++) + { + // Update the input subregion + subRegion.SetIndex(0, tx*m_OutputGridSize[0] + region.GetIndex(0)); + + // The actual region to copy + RegionType cpyRegion(subRegion); + cpyRegion.Crop(outputReqRegion); + + // Propagate region + inputImage->SetRequestedRegion(cpyRegion); + inputImage->PropagateRequestedRegion(); + inputImage->UpdateOutputData(); + + // Copy the subregion to output + itk::ImageAlgorithm::Copy( inputImage, outputPtr, cpyRegion, cpyRegion ); + + progress.CompletedPixel(); + } + } + } + + +} // end namespace otb + + +#endif diff --git a/python/ckpt2savedmodel.py b/python/ckpt2savedmodel.py new file mode 100644 index 0000000000000000000000000000000000000000..462d133a6989e4f2eb5bd46b22c1c0179e636a79 --- /dev/null +++ b/python/ckpt2savedmodel.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +#========================================================================== +# +# Copyright Remi Cresson (IRSTEA) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +#==========================================================================*/ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +from tricks import * + +# Logging +tf.logging.set_verbosity(tf.logging.INFO) + +# Parser +parser = argparse.ArgumentParser() +parser.add_argument("--ckpt", help="checkpoint file prefix", required=True) +parser.add_argument("--inputs", help="input placeholder names", required=True, nargs='+') +parser.add_argument("--outputs", help="output placeholder names", required=True, nargs='+') +parser.add_argument("--model", help="output SavedModel", required=True) +params = parser.parse_args() + +if __name__ == "__main__": + + CheckpointToSavedModel(params.ckpt, params.inputs, params.outputs, params.model) + + quit() diff --git a/python/create_model_ienco-m3_patchbased.py b/python/create_model_ienco-m3_patchbased.py index 0f1b8c4b31f470be12bbf2fea810f085927e05a5..be1bac453b8cfe6388e8ee7de8acf25f526be137 100644 --- a/python/create_model_ienco-m3_patchbased.py +++ b/python/create_model_ienco-m3_patchbased.py @@ -1,3 +1,21 @@ +# -*- coding: utf-8 -*- +#========================================================================== +# +# Copyright Remi Cresson, Dino Ienco (IRSTEA) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +#==========================================================================*/ import sys import os import numpy as np @@ -13,35 +31,6 @@ from sklearn.ensemble import RandomForestClassifier from sklearn.utils import shuffle from sklearn.metrics import confusion_matrix from tricks import * - -def export_model(sess, - export_dir, - x_cnn_placeholder, - x_rnn_placeholder, - is_training_placeholder, - testPrediction): - """ export a SavedModel - """ - - # Update the export dir - model_dir = export_dir + "/saved_model/" - if os.path.exists(model_dir): - shutil.rmtree(model_dir) - - print("Export model in " + model_dir) - - # Add a builder (for LoadSavedModel) - builder = tf.saved_model.builder.SavedModelBuilder(model_dir) - signature_def_map= { - "model": tf.saved_model.signature_def_utils.predict_signature_def( - inputs = {"x_cnn" : x_cnn_placeholder, - "x_rnn" : x_rnn_placeholder, - "is_training" : is_training_placeholder}, - outputs = {"prediction" : testPrediction}) - } - builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.TRAINING],signature_def_map) - builder.add_meta_graph([tf.saved_model.tag_constants.SERVING]) - builder.save() def checkTest(ts_data, vhsr_data, batchsz, label_test): tot_pred = [] @@ -60,11 +49,7 @@ def checkTest(ts_data, vhsr_data, batchsz, label_test): is_training_ph:True, dropout:0.0, x_cnn:batch_cnn_x}) - - del batch_rnn_x - del batch_cnn_x - del batch_y - + for el in pred_temp: tot_pred.append( el ) @@ -241,8 +226,8 @@ n_channels = 4 nclasses = 8 # check number of arguments -if len(sys.argv) != 7: - print("Usage : ") +if len(sys.argv) != 8: + print("Usage : ") sys.exit(1) ts_train = read_samples(sys.argv[1]) @@ -257,6 +242,8 @@ label_test = read_samples(sys.argv[6]) label_test = np.int32(label_test) print_histo(label_test, "label_test") +export_dir = read_samples(sys.argv[7]) + x_rnn = tf.placeholder(tf.float32,[None, 1, 1, n_dims*n_timestamps],name="x_rnn") x_cnn = tf.placeholder(tf.float32,[None, patch_window, patch_window, n_channels],name="x_cnn") y = tf.placeholder(tf.int32,[None, 1, 1, 1],name="y") @@ -324,19 +311,12 @@ for e in range(hm_epochs): lossi+=loss accS+=acc - del batch_rnn_x - del batch_cnn_x - del batch_y - - print "Epoch:",e,"Train loss:",lossi/iterations,"| accuracy:",accS/iterations c_loss = lossi/iterations if c_loss < best_loss: - save_path = saver.save(sess, "models/model") - print("Model saved in path: %s" % save_path) best_loss = c_loss - export_model(sess, "/tmp/m3_export", x_cnn, x_rnn, is_training_ph, testPrediction) + CreateSavedModel(sess, ["x_cnn:0","x_rnn:0","is_training:0"], ["prediction:0"], export_dir) test_acc = checkTest(ts_test, vhsr_test, 1024, label_test) \ No newline at end of file diff --git a/python/create_model_maggiori17_fullyconv.py b/python/create_model_maggiori17_fullyconv.py index efd1dd51fc0f5ff006383b71400d2faf5eee1e64..b4ba770dce41d7b49b57a5404042dd92c4704330 100644 --- a/python/create_model_maggiori17_fullyconv.py +++ b/python/create_model_maggiori17_fullyconv.py @@ -1,3 +1,21 @@ +# -*- coding: utf-8 -*- +#========================================================================== +# +# Copyright Remi Cresson (IRSTEA) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +#==========================================================================*/ from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -89,12 +107,11 @@ def main(unused_argv): # check number of arguments if len(sys.argv) != 4: - print("Usage : ") + print("Usage : ") sys.exit(1) # Export dir - log_dir = sys.argv[3] + '/model_checkpoints/' - export_dir = sys.argv[3] + '/model_export/' + export_dir = sys.argv[3] print("loading dataset") @@ -272,13 +289,9 @@ def main(unused_argv): if step % 10 == 0: # Print status to stdout. print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) - #print('Step %d: (%.3f sec)' % (step, duration)) - # Save a checkpoint and evaluate the model periodically. if (curr_epoch + 1) % 1 == 0: - checkpoint_file = os.path.join(log_dir, 'model.ckpt') - saver.save(sess, checkpoint_file, global_step=step) # Evaluate against the training set. print('Training Data Eval:') do_eval2(sess, @@ -301,16 +314,7 @@ def main(unused_argv): batch_size) # Let's export a SavedModel - shutil.rmtree(export_dir) - builder = tf.saved_model.builder.SavedModelBuilder(export_dir) - signature_def_map= { - "model": tf.saved_model.signature_def_utils.predict_signature_def( - inputs= {"x1": xs_placeholder}, - outputs= {"prediction": testPrediction}) - } - builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING], signature_def_map) - builder.add_meta_graph([tf.saved_model.tag_constants.SERVING]) - builder.save() + CreateSavedModel(sess, ["x1:0"], ["prediction:0"], export_dir) quit() diff --git a/python/tricks.py b/python/tricks.py index 911cac20e2df25229194aaf9ae8e8a889644bfdd..0b1ae614c7947fec9f99b08f78d8674c92ee35c9 100644 --- a/python/tricks.py +++ b/python/tricks.py @@ -1,9 +1,29 @@ +# -*- coding: utf-8 -*- +#========================================================================== +# +# Copyright Remi Cresson (IRSTEA) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +#==========================================================================*/ import sys import os import numpy as np import math import time import otbApplication +import tensorflow as tf +import shutil def flatten_nparray(np_arr): """ Returns a 1D numpy array retulting from the flatten of the input @@ -35,10 +55,13 @@ def print_tensor_info(name, tensor): print(name+" : "+str(tensor.shape)+" (dtype="+str(tensor.dtype)+")") -def read_samples(fn): +def read_samples(fn, single=False): """ Read an image of patches and return a 4D numpy array + TODO: Add an optional argument for the y-patchsize Args: fn: file name + single: a boolean telling if there is only 1 image in the batch. + In this case, the image can be rectangular (not squared) """ # Get input image size @@ -63,12 +86,15 @@ def read_samples(fn): print("Quick stats: min="+str(np.amin(outimg))+", max="+str(np.amax(outimg)) ) # reshape + if (single): + return np.copy(outimg.reshape((1, size_y, size_x, nbands))) + n_samples = int(size_y / size_x) outimg = outimg.reshape((n_samples, size_x, size_x, nbands)) print("Returned numpy array shape: "+str(outimg.shape)) return np.copy(outimg) - + def getBatch(X, Y, i, batch_size): start_id = i*batch_size end_id = min( (i+1) * batch_size, X.shape[0]) @@ -76,3 +102,55 @@ def getBatch(X, Y, i, batch_size): batch_y = Y[start_id:end_id] return batch_x, batch_y + +def CreateSavedModel(sess, inputs, outputs, directory): + """ + Create a SavedModel + + Args: + sess: the session + inputs: the list of input names + outputs: the list of output names + directory: the output path for the SavedModel + """ + + print("Create a SavedModel in " + directory) + + + # Get graph + graph = tf.get_default_graph() + + # Get inputs + input_dict = { i : graph.get_tensor_by_name(i) for i in inputs } + output_dict = { o : graph.get_tensor_by_name(o) for o in outputs } + + # Build the SavedModel + builder = tf.saved_model.builder.SavedModelBuilder(directory) + signature_def_map= { + "model": tf.saved_model.signature_def_utils.predict_signature_def( + input_dict, + output_dict) + } + builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.TRAINING],signature_def_map) + builder.add_meta_graph([tf.saved_model.tag_constants.SERVING]) + builder.save() + +def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path): + """ + Read a Checkpoint and build a SavedModel + + Args: + ckpt_path: path to the checkpoint file (without the ".meta" extension) + inputs: input list of placeholders names (e.g. ["x_cnn_1:0", "x_cnn_2:0"]) + outputs: output list of tensor outputs names (e.g. ["prediction:0", "features:0"]) + savedmodel_path: path to the SavedModel + """ + tf.reset_default_graph() + with tf.Session() as sess: + + # Restore variables from disk. + model_saver = tf.train.import_meta_graph(ckpt_path+".meta") + model_saver.restore(sess, ckpt_path) + + # Create a SavedModel + CreateSavedModel(sess, inputs, outputs, savedmodel_path) \ No newline at end of file