Commit d5ffaaab authored by remi cresson's avatar remi cresson
Browse files

REFAC: a bit of changes to clarify input roles

parent 2062149c
...@@ -240,10 +240,10 @@ public: ...@@ -240,10 +240,10 @@ public:
// -Placeholders // -Placeholders
// -PatchSize // -PatchSize
// -ImageSource // -ImageSource
// 2.Validation/Test // 2.Learning/Validation
// -Placeholders (if input) or Tensor name (if target) // -Placeholders (if input) or Tensor name (if target)
// -PatchSize (which is the same as for training) // -PatchSize (which is the same as for training)
// -ImageSource (depending if it's for test or validation) // -ImageSource (depending if it's for learning or validation)
// //
// TODO: a bit of refactoring. We could simply rely on m_Bundles // TODO: a bit of refactoring. We could simply rely on m_Bundles
// if we can keep trace of indices of sources for // if we can keep trace of indices of sources for
...@@ -262,12 +262,12 @@ public: ...@@ -262,12 +262,12 @@ public:
// Clear bundles // Clear bundles
m_InputSourcesForTraining.clear(); m_InputSourcesForTraining.clear();
m_InputSourcesForTest.clear(); m_InputSourcesForEvaluationAgainstLearningData.clear();
m_InputSourcesForValidation.clear(); m_InputSourcesForEvaluationAgainstValidationData.clear();
m_TargetTensorsNames.clear(); m_TargetTensorsNames.clear();
m_InputTargetsForValidation.clear(); m_InputTargetsForEvaluationAgainstValidationData.clear();
m_InputTargetsForTest.clear(); m_InputTargetsForEvaluationAgainstLearningData.clear();
// Prepare the bundles // Prepare the bundles
...@@ -314,8 +314,8 @@ public: ...@@ -314,8 +314,8 @@ public:
if (placeholderForValidation.compare(placeholderForTraining) == 0) if (placeholderForValidation.compare(placeholderForTraining) == 0)
{ {
// Source // Source
m_InputSourcesForValidation.push_back(bundle.tfSourceForValidation.Get()); m_InputSourcesForEvaluationAgainstValidationData.push_back(bundle.tfSourceForValidation.Get());
m_InputSourcesForTest.push_back(bundle.tfSource.Get()); m_InputSourcesForEvaluationAgainstLearningData.push_back(bundle.tfSource.Get());
// Placeholder // Placeholder
m_InputPlaceholdersForValidation.push_back(placeholderForValidation); m_InputPlaceholdersForValidation.push_back(placeholderForValidation);
...@@ -330,8 +330,8 @@ public: ...@@ -330,8 +330,8 @@ public:
else else
{ {
// Source // Source
m_InputTargetsForValidation.push_back(bundle.tfSourceForValidation.Get()); m_InputTargetsForEvaluationAgainstValidationData.push_back(bundle.tfSourceForValidation.Get());
m_InputTargetsForTest.push_back(bundle.tfSource.Get()); m_InputTargetsForEvaluationAgainstLearningData.push_back(bundle.tfSource.Get());
// Placeholder // Placeholder
m_TargetTensorsNames.push_back(placeholderForValidation); m_TargetTensorsNames.push_back(placeholderForValidation);
...@@ -422,8 +422,10 @@ public: ...@@ -422,8 +422,10 @@ public:
// Set input bundles // Set input bundles
for (unsigned int i = 0 ; i < m_InputSourcesForTraining.size() ; i++) for (unsigned int i = 0 ; i < m_InputSourcesForTraining.size() ; i++)
{ {
m_TrainModelFilter->PushBackInputBundle(m_InputPlaceholdersForTraining[i], m_TrainModelFilter->PushBackInputBundle(
m_InputPatchesSizeForTraining[i], m_InputSourcesForTraining[i]); m_InputPlaceholdersForTraining[i],
m_InputPatchesSizeForTraining[i],
m_InputSourcesForTraining[i]);
} }
// Train the model // Train the model
...@@ -449,26 +451,24 @@ public: ...@@ -449,26 +451,24 @@ public:
m_ValidateModelFilter = ValidateModelFilterType::New(); m_ValidateModelFilter = ValidateModelFilterType::New();
m_ValidateModelFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def()); m_ValidateModelFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def());
m_ValidateModelFilter->SetSession(m_SavedModel.session.get()); m_ValidateModelFilter->SetSession(m_SavedModel.session.get());
m_ValidateModelFilter->SetOutputTensorsNames(m_TargetTensorsNames);
m_ValidateModelFilter->SetBatchSize(GetParameterInt("training.batchsize")); m_ValidateModelFilter->SetBatchSize(GetParameterInt("training.batchsize"));
m_ValidateModelFilter->SetUserPlaceholders(GetUserPlaceholders("validation.userplaceholders")); m_ValidateModelFilter->SetUserPlaceholders(GetUserPlaceholders("validation.userplaceholders"));
// Evaluate the metrics against the learning data (test) // Evaluate the metrics against the learning data
for (unsigned int i = 0 ; i < m_InputSourcesForTest.size() ; i++) for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++)
{
m_ValidateModelFilter->PushBackInputBundle(m_InputPlaceholdersForValidation[i],
m_InputPatchesSizeForValidation[i], m_InputSourcesForTest[i]);
}
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{ {
m_ValidateModelFilter->PushBackInputReference(m_InputTargetsForTest[i], m_TargetPatchesSize[i]); m_ValidateModelFilter->PushBackInputBundle(
m_InputPlaceholdersForValidation[i],
m_InputPatchesSizeForValidation[i],
m_InputSourcesForEvaluationAgainstLearningData[i]);
} }
m_ValidateModelFilter->SetOutputTensorsNames(m_TargetTensorsNames);
// Evaluate the model (test) m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData);
AddProcess(m_ValidateModelFilter, "Evaluate model (Test)"); m_ValidateModelFilter->SetOutputFOESizes(m_TargetPatchesSize);
AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)");
m_ValidateModelFilter->Update(); m_ValidateModelFilter->Update();
// Print some metrics // Print metrics
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{ {
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":");
...@@ -476,28 +476,21 @@ public: ...@@ -476,28 +476,21 @@ public:
} }
// Evaluate the metrics against the validation data // Evaluate the metrics against the validation data
for (unsigned int i = 0 ; i < m_InputSourcesForValidation.size() ; i++) for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstValidationData.size() ; i++)
{
m_ValidateModelFilter->SetInput(i, m_InputSourcesForValidation[i]);
}
m_ValidateModelFilter->ClearInputReferences();
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{ {
m_ValidateModelFilter->PushBackInputReference(m_InputTargetsForValidation[i], m_TargetPatchesSize[i]); m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstValidationData[i]);
} }
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstValidationData);
// Evaluate the model (validation) AddProcess(m_ValidateModelFilter, "Evaluate model (Validation data)");
AddProcess(m_ValidateModelFilter, "Evaluate model (Validation)");
m_ValidateModelFilter->Update(); m_ValidateModelFilter->Update();
// Print some metrics // Print metrics
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{ {
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":");
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i)); PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
} }
} }
else if (GetParameterInt("validation.mode")==2) // rmse) else if (GetParameterInt("validation.mode")==2) // rmse)
{ {
...@@ -516,17 +509,23 @@ private: ...@@ -516,17 +509,23 @@ private:
tensorflow::SavedModelBundle m_SavedModel; // must be alive during all the execution of the application ! tensorflow::SavedModelBundle m_SavedModel; // must be alive during all the execution of the application !
BundleList m_Bundles; BundleList m_Bundles;
// Patches size
SizeList m_InputPatchesSizeForTraining; SizeList m_InputPatchesSizeForTraining;
SizeList m_InputPatchesSizeForValidation; SizeList m_InputPatchesSizeForValidation;
SizeList m_TargetPatchesSize; SizeList m_TargetPatchesSize;
// Placeholders and Tensors names
StringList m_InputPlaceholdersForTraining; StringList m_InputPlaceholdersForTraining;
StringList m_InputPlaceholdersForValidation; StringList m_InputPlaceholdersForValidation;
StringList m_TargetTensorsNames; StringList m_TargetTensorsNames;
// Image sources
std::vector<FloatVectorImageType::Pointer> m_InputSourcesForTraining; std::vector<FloatVectorImageType::Pointer> m_InputSourcesForTraining;
std::vector<FloatVectorImageType::Pointer> m_InputSourcesForTest; std::vector<FloatVectorImageType::Pointer> m_InputSourcesForEvaluationAgainstLearningData;
std::vector<FloatVectorImageType::Pointer> m_InputTargetsForTest; std::vector<FloatVectorImageType::Pointer> m_InputSourcesForEvaluationAgainstValidationData;
std::vector<FloatVectorImageType::Pointer> m_InputSourcesForValidation; std::vector<FloatVectorImageType::Pointer> m_InputTargetsForEvaluationAgainstLearningData;
std::vector<FloatVectorImageType::Pointer> m_InputTargetsForValidation; std::vector<FloatVectorImageType::Pointer> m_InputTargetsForEvaluationAgainstValidationData;
}; // end of class }; // end of class
......
...@@ -84,15 +84,29 @@ public: ...@@ -84,15 +84,29 @@ public:
tensorflow::Session * GetSession() { return m_Session; } tensorflow::Session * GetSession() { return m_Session; }
/** Model parameters */ /** Model parameters */
void PushBackInputBundle(std::string placeholder, SizeType fieldOfView, ImagePointerType image); void PushBackInputBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image);
itkSetMacro(InputPlaceholdersNames, StringList);
// /** Input placeholders names */
// itkSetMacro(InputPlaceholdersNames, StringList);
itkGetMacro(InputPlaceholdersNames, StringList); itkGetMacro(InputPlaceholdersNames, StringList);
itkSetMacro(InputFOVSizes, SizeListType); //
// /** Receptive field */
// itkSetMacro(InputFOVSizes, SizeListType);
itkGetMacro(InputFOVSizes, SizeListType); itkGetMacro(InputFOVSizes, SizeListType);
void SetUserPlaceholders(DictListType dict) { m_UserPlaceholders = dict; }
DictListType GetUserPlaceholders() { return m_UserPlaceholders; } /** Output tensors names */
itkSetMacro(OutputTensorsNames, StringList); itkSetMacro(OutputTensorsNames, StringList);
itkGetMacro(OutputTensorsNames, StringList); itkGetMacro(OutputTensorsNames, StringList);
/** Expression field */
itkSetMacro(OutputFOESizes, SizeListType);
itkGetMacro(OutputFOESizes, SizeListType);
/** User placeholders */
void SetUserPlaceholders(DictListType dict) { m_UserPlaceholders = dict; }
DictListType GetUserPlaceholders() { return m_UserPlaceholders; }
/** Target nodes names */
itkSetMacro(TargetNodesNames, StringList); itkSetMacro(TargetNodesNames, StringList);
itkGetMacro(TargetNodesNames, StringList); itkGetMacro(TargetNodesNames, StringList);
...@@ -123,6 +137,7 @@ private: ...@@ -123,6 +137,7 @@ private:
// Model parameters // Model parameters
StringList m_InputPlaceholdersNames; // Input placeholders names StringList m_InputPlaceholdersNames; // Input placeholders names
SizeListType m_InputFOVSizes; // Input tensors field of view (FOV) sizes SizeListType m_InputFOVSizes; // Input tensors field of view (FOV) sizes
SizeListType m_OutputFOESizes; // Output tensors field of expression (FOE) sizes
DictListType m_UserPlaceholders; // User placeholders DictListType m_UserPlaceholders; // User placeholders
StringList m_OutputTensorsNames; // User tensors StringList m_OutputTensorsNames; // User tensors
StringList m_TargetNodesNames; // User target tensors StringList m_TargetNodesNames; // User target tensors
......
...@@ -25,10 +25,10 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage> ...@@ -25,10 +25,10 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
template <class TInputImage, class TOutputImage> template <class TInputImage, class TOutputImage>
void void
TensorflowMultisourceModelBase<TInputImage, TOutputImage> TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::PushBackInputBundle(std::string placeholder, SizeType fieldOfView, ImagePointerType image) ::PushBackInputBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image)
{ {
Superclass::PushBackInput(image); Superclass::PushBackInput(image);
m_InputFOVSizes.push_back(fieldOfView); m_InputFOVSizes.push_back(receptiveField);
m_InputPlaceholdersNames.push_back(placeholder); m_InputPlaceholdersNames.push_back(placeholder);
} }
......
...@@ -63,6 +63,7 @@ public: ...@@ -63,6 +63,7 @@ public:
typedef typename Superclass::RegionType RegionType; typedef typename Superclass::RegionType RegionType;
typedef typename Superclass::SizeType SizeType; typedef typename Superclass::SizeType SizeType;
typedef typename Superclass::IndexType IndexType; typedef typename Superclass::IndexType IndexType;
typedef std::vector<ImagePointerType> ImageListType;
/* Typedefs for parameters */ /* Typedefs for parameters */
typedef typename Superclass::DictType DictType; typedef typename Superclass::DictType DictType;
...@@ -90,9 +91,8 @@ public: ...@@ -90,9 +91,8 @@ public:
virtual void GenerateInputRequestedRegion(); virtual void GenerateInputRequestedRegion();
virtual void PushBackInputReference(const ImageType *input, SizeType foe); virtual void SetInputReferences(ImageListType input);
const TInputImage* GetInputReference(unsigned int index); ImagePointerType GetInputReference(unsigned int index);
void ClearInputReferences();
virtual void GenerateData(); virtual void GenerateData();
...@@ -108,8 +108,7 @@ private: ...@@ -108,8 +108,7 @@ private:
void operator=(const Self&); //purposely not implemented void operator=(const Self&); //purposely not implemented
unsigned int m_BatchSize; // Batch size unsigned int m_BatchSize; // Batch size
SizeListType m_OutputFOESizes; // Output tensors field of expression (FOE) sizes ImageListType m_References; // The references images
std::vector<ImageType *> m_References; // The references images
// Read only // Read only
unsigned int m_NumberOfSamples; // Number of samples unsigned int m_NumberOfSamples; // Number of samples
......
...@@ -97,16 +97,17 @@ TensorflowMultisourceModelValidate<TInputImage> ...@@ -97,16 +97,17 @@ TensorflowMultisourceModelValidate<TInputImage>
{ {
itkExceptionMacro("No reference is set"); itkExceptionMacro("No reference is set");
} }
if (nbOfRefs != m_OutputFOESizes.size()) SizeListType outputEFSizes = this->GetOutputFOESizes();
if (nbOfRefs != outputEFSizes.size())
{ {
itkExceptionMacro("There is " << nbOfRefs << " but only " << itkExceptionMacro("There is " << nbOfRefs << " but only " <<
m_OutputFOESizes.size() << " field of expression sizes"); outputEFSizes.size() << " field of expression sizes");
} }
// Check reference image infos // Check reference image infos
for (unsigned int i = 0 ;i < nbOfRefs ; i++) for (unsigned int i = 0 ;i < nbOfRefs ; i++)
{ {
const SizeType outputFOESize = m_OutputFOESizes[i]; const SizeType outputFOESize = outputEFSizes[i];
const RegionType refRegion = m_References[i]->GetLargestPossibleRegion(); const RegionType refRegion = m_References[i]->GetLargestPossibleRegion();
if (refRegion.GetSize(0) != outputFOESize[0]) if (refRegion.GetSize(0) != outputFOESize[0])
{ {
...@@ -143,14 +144,13 @@ TensorflowMultisourceModelValidate<TInputImage> ...@@ -143,14 +144,13 @@ TensorflowMultisourceModelValidate<TInputImage>
template<class TInputImage> template<class TInputImage>
void void
TensorflowMultisourceModelValidate<TInputImage> TensorflowMultisourceModelValidate<TInputImage>
::PushBackInputReference(const ImageType *input, SizeType fieldOfExpression) ::SetInputReferences(ImageListType input)
{ {
m_References.push_back(const_cast<ImageType*>(input)); m_References = input;
m_OutputFOESizes.push_back(fieldOfExpression);
} }
template<class TInputImage> template<class TInputImage>
const TInputImage* typename TensorflowMultisourceModelValidate<TInputImage>::ImagePointerType
TensorflowMultisourceModelValidate<TInputImage> TensorflowMultisourceModelValidate<TInputImage>
::GetInputReference(unsigned int index) ::GetInputReference(unsigned int index)
{ {
...@@ -159,16 +159,7 @@ TensorflowMultisourceModelValidate<TInputImage> ...@@ -159,16 +159,7 @@ TensorflowMultisourceModelValidate<TInputImage>
itkExceptionMacro("There is no input reference #" << index); itkExceptionMacro("There is no input reference #" << index);
} }
return static_cast<const ImageType*>(m_References[index]); return m_References[index];
}
template <class TInputImage>
void
TensorflowMultisourceModelValidate<TInputImage>
::ClearInputReferences()
{
m_References.clear();
m_OutputFOESizes.clear();
} }
/** /**
...@@ -257,11 +248,11 @@ TensorflowMultisourceModelValidate<TInputImage> ...@@ -257,11 +248,11 @@ TensorflowMultisourceModelValidate<TInputImage>
itkWarningMacro("There is " << outputs.size() << " outputs returned after session run, " << itkWarningMacro("There is " << outputs.size() << " outputs returned after session run, " <<
"but only " << m_References.size() << " reference(s) set"); "but only " << m_References.size() << " reference(s) set");
} }
SizeListType outputEFSizes = this->GetOutputFOESizes();
for (unsigned int refIdx = 0 ; refIdx < outputs.size() ; refIdx++) for (unsigned int refIdx = 0 ; refIdx < outputs.size() ; refIdx++)
{ {
// Recopy the chunk // Recopy the chunk
const SizeType outputFOESize = m_OutputFOESizes[refIdx]; const SizeType outputFOESize = outputEFSizes[refIdx];
IndexType cpyStart; IndexType cpyStart;
cpyStart.Fill(0); cpyStart.Fill(0);
IndexType refRegStart; IndexType refRegStart;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment