otbTensorflowModelTrain.cxx 24.32 KiB
/*=========================================================================
     Copyright (c) 2018-2019 IRSTEA
     Copyright (c) 2020-2021 INRAE
     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.
=========================================================================*/
#include "itkFixedArray.h"
#include "itkObjectFactory.h"
#include "otbWrapperApplicationFactory.h"
// Application engine
#include "otbStandardFilterWatcher.h"
#include "itkFixedArray.h"
// Tensorflow SavedModel
#include "tensorflow/cc/saved_model/loader.h"
// Tensorflow model train
#include "otbTensorflowMultisourceModelTrain.h"
#include "otbTensorflowMultisourceModelValidate.h"
// Tensorflow graph load
#include "otbTensorflowGraphOperations.h"
// Layerstack
#include "otbTensorflowSource.h"
// Metrics
#include "otbConfusionMatrixMeasurements.h"
namespace otb
namespace Wrapper
class TensorflowModelTrain : public Application
public:
  /** Standard class typedefs. */
  typedef TensorflowModelTrain                       Self;
  typedef Application                                Superclass;
  typedef itk::SmartPointer<Self>                    Pointer;
  typedef itk::SmartPointer<const Self>              ConstPointer;
  /** Standard macro */
  itkNewMacro(Self);
  itkTypeMacro(TensorflowModelTrain, Application);
  /** Typedefs for TensorFlow */
  typedef otb::TensorflowMultisourceModelTrain<FloatVectorImageType>    TrainModelFilterType;
  typedef otb::TensorflowMultisourceModelValidate<FloatVectorImageType> ValidateModelFilterType;
  typedef otb::TensorflowSource<FloatVectorImageType>                   TFSource;
  /* Typedefs for evaluation metrics */
  typedef ValidateModelFilterType::ConfMatType                          ConfMatType;
  typedef ValidateModelFilterType::MapOfClassesType                     MapOfClassesType;
  typedef ValidateModelFilterType::LabelValueType                       LabelValueType;
  typedef otb::ConfusionMatrixMeasurements<ConfMatType, LabelValueType> ConfusionMatrixCalculatorType;
  // Store stuff related to one source
  struct ProcessObjectsBundle
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
{ TFSource tfSource; TFSource tfSourceForValidation; // Parameters keys std::string m_KeyInForTrain; // Key of input image list (training) std::string m_KeyInForValid; // Key of input image list (validation) std::string m_KeyPHNameForTrain; // Key for placeholder name in the TensorFlow model (training) std::string m_KeyPHNameForValid; // Key for placeholder name in the TensorFlow model (validation) std::string m_KeyPszX; // Key for samples sizes X std::string m_KeyPszY; // Key for samples sizes Y }; /** Typedefs for the app */ typedef std::vector<ProcessObjectsBundle> BundleList; typedef std::vector<FloatVectorImageType::SizeType> SizeList; typedef std::vector<std::string> StringList; void DoUpdateParameters() { } // // Add an input source, which includes: // -an input image list (for training) // -an input image placeholder (for training) // -an input image list (for validation) // -an input image placeholder (for validation) // -an input patchsize, which is the dimensions of samples. Same for training and validation. // void AddAnInputImage() { // Number of source unsigned int inputNumber = m_Bundles.size() + 1; // Create keys and descriptions std::stringstream ss_key_tr_group, ss_desc_tr_group, ss_key_val_group, ss_desc_val_group, ss_key_tr_in, ss_desc_tr_in, ss_key_val_in, ss_desc_val_in, ss_key_dims_x, ss_desc_dims_x, ss_key_dims_y, ss_desc_dims_y, ss_key_tr_ph, ss_desc_tr_ph, ss_key_val_ph, ss_desc_val_ph; // Parameter group key/description ss_key_tr_group << "training.source" << inputNumber; ss_key_val_group << "validation.source" << inputNumber; ss_desc_tr_group << "Parameters for source #" << inputNumber << " (training)"; ss_desc_val_group << "Parameters for source #" << inputNumber << " (validation)"; // Parameter group keys ss_key_tr_in << ss_key_tr_group.str() << ".il"; ss_key_val_in << ss_key_val_group.str() << ".il"; ss_key_dims_x << ss_key_tr_group.str() << ".patchsizex"; ss_key_dims_y << ss_key_tr_group.str() << ".patchsizey"; ss_key_tr_ph << ss_key_tr_group.str() << ".placeholder"; ss_key_val_ph << ss_key_val_group.str() << ".name"; // Parameter group descriptions ss_desc_tr_in << "Input image (or list to stack) for source #" << inputNumber << " (training)"; ss_desc_val_in << "Input image (or list to stack) for source #" << inputNumber << " (validation)"; ss_desc_dims_x << "Patch size (x) for source #" << inputNumber; ss_desc_dims_y << "Patch size (y) for source #" << inputNumber; ss_desc_tr_ph << "Name of the input placeholder for source #" << inputNumber << " (training)"; ss_desc_val_ph << "Name of the input placeholder " "or output tensor for source #" << inputNumber << " (validation)"; // Populate group AddParameter(ParameterType_Group, ss_key_tr_group.str(), ss_desc_tr_group.str());
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
AddParameter(ParameterType_InputImageList, ss_key_tr_in.str(), ss_desc_tr_in.str() ); AddParameter(ParameterType_Int, ss_key_dims_x.str(), ss_desc_dims_x.str()); SetMinimumParameterIntValue (ss_key_dims_x.str(), 1); AddParameter(ParameterType_Int, ss_key_dims_y.str(), ss_desc_dims_y.str()); SetMinimumParameterIntValue (ss_key_dims_y.str(), 1); AddParameter(ParameterType_String, ss_key_tr_ph.str(), ss_desc_tr_ph.str()); AddParameter(ParameterType_Group, ss_key_val_group.str(), ss_desc_val_group.str()); AddParameter(ParameterType_InputImageList, ss_key_val_in.str(), ss_desc_val_in.str() ); AddParameter(ParameterType_String, ss_key_val_ph.str(), ss_desc_val_ph.str()); // Add a new bundle ProcessObjectsBundle bundle; bundle.m_KeyInForTrain = ss_key_tr_in.str(); bundle.m_KeyInForValid = ss_key_val_in.str(); bundle.m_KeyPHNameForTrain = ss_key_tr_ph.str(); bundle.m_KeyPHNameForValid = ss_key_val_ph.str(); bundle.m_KeyPszX = ss_key_dims_x.str(); bundle.m_KeyPszY = ss_key_dims_y.str(); m_Bundles.push_back(bundle); } void DoInit() { // Documentation SetName("TensorflowModelTrain"); SetDescription("Train a multisource deep learning net using Tensorflow. Change " "the " + tf::ENV_VAR_NAME_NSOURCES + " environment variable to set the number of " "sources."); SetDocLongDescription("The application trains a Tensorflow model over multiple data sources. " "The number of input sources can be changed at runtime by setting the " "system environment variable " + tf::ENV_VAR_NAME_NSOURCES + ". " "For each source, you have to set (1) the tensor placeholder name, as named in " "the tensorflow model, (2) the patch size and (3) the image(s) source. "); SetDocAuthors("Remi Cresson"); AddDocTag(Tags::Learning); // Input model AddParameter(ParameterType_Group, "model", "Model parameters"); AddParameter(ParameterType_Directory, "model.dir", "Tensorflow model_save directory"); MandatoryOn ("model.dir"); AddParameter(ParameterType_String, "model.restorefrom", "Restore model from path"); MandatoryOff ("model.restorefrom"); AddParameter(ParameterType_String, "model.saveto", "Save model to path"); MandatoryOff ("model.saveto"); AddParameter(ParameterType_StringList, "model.tagsets", "Which tags (i.e. v1.MetaGraphDefs) to load from the saved model. Currently, only one tag is supported. Can be retrieved by running `saved_model_cli show --dir your_model_dir --all`"); MandatoryOff ("model.tagsets"); // Training parameters group AddParameter(ParameterType_Group, "training", "Training parameters"); AddParameter(ParameterType_Int, "training.batchsize", "Batch size"); SetMinimumParameterIntValue ("training.batchsize", 1); SetDefaultParameterInt ("training.batchsize", 100); AddParameter(ParameterType_Int, "training.epochs", "Number of epochs"); SetMinimumParameterIntValue ("training.epochs", 1); SetDefaultParameterInt ("training.epochs", 100); AddParameter(ParameterType_StringList, "training.userplaceholders", "Additional single-valued placeholders for training. Supported types: int, float, bool."); MandatoryOff ("training.userplaceholders"); AddParameter(ParameterType_StringList, "training.targetnodes", "Names of the target nodes"); MandatoryOn ("training.targetnodes"); AddParameter(ParameterType_StringList, "training.outputtensors", "Names of the output tensors to display"); MandatoryOff ("training.outputtensors"); AddParameter(ParameterType_Bool, "training.usestreaming", "Use the streaming through patches (slower but can process big dataset)"); MandatoryOff ("training.usestreaming"); // Metrics AddParameter(ParameterType_Group, "validation", "Validation parameters");
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
MandatoryOff ("validation"); AddParameter(ParameterType_Int, "validation.step", "Perform the validation every Nth epochs"); SetMinimumParameterIntValue ("validation.step", 1); SetDefaultParameterInt ("validation.step", 10); AddParameter(ParameterType_Choice, "validation.mode", "Metrics to compute"); AddChoice ("validation.mode.none", "No validation step"); AddChoice ("validation.mode.class", "Classification metrics"); AddChoice ("validation.mode.rmse", "Root mean square error"); AddParameter(ParameterType_StringList, "validation.userplaceholders", "Additional single-valued placeholders for validation. Supported types: int, float, bool."); MandatoryOff ("validation.userplaceholders"); AddParameter(ParameterType_Bool, "validation.usestreaming", "Use the streaming through patches (slower but can process big dataset)"); MandatoryOff ("validation.usestreaming"); // Input/output images AddAnInputImage(); for (int i = 1; i < tf::GetNumberOfSources() + 1 ; i++) // +1 because we have at least 1 source more for training { AddAnInputImage(); } // Example SetDocExampleParameterValue("source1.il", "spot6pms.tif"); SetDocExampleParameterValue("source1.placeholder", "x1"); SetDocExampleParameterValue("source1.patchsizex", "16"); SetDocExampleParameterValue("source1.patchsizey", "16"); SetDocExampleParameterValue("source2.il", "labels.tif"); SetDocExampleParameterValue("source2.placeholder", "y1"); SetDocExampleParameterValue("source2.patchsizex", "1"); SetDocExampleParameterValue("source2.patchsizex", "1"); SetDocExampleParameterValue("model.dir", "/tmp/my_saved_model/"); SetDocExampleParameterValue("training.userplaceholders", "is_training=true dropout=0.2"); SetDocExampleParameterValue("training.targetnodes", "optimizer"); SetDocExampleParameterValue("model.saveto", "/tmp/my_saved_model/variables/variables"); } // // Prepare bundles // Here, we populate the two following groups: // 1.Training : // -Placeholders // -PatchSize // -ImageSource // 2.Learning/Validation // -Placeholders (if input) or Tensor name (if target) // -PatchSize (which is the same as for training) // -ImageSource (depending if it's for learning or validation) // // TODO: a bit of refactoring. We could simply rely on m_Bundles // if we can keep trace of indices of sources for // training / test / validation // void PrepareInputs() { // Clear placeholder names m_InputPlaceholdersForTraining.clear(); m_InputPlaceholdersForValidation.clear(); // Clear patches sizes m_InputPatchesSizeForTraining.clear(); m_InputPatchesSizeForValidation.clear(); m_TargetPatchesSize.clear(); // Clear bundles m_InputSourcesForTraining.clear(); m_InputSourcesForEvaluationAgainstLearningData.clear(); m_InputSourcesForEvaluationAgainstValidationData.clear(); m_TargetTensorsNames.clear();
281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
m_InputTargetsForEvaluationAgainstValidationData.clear(); m_InputTargetsForEvaluationAgainstLearningData.clear(); // Prepare the bundles for (auto& bundle: m_Bundles) { // Source FloatVectorImageListType::Pointer trainStack = GetParameterImageList(bundle.m_KeyInForTrain); bundle.tfSource.Set(trainStack); m_InputSourcesForTraining.push_back(bundle.tfSource.Get()); // Placeholder std::string placeholderForTraining = GetParameterAsString(bundle.m_KeyPHNameForTrain); m_InputPlaceholdersForTraining.push_back(placeholderForTraining); // Patch size FloatVectorImageType::SizeType patchSize; patchSize[0] = GetParameterInt(bundle.m_KeyPszX); patchSize[1] = GetParameterInt(bundle.m_KeyPszY); m_InputPatchesSizeForTraining.push_back(patchSize); otbAppLogINFO("New source:"); otbAppLogINFO("Patch size : "<< patchSize); otbAppLogINFO("Placeholder (training) : "<< placeholderForTraining); // Prepare validation sources if (GetParameterInt("validation.mode") != 0) { // Get the stack if (!HasValue(bundle.m_KeyInForValid)) { otbAppLogFATAL("No validation input is set for this source"); } FloatVectorImageListType::Pointer validStack = GetParameterImageList(bundle.m_KeyInForValid); bundle.tfSourceForValidation.Set(validStack); // We check if the placeholder is the same for training and for validation // If yes, it means that its not an output tensor on which perform the validation std::string placeholderForValidation = GetParameterAsString(bundle.m_KeyPHNameForValid); if (placeholderForValidation.empty()) { placeholderForValidation = placeholderForTraining; } // Same placeholder name ==> is a source for validation if (placeholderForValidation.compare(placeholderForTraining) == 0) { // Source m_InputSourcesForEvaluationAgainstValidationData.push_back(bundle.tfSourceForValidation.Get()); m_InputSourcesForEvaluationAgainstLearningData.push_back(bundle.tfSource.Get()); // Placeholder m_InputPlaceholdersForValidation.push_back(placeholderForValidation); // Patch size m_InputPatchesSizeForValidation.push_back(patchSize); otbAppLogINFO("Placeholder (validation) : "<< placeholderForValidation); } // Different placeholder ==> is a target to validate else { // Source m_InputTargetsForEvaluationAgainstValidationData.push_back(bundle.tfSourceForValidation.Get()); m_InputTargetsForEvaluationAgainstLearningData.push_back(bundle.tfSource.Get()); // Placeholder m_TargetTensorsNames.push_back(placeholderForValidation);
351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
// Patch size m_TargetPatchesSize.push_back(patchSize); otbAppLogINFO("Tensor name (validation) : "<< placeholderForValidation); } } } } // // Get user placeholders // TrainModelFilterType::DictType GetUserPlaceholders(const std::string & key) { TrainModelFilterType::DictType dict; TrainModelFilterType::StringList expressions = GetParameterStringList(key); for (auto& exp: expressions) { TrainModelFilterType::DictElementType entry = tf::ExpressionToTensor(exp); dict.push_back(entry); otbAppLogINFO("Using placeholder " << entry.first << " with " << tf::PrintTensorInfos(entry.second)); } return dict; } // // Print some classification metrics // void PrintClassificationMetrics(const ConfMatType & confMat, const MapOfClassesType & mapOfClassesRef) { ConfusionMatrixCalculatorType::Pointer confMatMeasurements = ConfusionMatrixCalculatorType::New(); confMatMeasurements->SetConfusionMatrix(confMat); confMatMeasurements->SetMapOfClasses(mapOfClassesRef); confMatMeasurements->Compute(); for (auto const& itMapOfClassesRef : mapOfClassesRef) { LabelValueType labelRef = itMapOfClassesRef.first; LabelValueType indexLabelRef = itMapOfClassesRef.second; otbAppLogINFO("Precision of class [" << labelRef << "] vs all: " << confMatMeasurements->GetPrecisions()[indexLabelRef]); otbAppLogINFO("Recall of class [" << labelRef << "] vs all: " << confMatMeasurements->GetRecalls()[indexLabelRef]); otbAppLogINFO("F-score of class [" << labelRef << "] vs all: " << confMatMeasurements->GetFScores()[indexLabelRef]); otbAppLogINFO("\t"); } otbAppLogINFO("Precision of the different classes: " << confMatMeasurements->GetPrecisions()); otbAppLogINFO("Recall of the different classes: " << confMatMeasurements->GetRecalls()); otbAppLogINFO("F-score of the different classes: " << confMatMeasurements->GetFScores()); otbAppLogINFO("\t"); otbAppLogINFO("Kappa index: " << confMatMeasurements->GetKappaIndex()); otbAppLogINFO("Overall accuracy index: " << confMatMeasurements->GetOverallAccuracy()); otbAppLogINFO("Confusion matrix:\n" << confMat); } void DoExecute() { // Load the Tensorflow bundle tf::LoadModel(GetParameterAsString("model.dir"), m_SavedModel, GetParameterStringList("model.tagsets")); // Check if we have to restore variables from somewhere else if (HasValue("model.restorefrom")) { const std::string path = GetParameterAsString("model.restorefrom"); otbAppLogINFO("Restoring model from " + path); // Load SavedModel variables
421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
tf::RestoreModel(path, m_SavedModel); } // Prepare inputs PrepareInputs(); // Setup training filter m_TrainModelFilter = TrainModelFilterType::New(); m_TrainModelFilter->SetSavedModel(&m_SavedModel); m_TrainModelFilter->SetOutputTensors(GetParameterStringList("training.outputtensors")); m_TrainModelFilter->SetTargetNodesNames(GetParameterStringList("training.targetnodes")); m_TrainModelFilter->SetBatchSize(GetParameterInt("training.batchsize")); m_TrainModelFilter->SetUserPlaceholders(GetUserPlaceholders("training.userplaceholders")); m_TrainModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming")); // Set inputs for (unsigned int i = 0 ; i < m_InputSourcesForTraining.size() ; i++) { m_TrainModelFilter->PushBackInputTensorBundle( m_InputPlaceholdersForTraining[i], m_InputPatchesSizeForTraining[i], m_InputSourcesForTraining[i]); } // Setup the validation filter const bool do_validation = HasUserValue("validation.mode"); if (GetParameterInt("validation.mode")==1) // class { otbAppLogINFO("Set validation mode to classification validation"); m_ValidateModelFilter = ValidateModelFilterType::New(); m_ValidateModelFilter->SetSavedModel(&m_SavedModel); m_ValidateModelFilter->SetBatchSize(GetParameterInt("training.batchsize")); m_ValidateModelFilter->SetUserPlaceholders(GetUserPlaceholders("validation.userplaceholders")); m_ValidateModelFilter->SetInputPlaceholders(m_InputPlaceholdersForValidation); m_ValidateModelFilter->SetInputReceptiveFields(m_InputPatchesSizeForValidation); m_ValidateModelFilter->SetOutputTensors(m_TargetTensorsNames); m_ValidateModelFilter->SetOutputExpressionFields(m_TargetPatchesSize); } else if (GetParameterInt("validation.mode")==2) // rmse) { otbAppLogINFO("Set validation mode to classification RMSE evaluation"); otbAppLogFATAL("Not implemented yet !"); // XD // TODO } // Epoch for (int epoch = 1 ; epoch <= GetParameterInt("training.epochs") ; epoch++) { // Train the model AddProcess(m_TrainModelFilter, "Training epoch #" + std::to_string(epoch)); m_TrainModelFilter->Update(); if (do_validation) { // Validate the model if (epoch % GetParameterInt("validation.step") == 0) { // 1. Evaluate the metrics against the learning data for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++) { m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstLearningData[i]); } m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData); // As we use the learning data here, it's rational to use the same option as streaming during training m_ValidateModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming"));
491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
// Update AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)"); m_ValidateModelFilter->Update(); for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) { otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i)); } // 2. Evaluate the metrics against the validation data // Here we just change the input sources and references for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstValidationData.size() ; i++) { m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstValidationData[i]); } m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstValidationData); m_ValidateModelFilter->SetUseStreaming(GetParameterInt("validation.usestreaming")); // Update AddProcess(m_ValidateModelFilter, "Evaluate model (Validation data)"); m_ValidateModelFilter->Update(); for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) { otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i)); } } // Step is OK to perform validation } // Do the validation against the validation data } // Next epoch // Check if we have to save variables to somewhere if (HasValue("model.saveto")) { const std::string path = GetParameterAsString("model.saveto"); otbAppLogINFO("Saving model to " + path); tf::SaveModel(path, m_SavedModel); } } private: tensorflow::SavedModelBundle m_SavedModel; // must be alive during all the execution of the application ! // Filters TrainModelFilterType::Pointer m_TrainModelFilter; ValidateModelFilterType::Pointer m_ValidateModelFilter; // Inputs BundleList m_Bundles; // Patches size SizeList m_InputPatchesSizeForTraining; SizeList m_InputPatchesSizeForValidation; SizeList m_TargetPatchesSize; // Placeholders and Tensors names StringList m_InputPlaceholdersForTraining; StringList m_InputPlaceholdersForValidation; StringList m_TargetTensorsNames; // Image sources std::vector<FloatVectorImageType::Pointer> m_InputSourcesForTraining; std::vector<FloatVectorImageType::Pointer> m_InputSourcesForEvaluationAgainstLearningData; std::vector<FloatVectorImageType::Pointer> m_InputSourcesForEvaluationAgainstValidationData; std::vector<FloatVectorImageType::Pointer> m_InputTargetsForEvaluationAgainstLearningData;
561562563564565566567568569
std::vector<FloatVectorImageType::Pointer> m_InputTargetsForEvaluationAgainstValidationData; }; // end of class } // namespace wrapper } // namespace otb OTB_APPLICATION_EXPORT( otb::Wrapper::TensorflowModelTrain )