Commit d5ffaaab authored by remi cresson's avatar remi cresson
Browse files

REFAC: a bit of changes to clarify input roles

parent 2062149c
......@@ -240,10 +240,10 @@ public:
// -Placeholders
// -PatchSize
// -ImageSource
// 2.Validation/Test
// 2.Learning/Validation
// -Placeholders (if input) or Tensor name (if target)
// -PatchSize (which is the same as for training)
// -ImageSource (depending if it's for test or validation)
// -ImageSource (depending if it's for learning or validation)
//
// TODO: a bit of refactoring. We could simply rely on m_Bundles
// if we can keep trace of indices of sources for
......@@ -262,12 +262,12 @@ public:
// Clear bundles
m_InputSourcesForTraining.clear();
m_InputSourcesForTest.clear();
m_InputSourcesForValidation.clear();
m_InputSourcesForEvaluationAgainstLearningData.clear();
m_InputSourcesForEvaluationAgainstValidationData.clear();
m_TargetTensorsNames.clear();
m_InputTargetsForValidation.clear();
m_InputTargetsForTest.clear();
m_InputTargetsForEvaluationAgainstValidationData.clear();
m_InputTargetsForEvaluationAgainstLearningData.clear();
// Prepare the bundles
......@@ -314,8 +314,8 @@ public:
if (placeholderForValidation.compare(placeholderForTraining) == 0)
{
// Source
m_InputSourcesForValidation.push_back(bundle.tfSourceForValidation.Get());
m_InputSourcesForTest.push_back(bundle.tfSource.Get());
m_InputSourcesForEvaluationAgainstValidationData.push_back(bundle.tfSourceForValidation.Get());
m_InputSourcesForEvaluationAgainstLearningData.push_back(bundle.tfSource.Get());
// Placeholder
m_InputPlaceholdersForValidation.push_back(placeholderForValidation);
......@@ -330,8 +330,8 @@ public:
else
{
// Source
m_InputTargetsForValidation.push_back(bundle.tfSourceForValidation.Get());
m_InputTargetsForTest.push_back(bundle.tfSource.Get());
m_InputTargetsForEvaluationAgainstValidationData.push_back(bundle.tfSourceForValidation.Get());
m_InputTargetsForEvaluationAgainstLearningData.push_back(bundle.tfSource.Get());
// Placeholder
m_TargetTensorsNames.push_back(placeholderForValidation);
......@@ -422,8 +422,10 @@ public:
// Set input bundles
for (unsigned int i = 0 ; i < m_InputSourcesForTraining.size() ; i++)
{
m_TrainModelFilter->PushBackInputBundle(m_InputPlaceholdersForTraining[i],
m_InputPatchesSizeForTraining[i], m_InputSourcesForTraining[i]);
m_TrainModelFilter->PushBackInputBundle(
m_InputPlaceholdersForTraining[i],
m_InputPatchesSizeForTraining[i],
m_InputSourcesForTraining[i]);
}
// Train the model
......@@ -449,26 +451,24 @@ public:
m_ValidateModelFilter = ValidateModelFilterType::New();
m_ValidateModelFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def());
m_ValidateModelFilter->SetSession(m_SavedModel.session.get());
m_ValidateModelFilter->SetOutputTensorsNames(m_TargetTensorsNames);
m_ValidateModelFilter->SetBatchSize(GetParameterInt("training.batchsize"));
m_ValidateModelFilter->SetUserPlaceholders(GetUserPlaceholders("validation.userplaceholders"));
// Evaluate the metrics against the learning data (test)
for (unsigned int i = 0 ; i < m_InputSourcesForTest.size() ; i++)
{
m_ValidateModelFilter->PushBackInputBundle(m_InputPlaceholdersForValidation[i],
m_InputPatchesSizeForValidation[i], m_InputSourcesForTest[i]);
}
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
// Evaluate the metrics against the learning data
for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++)
{
m_ValidateModelFilter->PushBackInputReference(m_InputTargetsForTest[i], m_TargetPatchesSize[i]);
m_ValidateModelFilter->PushBackInputBundle(
m_InputPlaceholdersForValidation[i],
m_InputPatchesSizeForValidation[i],
m_InputSourcesForEvaluationAgainstLearningData[i]);
}
// Evaluate the model (test)
AddProcess(m_ValidateModelFilter, "Evaluate model (Test)");
m_ValidateModelFilter->SetOutputTensorsNames(m_TargetTensorsNames);
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData);
m_ValidateModelFilter->SetOutputFOESizes(m_TargetPatchesSize);
AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)");
m_ValidateModelFilter->Update();
// Print some metrics
// Print metrics
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":");
......@@ -476,28 +476,21 @@ public:
}
// Evaluate the metrics against the validation data
for (unsigned int i = 0 ; i < m_InputSourcesForValidation.size() ; i++)
{
m_ValidateModelFilter->SetInput(i, m_InputSourcesForValidation[i]);
}
m_ValidateModelFilter->ClearInputReferences();
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstValidationData.size() ; i++)
{
m_ValidateModelFilter->PushBackInputReference(m_InputTargetsForValidation[i], m_TargetPatchesSize[i]);
m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstValidationData[i]);
}
// Evaluate the model (validation)
AddProcess(m_ValidateModelFilter, "Evaluate model (Validation)");
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstValidationData);
AddProcess(m_ValidateModelFilter, "Evaluate model (Validation data)");
m_ValidateModelFilter->Update();
// Print some metrics
// Print metrics
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":");
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
}
}
else if (GetParameterInt("validation.mode")==2) // rmse)
{
......@@ -516,17 +509,23 @@ private:
tensorflow::SavedModelBundle m_SavedModel; // must be alive during all the execution of the application !
BundleList m_Bundles;
// Patches size
SizeList m_InputPatchesSizeForTraining;
SizeList m_InputPatchesSizeForValidation;
SizeList m_TargetPatchesSize;
// Placeholders and Tensors names
StringList m_InputPlaceholdersForTraining;
StringList m_InputPlaceholdersForValidation;
StringList m_TargetTensorsNames;
// Image sources
std::vector<FloatVectorImageType::Pointer> m_InputSourcesForTraining;
std::vector<FloatVectorImageType::Pointer> m_InputSourcesForTest;
std::vector<FloatVectorImageType::Pointer> m_InputTargetsForTest;
std::vector<FloatVectorImageType::Pointer> m_InputSourcesForValidation;
std::vector<FloatVectorImageType::Pointer> m_InputTargetsForValidation;
std::vector<FloatVectorImageType::Pointer> m_InputSourcesForEvaluationAgainstLearningData;
std::vector<FloatVectorImageType::Pointer> m_InputSourcesForEvaluationAgainstValidationData;
std::vector<FloatVectorImageType::Pointer> m_InputTargetsForEvaluationAgainstLearningData;
std::vector<FloatVectorImageType::Pointer> m_InputTargetsForEvaluationAgainstValidationData;
}; // end of class
......
......@@ -84,15 +84,29 @@ public:
tensorflow::Session * GetSession() { return m_Session; }
/** Model parameters */
void PushBackInputBundle(std::string placeholder, SizeType fieldOfView, ImagePointerType image);
itkSetMacro(InputPlaceholdersNames, StringList);
void PushBackInputBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image);
// /** Input placeholders names */
// itkSetMacro(InputPlaceholdersNames, StringList);
itkGetMacro(InputPlaceholdersNames, StringList);
itkSetMacro(InputFOVSizes, SizeListType);
//
// /** Receptive field */
// itkSetMacro(InputFOVSizes, SizeListType);
itkGetMacro(InputFOVSizes, SizeListType);
void SetUserPlaceholders(DictListType dict) { m_UserPlaceholders = dict; }
DictListType GetUserPlaceholders() { return m_UserPlaceholders; }
/** Output tensors names */
itkSetMacro(OutputTensorsNames, StringList);
itkGetMacro(OutputTensorsNames, StringList);
/** Expression field */
itkSetMacro(OutputFOESizes, SizeListType);
itkGetMacro(OutputFOESizes, SizeListType);
/** User placeholders */
void SetUserPlaceholders(DictListType dict) { m_UserPlaceholders = dict; }
DictListType GetUserPlaceholders() { return m_UserPlaceholders; }
/** Target nodes names */
itkSetMacro(TargetNodesNames, StringList);
itkGetMacro(TargetNodesNames, StringList);
......@@ -123,6 +137,7 @@ private:
// Model parameters
StringList m_InputPlaceholdersNames; // Input placeholders names
SizeListType m_InputFOVSizes; // Input tensors field of view (FOV) sizes
SizeListType m_OutputFOESizes; // Output tensors field of expression (FOE) sizes
DictListType m_UserPlaceholders; // User placeholders
StringList m_OutputTensorsNames; // User tensors
StringList m_TargetNodesNames; // User target tensors
......
......@@ -25,10 +25,10 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::PushBackInputBundle(std::string placeholder, SizeType fieldOfView, ImagePointerType image)
::PushBackInputBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image)
{
Superclass::PushBackInput(image);
m_InputFOVSizes.push_back(fieldOfView);
m_InputFOVSizes.push_back(receptiveField);
m_InputPlaceholdersNames.push_back(placeholder);
}
......
......@@ -63,6 +63,7 @@ public:
typedef typename Superclass::RegionType RegionType;
typedef typename Superclass::SizeType SizeType;
typedef typename Superclass::IndexType IndexType;
typedef std::vector<ImagePointerType> ImageListType;
/* Typedefs for parameters */
typedef typename Superclass::DictType DictType;
......@@ -90,9 +91,8 @@ public:
virtual void GenerateInputRequestedRegion();
virtual void PushBackInputReference(const ImageType *input, SizeType foe);
const TInputImage* GetInputReference(unsigned int index);
void ClearInputReferences();
virtual void SetInputReferences(ImageListType input);
ImagePointerType GetInputReference(unsigned int index);
virtual void GenerateData();
......@@ -108,8 +108,7 @@ private:
void operator=(const Self&); //purposely not implemented
unsigned int m_BatchSize; // Batch size
SizeListType m_OutputFOESizes; // Output tensors field of expression (FOE) sizes
std::vector<ImageType *> m_References; // The references images
ImageListType m_References; // The references images
// Read only
unsigned int m_NumberOfSamples; // Number of samples
......
......@@ -97,16 +97,17 @@ TensorflowMultisourceModelValidate<TInputImage>
{
itkExceptionMacro("No reference is set");
}
if (nbOfRefs != m_OutputFOESizes.size())
SizeListType outputEFSizes = this->GetOutputFOESizes();
if (nbOfRefs != outputEFSizes.size())
{
itkExceptionMacro("There is " << nbOfRefs << " but only " <<
m_OutputFOESizes.size() << " field of expression sizes");
outputEFSizes.size() << " field of expression sizes");
}
// Check reference image infos
for (unsigned int i = 0 ;i < nbOfRefs ; i++)
{
const SizeType outputFOESize = m_OutputFOESizes[i];
const SizeType outputFOESize = outputEFSizes[i];
const RegionType refRegion = m_References[i]->GetLargestPossibleRegion();
if (refRegion.GetSize(0) != outputFOESize[0])
{
......@@ -143,14 +144,13 @@ TensorflowMultisourceModelValidate<TInputImage>
template<class TInputImage>
void
TensorflowMultisourceModelValidate<TInputImage>
::PushBackInputReference(const ImageType *input, SizeType fieldOfExpression)
::SetInputReferences(ImageListType input)
{
m_References.push_back(const_cast<ImageType*>(input));
m_OutputFOESizes.push_back(fieldOfExpression);
m_References = input;
}
template<class TInputImage>
const TInputImage*
typename TensorflowMultisourceModelValidate<TInputImage>::ImagePointerType
TensorflowMultisourceModelValidate<TInputImage>
::GetInputReference(unsigned int index)
{
......@@ -159,16 +159,7 @@ TensorflowMultisourceModelValidate<TInputImage>
itkExceptionMacro("There is no input reference #" << index);
}
return static_cast<const ImageType*>(m_References[index]);
}
template <class TInputImage>
void
TensorflowMultisourceModelValidate<TInputImage>
::ClearInputReferences()
{
m_References.clear();
m_OutputFOESizes.clear();
return m_References[index];
}
/**
......@@ -257,11 +248,11 @@ TensorflowMultisourceModelValidate<TInputImage>
itkWarningMacro("There is " << outputs.size() << " outputs returned after session run, " <<
"but only " << m_References.size() << " reference(s) set");
}
SizeListType outputEFSizes = this->GetOutputFOESizes();
for (unsigned int refIdx = 0 ; refIdx < outputs.size() ; refIdx++)
{
// Recopy the chunk
const SizeType outputFOESize = m_OutputFOESizes[refIdx];
const SizeType outputFOESize = outputEFSizes[refIdx];
IndexType cpyStart;
cpyStart.Fill(0);
IndexType refRegStart;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment