Commit 20bb0192 authored by remi cresson's avatar remi cresson
Browse files

DOC: add plenty of comments

parent d5ffaaab
......@@ -419,7 +419,7 @@ public:
m_TrainModelFilter->SetBatchSize(GetParameterInt("training.batchsize"));
m_TrainModelFilter->SetUserPlaceholders(GetUserPlaceholders("training.userplaceholders"));
// Set input bundles
// Set inputs
for (unsigned int i = 0 ; i < m_InputSourcesForTraining.size() ; i++)
{
m_TrainModelFilter->PushBackInputBundle(
......@@ -454,7 +454,8 @@ public:
m_ValidateModelFilter->SetBatchSize(GetParameterInt("training.batchsize"));
m_ValidateModelFilter->SetUserPlaceholders(GetUserPlaceholders("validation.userplaceholders"));
// Evaluate the metrics against the learning data
// 1. Evaluate the metrics against the learning data
for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++)
{
m_ValidateModelFilter->PushBackInputBundle(
......@@ -465,26 +466,30 @@ public:
m_ValidateModelFilter->SetOutputTensorsNames(m_TargetTensorsNames);
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData);
m_ValidateModelFilter->SetOutputFOESizes(m_TargetPatchesSize);
// Update
AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)");
m_ValidateModelFilter->Update();
// Print metrics
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":");
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
}
// Evaluate the metrics against the validation data
// 2. Evaluate the metrics against the validation data
// Here we just change the input sources and references
for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstValidationData.size() ; i++)
{
m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstValidationData[i]);
}
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstValidationData);
// Update
AddProcess(m_ValidateModelFilter, "Evaluate model (Validation data)");
m_ValidateModelFilter->Update();
// Print metrics
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":");
......
......@@ -36,6 +36,9 @@ namespace otb
* Names of input placeholders must be specified using the
* SetInputPlaceholdersNames method
*
* TODO:
* Replace FOV (Field Of View) --> RF (Receptive Field)
* Replace FEO (Field Of Expr) --> EF (Expression Field)
*
* \ingroup OTBTensorflow
*/
......
......@@ -34,6 +34,7 @@ namespace otb
* Names of input placeholders must be specified using the
* SetInputPlaceholdersNames method
*
* TODO: Add an option to disable streaming
*
* \ingroup OTBTensorflow
*/
......
......@@ -21,6 +21,7 @@ TensorflowMultisourceModelTrain<TInputImage>
::TensorflowMultisourceModelTrain()
{
m_BatchSize = 100;
m_NumberOfSamples = 0;
}
......@@ -41,7 +42,6 @@ TensorflowMultisourceModelTrain<TInputImage>
// Check the number of samples
//////////////////////////////////////////////////////////////////////////////////////////
m_NumberOfSamples = 0;
for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++)
{
......
......@@ -36,6 +36,7 @@ namespace otb
* Names of input placeholders must be specified using the
* SetInputPlaceholdersNames method
*
* TODO: Add an option to disable streaming
*
* \ingroup OTBTensorflow
*/
......@@ -83,20 +84,27 @@ public:
typedef std::vector<ConfMatType> ConfMatListType;
typedef itk::ImageRegionConstIterator<ImageType> IteratorType;
/* Set and Get the batch size
itkSetMacro(BatchSize, unsigned int);
itkGetMacro(BatchSize, unsigned int);
/** Get the number of samples */
itkGetMacro(NumberOfSamples, unsigned int);
virtual void GenerateOutputInformation(void);
virtual void GenerateInputRequestedRegion();
/** Set and Get the input references */
virtual void SetInputReferences(ImageListType input);
ImagePointerType GetInputReference(unsigned int index);
virtual void GenerateData();
/** Get the confusion matrix */
const ConfMatType GetConfusionMatrix(unsigned int target);
/** Get the map of classes matrix */
const MapOfClassesType GetMapOfClasses(unsigned int target);
protected:
......
......@@ -21,6 +21,7 @@ TensorflowMultisourceModelValidate<TInputImage>
::TensorflowMultisourceModelValidate()
{
m_BatchSize = 100;
m_NumberOfSamples = 0;
}
......@@ -141,6 +142,9 @@ TensorflowMultisourceModelValidate<TInputImage>
}
/*
* Set the references images
*/
template<class TInputImage>
void
TensorflowMultisourceModelValidate<TInputImage>
......@@ -149,6 +153,10 @@ TensorflowMultisourceModelValidate<TInputImage>
m_References = input;
}
/*
* Retrieve the i-th reference image
* An exception is thrown if it doesn't exist.
*/
template<class TInputImage>
typename TensorflowMultisourceModelValidate<TInputImage>::ImagePointerType
TensorflowMultisourceModelValidate<TInputImage>
......@@ -164,6 +172,9 @@ TensorflowMultisourceModelValidate<TInputImage>
/**
* Perform the validation
* The session is ran over the entire set of batches.
* Output is then validated agains the references images,
* and a confusion matrix is built.
*/
template <class TInputImage>
void
......@@ -346,6 +357,10 @@ TensorflowMultisourceModelValidate<TInputImage>
}
/*
* Get the confusion matrix
* If the target is not in the map, an exception is thrown.
*/
template <class TInputImage>
const typename TensorflowMultisourceModelValidate<TInputImage>::ConfMatType
TensorflowMultisourceModelValidate<TInputImage>
......@@ -360,6 +375,10 @@ TensorflowMultisourceModelValidate<TInputImage>
return m_ConfusionMatrices[target];
}
/*
* Get the map of classes
* If the target is not in the map, an exception is thrown.
*/
template <class TInputImage>
const typename TensorflowMultisourceModelValidate<TInputImage>::MapOfClassesType
TensorflowMultisourceModelValidate<TInputImage>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment