diff --git a/app/otbTensorflowModelTrain.cxx b/app/otbTensorflowModelTrain.cxx index 2fe0ca2d887ab204df7c4c72b8e4a389721a8d12..8cdc2c4cc6852ef19a2e6dd1583ddc6167b46bf2 100644 --- a/app/otbTensorflowModelTrain.cxx +++ b/app/otbTensorflowModelTrain.cxx @@ -419,7 +419,7 @@ public: m_TrainModelFilter->SetBatchSize(GetParameterInt("training.batchsize")); m_TrainModelFilter->SetUserPlaceholders(GetUserPlaceholders("training.userplaceholders")); - // Set input bundles + // Set inputs for (unsigned int i = 0 ; i < m_InputSourcesForTraining.size() ; i++) { m_TrainModelFilter->PushBackInputBundle( @@ -454,7 +454,8 @@ public: m_ValidateModelFilter->SetBatchSize(GetParameterInt("training.batchsize")); m_ValidateModelFilter->SetUserPlaceholders(GetUserPlaceholders("validation.userplaceholders")); - // Evaluate the metrics against the learning data + // 1. Evaluate the metrics against the learning data + for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++) { m_ValidateModelFilter->PushBackInputBundle( @@ -465,26 +466,30 @@ public: m_ValidateModelFilter->SetOutputTensorsNames(m_TargetTensorsNames); m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData); m_ValidateModelFilter->SetOutputFOESizes(m_TargetPatchesSize); + + // Update AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)"); m_ValidateModelFilter->Update(); - // Print metrics for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) { otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i)); } - // Evaluate the metrics against the validation data + // 2. Evaluate the metrics against the validation data + + // Here we just change the input sources and references for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstValidationData.size() ; i++) { m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstValidationData[i]); } m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstValidationData); + + // Update AddProcess(m_ValidateModelFilter, "Evaluate model (Validation data)"); m_ValidateModelFilter->Update(); - // Print metrics for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) { otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); diff --git a/include/otbTensorflowMultisourceModelBase.h b/include/otbTensorflowMultisourceModelBase.h index 74a7d3ed17af2488a070cb9b58d570eb72d38081..e0c23846fd45110efe30704acde79c35060e5310 100644 --- a/include/otbTensorflowMultisourceModelBase.h +++ b/include/otbTensorflowMultisourceModelBase.h @@ -36,6 +36,9 @@ namespace otb * Names of input placeholders must be specified using the * SetInputPlaceholdersNames method * + * TODO: + * Replace FOV (Field Of View) --> RF (Receptive Field) + * Replace FEO (Field Of Expr) --> EF (Expression Field) * * \ingroup OTBTensorflow */ diff --git a/include/otbTensorflowMultisourceModelTrain.h b/include/otbTensorflowMultisourceModelTrain.h index 51fb894c5e4d9fc54f22e6c32224213ca6b9f69a..8308f2ed5bbbb64b949163f33bea3a3b5eb1fa80 100644 --- a/include/otbTensorflowMultisourceModelTrain.h +++ b/include/otbTensorflowMultisourceModelTrain.h @@ -34,6 +34,7 @@ namespace otb * Names of input placeholders must be specified using the * SetInputPlaceholdersNames method * + * TODO: Add an option to disable streaming * * \ingroup OTBTensorflow */ diff --git a/include/otbTensorflowMultisourceModelTrain.hxx b/include/otbTensorflowMultisourceModelTrain.hxx index 92e6c68fe59a17952dbd5a48ba10056c89da2f14..aa47237c30ba886c5bc7b9effa5028d9cdefde91 100644 --- a/include/otbTensorflowMultisourceModelTrain.hxx +++ b/include/otbTensorflowMultisourceModelTrain.hxx @@ -21,6 +21,7 @@ TensorflowMultisourceModelTrain<TInputImage> ::TensorflowMultisourceModelTrain() { m_BatchSize = 100; + m_NumberOfSamples = 0; } @@ -41,7 +42,6 @@ TensorflowMultisourceModelTrain<TInputImage> // Check the number of samples ////////////////////////////////////////////////////////////////////////////////////////// - m_NumberOfSamples = 0; for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++) { diff --git a/include/otbTensorflowMultisourceModelValidate.h b/include/otbTensorflowMultisourceModelValidate.h index 84acb5b946c7e91b387430dde114b0ee5e12876f..868317f9344ba1567deaf71aecd7b18a2ccd635f 100644 --- a/include/otbTensorflowMultisourceModelValidate.h +++ b/include/otbTensorflowMultisourceModelValidate.h @@ -36,6 +36,7 @@ namespace otb * Names of input placeholders must be specified using the * SetInputPlaceholdersNames method * + * TODO: Add an option to disable streaming * * \ingroup OTBTensorflow */ @@ -83,20 +84,27 @@ public: typedef std::vector<ConfMatType> ConfMatListType; typedef itk::ImageRegionConstIterator<ImageType> IteratorType; + /* Set and Get the batch size itkSetMacro(BatchSize, unsigned int); itkGetMacro(BatchSize, unsigned int); + + /** Get the number of samples */ itkGetMacro(NumberOfSamples, unsigned int); virtual void GenerateOutputInformation(void); virtual void GenerateInputRequestedRegion(); + /** Set and Get the input references */ virtual void SetInputReferences(ImageListType input); ImagePointerType GetInputReference(unsigned int index); virtual void GenerateData(); + /** Get the confusion matrix */ const ConfMatType GetConfusionMatrix(unsigned int target); + + /** Get the map of classes matrix */ const MapOfClassesType GetMapOfClasses(unsigned int target); protected: diff --git a/include/otbTensorflowMultisourceModelValidate.hxx b/include/otbTensorflowMultisourceModelValidate.hxx index cad1f3d57396d3bbdc22c2f65b5ff34774daa60a..b0ec5e2dd264aa5369cf6f7e4d364a7cbc9fbbdb 100644 --- a/include/otbTensorflowMultisourceModelValidate.hxx +++ b/include/otbTensorflowMultisourceModelValidate.hxx @@ -21,6 +21,7 @@ TensorflowMultisourceModelValidate<TInputImage> ::TensorflowMultisourceModelValidate() { m_BatchSize = 100; + m_NumberOfSamples = 0; } @@ -141,6 +142,9 @@ TensorflowMultisourceModelValidate<TInputImage> } +/* + * Set the references images + */ template<class TInputImage> void TensorflowMultisourceModelValidate<TInputImage> @@ -149,6 +153,10 @@ TensorflowMultisourceModelValidate<TInputImage> m_References = input; } +/* + * Retrieve the i-th reference image + * An exception is thrown if it doesn't exist. + */ template<class TInputImage> typename TensorflowMultisourceModelValidate<TInputImage>::ImagePointerType TensorflowMultisourceModelValidate<TInputImage> @@ -164,6 +172,9 @@ TensorflowMultisourceModelValidate<TInputImage> /** * Perform the validation + * The session is ran over the entire set of batches. + * Output is then validated agains the references images, + * and a confusion matrix is built. */ template <class TInputImage> void @@ -346,6 +357,10 @@ TensorflowMultisourceModelValidate<TInputImage> } +/* + * Get the confusion matrix + * If the target is not in the map, an exception is thrown. + */ template <class TInputImage> const typename TensorflowMultisourceModelValidate<TInputImage>::ConfMatType TensorflowMultisourceModelValidate<TInputImage> @@ -360,6 +375,10 @@ TensorflowMultisourceModelValidate<TInputImage> return m_ConfusionMatrices[target]; } +/* + * Get the map of classes + * If the target is not in the map, an exception is thrown. + */ template <class TInputImage> const typename TensorflowMultisourceModelValidate<TInputImage>::MapOfClassesType TensorflowMultisourceModelValidate<TInputImage>