diff --git a/include/otbTensorflowMultisourceModelBase.h b/include/otbTensorflowMultisourceModelBase.h index e0c23846fd45110efe30704acde79c35060e5310..8f4440730946d947f83fb92d960e027485be5ec0 100644 --- a/include/otbTensorflowMultisourceModelBase.h +++ b/include/otbTensorflowMultisourceModelBase.h @@ -72,10 +72,10 @@ public: typedef typename TInputImage::RegionType RegionType; /** Typedefs for parameters */ - typedef std::pair<std::string, tensorflow::Tensor> DictType; + typedef std::pair<std::string, tensorflow::Tensor> DictElementType; typedef std::vector<std::string> StringList; typedef std::vector<SizeType> SizeListType; - typedef std::vector<DictType> DictListType; + typedef std::vector<DictElementType> DictType; typedef std::vector<tensorflow::DataType> DataTypeListType; typedef std::vector<tensorflow::TensorShapeProto> TensorShapeProtoList; typedef std::vector<tensorflow::Tensor> TensorListType; @@ -87,27 +87,28 @@ public: tensorflow::Session * GetSession() { return m_Session; } /** Model parameters */ - void PushBackInputBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image); + void PushBackInputTensorBundle(std::string name, SizeType receptiveField, ImagePointerType image); + void PushBackOuputTensorBundle(std::string name, SizeType expressionField); -// /** Input placeholders names */ -// itkSetMacro(InputPlaceholdersNames, StringList); - itkGetMacro(InputPlaceholdersNames, StringList); -// -// /** Receptive field */ -// itkSetMacro(InputFOVSizes, SizeListType); - itkGetMacro(InputFOVSizes, SizeListType); + /** Input placeholders names */ + itkSetMacro(InputPlaceholders, StringList); + itkGetMacro(InputPlaceholders, StringList); + + /** Receptive field */ + itkSetMacro(InputReceptiveFields, SizeListType); + itkGetMacro(InputReceptiveFields, SizeListType); /** Output tensors names */ - itkSetMacro(OutputTensorsNames, StringList); - itkGetMacro(OutputTensorsNames, StringList); + itkSetMacro(OutputTensors, StringList); + itkGetMacro(OutputTensors, StringList); /** Expression field */ - itkSetMacro(OutputFOESizes, SizeListType); - itkGetMacro(OutputFOESizes, SizeListType); + itkSetMacro(OutputExpressionFields, SizeListType); + itkGetMacro(OutputExpressionFields, SizeListType); /** User placeholders */ - void SetUserPlaceholders(DictListType dict) { m_UserPlaceholders = dict; } - DictListType GetUserPlaceholders() { return m_UserPlaceholders; } + void SetUserPlaceholders(DictType dict) { m_UserPlaceholders = dict; } + DictType GetUserPlaceholders() { return m_UserPlaceholders; } /** Target nodes names */ itkSetMacro(TargetNodesNames, StringList); @@ -125,9 +126,9 @@ protected: TensorflowMultisourceModelBase(); virtual ~TensorflowMultisourceModelBase() {}; - virtual std::stringstream GenerateDebugReport(DictListType & inputs, TensorListType & outputs); + virtual std::stringstream GenerateDebugReport(DictType & inputs); - virtual void RunSession(DictListType & inputs, TensorListType & outputs); + virtual void RunSession(DictType & inputs, TensorListType & outputs); private: TensorflowMultisourceModelBase(const Self&); //purposely not implemented @@ -138,11 +139,11 @@ private: tensorflow::Session * m_Session; // The tensorflow session // Model parameters - StringList m_InputPlaceholdersNames; // Input placeholders names - SizeListType m_InputFOVSizes; // Input tensors field of view (FOV) sizes - SizeListType m_OutputFOESizes; // Output tensors field of expression (FOE) sizes - DictListType m_UserPlaceholders; // User placeholders - StringList m_OutputTensorsNames; // User tensors + StringList m_InputPlaceholders; // Input placeholders names + SizeListType m_InputReceptiveFields; // Input receptive fields + StringList m_OutputTensors; // Output tensors names + SizeListType m_OutputExpressionFields; // Output expression fields + DictType m_UserPlaceholders; // User placeholders StringList m_TargetNodesNames; // User target tensors // Read-only diff --git a/include/otbTensorflowMultisourceModelBase.hxx b/include/otbTensorflowMultisourceModelBase.hxx index 07b2cf71d5afdab187b494d3573a44dcd74c07d3..938c903544d8d5eed0fddf6596f0224ae019ca1f 100644 --- a/include/otbTensorflowMultisourceModelBase.hxx +++ b/include/otbTensorflowMultisourceModelBase.hxx @@ -20,22 +20,23 @@ template <class TInputImage, class TOutputImage> TensorflowMultisourceModelBase<TInputImage, TOutputImage> ::TensorflowMultisourceModelBase() { + m_Session = nullptr; } template <class TInputImage, class TOutputImage> void TensorflowMultisourceModelBase<TInputImage, TOutputImage> -::PushBackInputBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image) +::PushBackInputTensorBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image) { Superclass::PushBackInput(image); - m_InputFOVSizes.push_back(receptiveField); - m_InputPlaceholdersNames.push_back(placeholder); + m_InputReceptiveFields.push_back(receptiveField); + m_InputPlaceholders.push_back(placeholder); } template <class TInputImage, class TOutputImage> std::stringstream TensorflowMultisourceModelBase<TInputImage, TOutputImage> -::GenerateDebugReport(DictListType & inputs, TensorListType & outputs) +::GenerateDebugReport(DictType & inputs) { // Create a debug report std::stringstream debugReport; @@ -69,7 +70,7 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage> template <class TInputImage, class TOutputImage> void TensorflowMultisourceModelBase<TInputImage, TOutputImage> -::RunSession(DictListType & inputs, TensorListType & outputs) +::RunSession(DictType & inputs, TensorListType & outputs) { // Add the user's placeholders @@ -82,11 +83,11 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage> // The session will initialize the outputs // Run the session, evaluating our output tensors from the graph - auto status = this->GetSession()->Run(inputs, m_OutputTensorsNames, m_TargetNodesNames, &outputs); + auto status = this->GetSession()->Run(inputs, m_OutputTensors, m_TargetNodesNames, &outputs); if (!status.ok()) { // Create a debug report - std::stringstream debugReport = GenerateDebugReport(inputs, outputs); + std::stringstream debugReport = GenerateDebugReport(inputs); // Throw an exception with the report itkExceptionMacro("Can't run the tensorflow session !\n" << @@ -108,11 +109,11 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage> // - patches sizes // - input image const unsigned int nbInputs = this->GetNumberOfInputs(); - if (nbInputs != m_InputFOVSizes.size() || nbInputs != m_InputPlaceholdersNames.size()) + if (nbInputs != m_InputReceptiveFields.size() || nbInputs != m_InputPlaceholders.size()) { itkExceptionMacro("Number of input images is " << nbInputs << - " but the number of input patches size is " << m_InputFOVSizes.size() << - " and the number of input tensors names is " << m_InputPlaceholdersNames.size()); + " but the number of input patches size is " << m_InputReceptiveFields.size() << + " and the number of input tensors names is " << m_InputPlaceholders.size()); } ////////////////////////////////////////////////////////////////////////////////////////// @@ -120,8 +121,8 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage> ////////////////////////////////////////////////////////////////////////////////////////// // Get input and output tensors datatypes and shapes - tf::GetTensorAttributes(m_Graph, m_InputPlaceholdersNames, m_InputTensorsShapes, m_InputTensorsDataTypes); - tf::GetTensorAttributes(m_Graph, m_OutputTensorsNames, m_OutputTensorsShapes, m_OutputTensorsDataTypes); + tf::GetTensorAttributes(m_Graph, m_InputPlaceholders, m_InputTensorsShapes, m_InputTensorsDataTypes); + tf::GetTensorAttributes(m_Graph, m_OutputTensors, m_OutputTensorsShapes, m_OutputTensorsDataTypes); } diff --git a/include/otbTensorflowMultisourceModelFilter.h b/include/otbTensorflowMultisourceModelFilter.h index 833294ada5c05bf469c509aa58ea69ca04cfd1d4..fdc799a5cd69f86a77bc396f3f159f4ffde101e5 100644 --- a/include/otbTensorflowMultisourceModelFilter.h +++ b/include/otbTensorflowMultisourceModelFilter.h @@ -30,27 +30,26 @@ namespace otb * * The filter takes N input images and feed the TensorFlow model to produce * one output image of desired TF op results. - * Names of input/output placeholders/tensors must be specified using the - * SetInputPlaceholdersNames/SetOutputTensorNames. + * Names of input placeholders and output tensors must be specified using the + * SetPlaceholders() and SetTensors() methods. * - * Example: we have a tensorflow model which runs the input images "x1" and "x2" + * Example: we have a TensorFlow model which runs the input images "x1" and "x2" * and produces the output image "y". - * "x1" and "x2" are two TF placeholders, we set InputTensorNames={"x1","x2"} - * "y1" corresponds to one TF op output, we set OutputTensorNames={"y1"} + * "x1" and "x2" are two TF placeholders, we set InputPlaceholder={"x1","x2"} + * "y1" corresponds to one TF op output, we set OutputTensors={"y1"} * * The reference grid for the output image is the same as the first input image. * This grid can be scaled by setting the OutputSpacingScale value. * This can be used to run models which downsize the output image spacing - * (typically fully convolutional model with strides) or to produce the result + * (e.g. fully convolutional model with strides) or to produce the result * of a patch-based network at regular intervals. * - * For each input, input field of view (FOV) must be set. + * For each input (resp. output), receptive field (resp. expression field) must be set. * If the number of values in the output tensors (produced by the model) don't - * fit with the output image region, exception will be thrown. + * fit with the output image region, an exception will be thrown. * - * - * The tensorflow Graph is passed using the SetGraph() method - * The tensorflow Session is passed using the SetSession() method + * The TensorFlow Graph is passed using the SetGraph() method + * The TensorFlow Session is passed using the SetSession() method * * \ingroup OTBTensorflow */ @@ -94,6 +93,7 @@ public: typedef typename itk::ImageRegionConstIterator<TInputImage> InputConstIteratorType; /* Typedefs for parameters */ + typedef typename Superclass::DictElementType DictElementType; typedef typename Superclass::DictType DictType; typedef typename Superclass::StringList StringList; typedef typename Superclass::SizeListType SizeListType; @@ -101,8 +101,6 @@ public: typedef typename Superclass::TensorListType TensorListType; typedef std::vector<float> ScaleListType; - itkSetMacro(OutputFOESize, SizeType); - itkGetMacro(OutputFOESize, SizeType); itkSetMacro(OutputGridSize, SizeType); itkGetMacro(OutputGridSize, SizeType); itkSetMacro(ForceOutputGridSize, bool); @@ -132,7 +130,6 @@ private: TensorflowMultisourceModelFilter(const Self&); //purposely not implemented void operator=(const Self&); //purposely not implemented - SizeType m_OutputFOESize; // Output tensors field of expression (FOE) sizes SizeType m_OutputGridSize; // Output grid size bool m_ForceOutputGridSize; // Force output grid size bool m_FullyConvolutional; // Convolution mode diff --git a/include/otbTensorflowMultisourceModelFilter.hxx b/include/otbTensorflowMultisourceModelFilter.hxx index c7a547e240b2ccc8e60a0ced3d5c734b501537bd..e5fa16b9b2947bc91bd7adf337a971473ea8a313 100644 --- a/include/otbTensorflowMultisourceModelFilter.hxx +++ b/include/otbTensorflowMultisourceModelFilter.hxx @@ -216,7 +216,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> // Update output image extent PointType currentInputImageExtentInf, currentInputImageExtentSup; - ImageToExtent(currentImage, currentInputImageExtentInf, currentInputImageExtentSup, this->GetInputFOVSizes()[imageIndex]); + ImageToExtent(currentImage, currentInputImageExtentInf, currentInputImageExtentSup, this->GetInputReceptiveFields()[imageIndex]); for(unsigned int dim = 0; dim<ImageType::ImageDimension; ++dim) { extentInf[dim] = vnl_math_max(currentInputImageExtentInf[dim], extentInf[dim]); @@ -236,7 +236,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> if (!m_ForceOutputGridSize) { // Default is the output field of expression - m_OutputGridSize = m_OutputFOESize; + m_OutputGridSize = this->GetOutputExpressionFields().at(0); } // Resize the largestPossibleRegion to be a multiple of the grid size @@ -315,9 +315,9 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> } // Compute the FOV-scale*FOE radius to pad - SizeType toPad(this->GetInputFOVSizes().at(i)); - toPad[0] -= 1 + (m_OutputFOESize[0] - 1) * m_OutputSpacingScale; - toPad[1] -= 1 + (m_OutputFOESize[1] - 1) * m_OutputSpacingScale; + SizeType toPad(this->GetInputReceptiveFields().at(i)); + toPad[0] -= 1 + (this->GetOutputExpressionFields().at(0)[0] - 1) * m_OutputSpacingScale; + toPad[1] -= 1 + (this->GetOutputExpressionFields().at(0)[1] - 1) * m_OutputSpacingScale; // Pad with radius SmartPad(inRegion, toPad); @@ -365,7 +365,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> const unsigned int nInputs = this->GetNumberOfInputs(); // Create input tensors list - DictListType inputs; + DictType inputs; // Populate input tensors for (unsigned int i = 0 ; i < nInputs ; i++) @@ -374,7 +374,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> const ImagePointerType inputPtr = const_cast<TInputImage*>(this->GetInput(i)); // Patch size of tensor #i - const SizeType inputPatchSize = this->GetInputFOVSizes().at(i); + const SizeType inputPatchSize = this->GetInputReceptiveFields().at(i); // Input image requested region const RegionType reqRegion = inputPtr->GetRequestedRegion(); @@ -395,7 +395,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> tf::RecopyImageRegionToTensorWithCast<TInputImage>(inputPtr, reqRegion, inputTensor, 0); // Input #1 : the tensor of patches (aka the batch) - DictType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor }; + DictElementType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor }; inputs.push_back(input1); } else @@ -429,7 +429,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> } // Input #1 : the tensor of patches (aka the batch) - DictType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor }; + DictElementType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor }; inputs.push_back(input1); } // mode is not full convolutional diff --git a/include/otbTensorflowMultisourceModelLearningBase.h b/include/otbTensorflowMultisourceModelLearningBase.h new file mode 100644 index 0000000000000000000000000000000000000000..d6feb24e83ca08eabb1e2566337511b5df0bb614 --- /dev/null +++ b/include/otbTensorflowMultisourceModelLearningBase.h @@ -0,0 +1,112 @@ +/*========================================================================= + + Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef otbTensorflowMultisourceModelLearningBase_h +#define otbTensorflowMultisourceModelLearningBase_h + +#include "itkProcessObject.h" +#include "itkNumericTraits.h" +#include "itkSimpleDataObjectDecorator.h" + +// Base +#include "otbTensorflowMultisourceModelBase.h" + +// Shuffle +#include <random> +#include <algorithm> +#include <iterator> + +namespace otb +{ + +/** + * \class TensorflowMultisourceModelLearningBase + * \brief This filter is the base class for learning filters. + * + * \ingroup OTBTensorflow + */ +template <class TInputImage> +class ITK_EXPORT TensorflowMultisourceModelLearningBase : +public TensorflowMultisourceModelBase<TInputImage> +{ +public: + + /** Standard class typedefs. */ + typedef TensorflowMultisourceModelLearningBase Self; + typedef TensorflowMultisourceModelBase<TInputImage> Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + /** Run-time type information (and related methods). */ + itkTypeMacro(TensorflowMultisourceModelLearningBase, TensorflowMultisourceModelBase); + + /** Images typedefs */ + typedef typename Superclass::ImageType ImageType; + typedef typename Superclass::ImagePointerType ImagePointerType; + typedef typename Superclass::RegionType RegionType; + typedef typename Superclass::SizeType SizeType; + typedef typename Superclass::IndexType IndexType; + + /* Typedefs for parameters */ + typedef typename Superclass::DictType DictType; + typedef typename Superclass::DictElementType DictElementType; + typedef typename Superclass::StringList StringList; + typedef typename Superclass::SizeListType SizeListType; + typedef typename Superclass::TensorListType TensorListType; + + /* Typedefs for index */ + typedef typename ImageType::IndexValueType IndexValueType; + typedef std::vector<IndexValueType> IndexListType; + + // Batch size + itkSetMacro(BatchSize, IndexValueType); + itkGetMacro(BatchSize, IndexValueType); + + // Use streaming + itkSetMacro(UseStreaming, bool); + itkGetMacro(UseStreaming, bool); + + // Get number of samples + itkGetMacro(NumberOfSamples, IndexValueType); + +protected: + TensorflowMultisourceModelLearningBase(); + virtual ~TensorflowMultisourceModelLearningBase() {}; + + virtual void GenerateOutputInformation(void); + + virtual void GenerateInputRequestedRegion(); + + virtual void GenerateData(); + + virtual void PopulateInputTensors(TensorListType & inputs, const IndexValueType & sampleStart, + const IndexValueType & batchSize, const IndexListType & order = IndexListType()); + + virtual void ProcessBatch(TensorListType & inputs, const IndexValueType & sampleStart, + const IndexValueType & batchSize) = 0; + +private: + TensorflowMultisourceModelLearningBase(const Self&); //purposely not implemented + void operator=(const Self&); //purposely not implemented + + unsigned int m_BatchSize; // Batch size + bool m_UseStreaming; // Use streaming on/off + + // Read only + IndexValueType m_NumberOfSamples; // Number of samples + +}; // end class + + +} // end namespace otb + +#include "otbTensorflowMultisourceModelLearningBase.hxx" + +#endif diff --git a/include/otbTensorflowMultisourceModelLearningBase.hxx b/include/otbTensorflowMultisourceModelLearningBase.hxx new file mode 100644 index 0000000000000000000000000000000000000000..9c913cd2d3d6230a0356bec742ae749f333949f8 --- /dev/null +++ b/include/otbTensorflowMultisourceModelLearningBase.hxx @@ -0,0 +1,211 @@ +/*========================================================================= + + Copyright (c) Remi Cresson (IRSTEA). All rights reserved. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef otbTensorflowMultisourceModelLearningBase_txx +#define otbTensorflowMultisourceModelLearningBase_txx + +#include "otbTensorflowMultisourceModelLearningBase.h" + +namespace otb +{ + +template <class TInputImage> +TensorflowMultisourceModelLearningBase<TInputImage> +::TensorflowMultisourceModelLearningBase(): m_BatchSize(100), + m_NumberOfSamples(0), m_UseStreaming(false) + { + } + + +template <class TInputImage> +void +TensorflowMultisourceModelLearningBase<TInputImage> +::GenerateOutputInformation() + { + Superclass::GenerateOutputInformation(); + + ImageType * outputPtr = this->GetOutput(); + RegionType nullRegion; + nullRegion.GetModifiableSize().Fill(1); + outputPtr->SetNumberOfComponentsPerPixel(1); + outputPtr->SetLargestPossibleRegion( nullRegion ); + + // Count the number of samples + m_NumberOfSamples = 0; + for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++) + { + // Input image pointer + ImagePointerType inputPtr = const_cast<ImageType*>(this->GetInput(i)); + + // Make sure input is available + if ( inputPtr.IsNull() ) + { + itkExceptionMacro(<< "Input " << i << " is null!"); + } + + // Update input information + inputPtr->UpdateOutputInformation(); + + // Patch size of tensor #i + const SizeType inputPatchSize = this->GetInputReceptiveFields().at(i); + + // Input image requested region + const RegionType reqRegion = inputPtr->GetLargestPossibleRegion(); + + // Check size X + if (inputPatchSize[0] != reqRegion.GetSize(0)) + itkExceptionMacro("Patch size for input " << i + << " is " << inputPatchSize + << " but input patches image size is " << reqRegion.GetSize()); + + // Check size Y + if (reqRegion.GetSize(1) % inputPatchSize[1] != 0) + itkExceptionMacro("Input patches image must have a number of rows which is " + << "a multiple of the patch size Y! Patches image has " << reqRegion.GetSize(1) + << " rows but patch size Y is " << inputPatchSize[1] << " for input " << i); + + // Get the batch size + const tensorflow::uint64 currNumberOfSamples = reqRegion.GetSize(1) / inputPatchSize[1]; + + // Check the consistency with other inputs + if (m_NumberOfSamples == 0) + { + m_NumberOfSamples = currNumberOfSamples; + } + else if (m_NumberOfSamples != currNumberOfSamples) + { + itkGenericExceptionMacro("Previous batch size is " << m_NumberOfSamples + << " but input " << i + << " has a batch size of " << currNumberOfSamples ); + } + } // next input + } + +template <class TInputImage> +void +TensorflowMultisourceModelLearningBase<TInputImage> +::GenerateInputRequestedRegion() + { + Superclass::GenerateInputRequestedRegion(); + + // For each image, set no image region + for(unsigned int i = 0; i < this->GetNumberOfInputs(); ++i) + { + RegionType nullRegion; + ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(i) ); + + // If the streaming is enabled, we don't read the full image + if (m_UseStreaming) + { + inputImage->SetRequestedRegion(nullRegion); + } + else + { + inputImage->SetRequestedRegion(inputImage->GetLargestPossibleRegion()); + } + } // next image + } + +/** + * + */ +template <class TInputImage> +void +TensorflowMultisourceModelLearningBase<TInputImage> +::GenerateData() + { + + // Batches loop + const IndexValueType nBatches = vcl_ceil(m_NumberOfSamples / m_BatchSize); + const IndexValueType rest = m_NumberOfSamples % m_BatchSize; + + itk::ProgressReporter progress(this, 0, nBatches); + + for (IndexValueType batch = 0 ; batch < nBatches ; batch++) + { + + // Create input tensors list + TensorListType inputs; + + // Batch start and size + const IndexValueType sampleStart = batch * m_BatchSize; + IndexValueType batchSize = m_BatchSize; + if (rest != 0) + { + batchSize = rest; + } + + // Process the batch + ProcessBatch(inputs, sampleStart, batchSize); + + progress.CompletedPixel(); + } // Next batch + + } + +template <class TInputImage> +void +TensorflowMultisourceModelLearningBase<TInputImage> +::PopulateInputTensors(TensorListType & inputs, const IndexValueType & sampleStart, + const IndexValueType & batchSize, const IndexListType & order) + { + const bool reorder = order.size(); + + // Populate input tensors + for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++) + { + // Input image pointer + ImagePointerType inputPtr = const_cast<ImageType*>(this->GetInput(i)); + + // Patch size of tensor #i + const SizeType inputPatchSize = this->GetInputReceptiveFields().at(i); + + // Create the tensor for the batch + const tensorflow::int64 sz_n = batchSize; + const tensorflow::int64 sz_y = inputPatchSize[1]; + const tensorflow::int64 sz_x = inputPatchSize[0]; + const tensorflow::int64 sz_c = inputPtr->GetNumberOfComponentsPerPixel(); + const tensorflow::TensorShape inputTensorShape({sz_n, sz_y, sz_x, sz_c}); + tensorflow::Tensor inputTensor(this->GetInputTensorsDataTypes()[i], inputTensorShape); + + // Populate the tensor + for (tensorflow::uint64 elem = 0 ; elem < batchSize ; elem++) + { + const tensorflow::uint64 samplePos = sampleStart + elem; + IndexType start; + start[0] = 0; + if (reorder) + { + start[1] = order[samplePos] * sz_y; + } + else + { + start[1] = samplePos * sz_y;; + } + RegionType patchRegion(start, inputPatchSize); + if (m_UseStreaming) + { + // If streaming is enabled, we need to explicitly propagate requested region + tf::PropagateRequestedRegion<TInputImage>(inputPtr, patchRegion); + } + tf::RecopyImageRegionToTensorWithCast<TInputImage>(inputPtr, patchRegion, inputTensor, elem ); + } + + // Input #i : the tensor of patches (aka the batch) + DictElementType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor }; + inputs.push_back(input1); + } // next input tensor + } + + +} // end namespace otb + + +#endif diff --git a/include/otbTensorflowMultisourceModelTrain.h b/include/otbTensorflowMultisourceModelTrain.h index 8308f2ed5bbbb64b949163f33bea3a3b5eb1fa80..ec4ce349c6af76d36c6d2280c080b041d4b06e2d 100644 --- a/include/otbTensorflowMultisourceModelTrain.h +++ b/include/otbTensorflowMultisourceModelTrain.h @@ -16,7 +16,7 @@ #include "itkSimpleDataObjectDecorator.h" // Base -#include "otbTensorflowMultisourceModelBase.h" +#include "otbTensorflowMultisourceModelLearningBase.h" // Shuffle #include <random> @@ -31,67 +31,47 @@ namespace otb * \brief This filter train a TensorFlow model over multiple input images. * * The filter takes N input images and feed the TensorFlow model. - * Names of input placeholders must be specified using the - * SetInputPlaceholdersNames method * - * TODO: Add an option to disable streaming * * \ingroup OTBTensorflow */ template <class TInputImage> class ITK_EXPORT TensorflowMultisourceModelTrain : -public TensorflowMultisourceModelBase<TInputImage> +public TensorflowMultisourceModelLearningBase<TInputImage> { public: /** Standard class typedefs. */ - typedef TensorflowMultisourceModelTrain Self; - typedef TensorflowMultisourceModelBase<TInputImage> Superclass; - typedef itk::SmartPointer<Self> Pointer; - typedef itk::SmartPointer<const Self> ConstPointer; + typedef TensorflowMultisourceModelTrain Self; + typedef TensorflowMultisourceModelLearningBase<TInputImage> Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; /** Method for creation through the object factory. */ itkNewMacro(Self); /** Run-time type information (and related methods). */ - itkTypeMacro(TensorflowMultisourceModelTrain, TensorflowMultisourceModelBase); - - /** Images typedefs */ - typedef typename Superclass::ImageType ImageType; - typedef typename Superclass::ImagePointerType ImagePointerType; - typedef typename Superclass::RegionType RegionType; - typedef typename Superclass::SizeType SizeType; - typedef typename Superclass::IndexType IndexType; - - /* Typedefs for parameters */ - typedef typename Superclass::DictType DictType; - typedef typename Superclass::StringList StringList; - typedef typename Superclass::SizeListType SizeListType; - typedef typename Superclass::DictListType DictListType; - typedef typename Superclass::TensorListType TensorListType; - - itkSetMacro(BatchSize, unsigned int); - itkGetMacro(BatchSize, unsigned int); - itkGetMacro(NumberOfSamples, unsigned int); - - virtual void GenerateOutputInformation(void); + itkTypeMacro(TensorflowMultisourceModelTrain, TensorflowMultisourceModelLearningBase); - virtual void GenerateInputRequestedRegion(); + /** Superclass typedefs */ + typedef typename Superclass::IndexValueType IndexValueType; + typedef typename Superclass::TensorListType TensorListType; + typedef typename Superclass::IndexListType IndexListType; - virtual void GenerateData(); protected: TensorflowMultisourceModelTrain(); virtual ~TensorflowMultisourceModelTrain() {}; + void GenerateData(); + void ProcessBatch(TensorListType & inputs, const IndexValueType & sampleStart, + const IndexValueType & batchSize); + private: TensorflowMultisourceModelTrain(const Self&); //purposely not implemented void operator=(const Self&); //purposely not implemented - unsigned int m_BatchSize; // Batch size - - // Read only - unsigned int m_NumberOfSamples; // Number of samples + IndexListType m_RandomIndices; // Reordered indices }; // end class diff --git a/include/otbTensorflowMultisourceModelTrain.hxx b/include/otbTensorflowMultisourceModelTrain.hxx index aa47237c30ba886c5bc7b9effa5028d9cdefde91..2ad029129f0c75bf940cd7b8badd025a98e564ac 100644 --- a/include/otbTensorflowMultisourceModelTrain.hxx +++ b/include/otbTensorflowMultisourceModelTrain.hxx @@ -20,183 +20,44 @@ template <class TInputImage> TensorflowMultisourceModelTrain<TInputImage> ::TensorflowMultisourceModelTrain() { - m_BatchSize = 100; - m_NumberOfSamples = 0; } - template <class TInputImage> void TensorflowMultisourceModelTrain<TInputImage> -::GenerateOutputInformation() +::GenerateData() { - Superclass::GenerateOutputInformation(); - - ImageType * outputPtr = this->GetOutput(); - RegionType nullRegion; - nullRegion.GetModifiableSize().Fill(1); - outputPtr->SetNumberOfComponentsPerPixel(1); - outputPtr->SetLargestPossibleRegion( nullRegion ); - - ////////////////////////////////////////////////////////////////////////////////////////// - // Check the number of samples - ////////////////////////////////////////////////////////////////////////////////////////// - - m_NumberOfSamples = 0; - for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++) - { - // Input image pointer - ImagePointerType inputPtr = const_cast<ImageType*>(this->GetInput(i)); - - // Make sure input is available - if ( inputPtr.IsNull() ) - { - itkExceptionMacro(<< "Input " << i << " is null!"); - } - - // Update input information - inputPtr->UpdateOutputInformation(); - - // Patch size of tensor #i - const SizeType inputPatchSize = this->GetInputFOVSizes().at(i); - - // Input image requested region - const RegionType reqRegion = inputPtr->GetLargestPossibleRegion(); - - // Check size X - if (inputPatchSize[0] != reqRegion.GetSize(0)) - itkExceptionMacro("Patch size for input " << i << " is " << inputPatchSize << - " but input patches image size is " << reqRegion.GetSize()); - - // Check size Y - if (reqRegion.GetSize(1) % inputPatchSize[1] != 0) - itkExceptionMacro("Input patches image must have a number of rows which is " - "a multiple of the patch size Y! Patches image has " << reqRegion.GetSize(1) << - " rows but patch size Y is " << inputPatchSize[1] << " for input " << i); - - // Get the batch size - const tensorflow::uint64 currNumberOfSamples = reqRegion.GetSize(1) / inputPatchSize[1]; - - // Check the consistency with other inputs - if (m_NumberOfSamples == 0) - { - m_NumberOfSamples = currNumberOfSamples; - } - else if (m_NumberOfSamples != currNumberOfSamples) - { - itkGenericExceptionMacro("Previous batch size is " << m_NumberOfSamples << " but input " << i - << " has a batch size of " << currNumberOfSamples ); - } - } // next input - } -template <class TInputImage> -void -TensorflowMultisourceModelTrain<TInputImage> -::GenerateInputRequestedRegion() - { - Superclass::GenerateInputRequestedRegion(); - - // For each image, set no image region - for(unsigned int i = 0; i < this->GetNumberOfInputs(); ++i) - { - RegionType nullRegion; - ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(i) ); -// TODO: streaming mode on/off -// inputImage->SetRequestedRegion(nullRegion); - inputImage->SetRequestedRegion(inputImage->GetLargestPossibleRegion()); - } // next image + // Initial sequence 1...N + m_RandomIndices.resize(this->GetNumberOfSamples()); + std::iota (std::begin(m_RandomIndices), std::end(m_RandomIndices), 0); + + // Shuffle the sequence + std::random_device rd; + std::mt19937 g(rd()); + std::shuffle(m_RandomIndices.begin(), m_RandomIndices.end(), g); + + // Call the generic method + Superclass::GenerateData(); + } -/** - * - */ template <class TInputImage> void TensorflowMultisourceModelTrain<TInputImage> -::GenerateData() +::ProcessBatch(TensorListType & inputs, const IndexValueType & sampleStart, + const IndexValueType & batchSize) { + // Populate input tensors + PopulateInputTensor(inputs, sampleStart, batchSize, m_RandomIndices); - // Random sequence - std::vector<int> v(m_NumberOfSamples) ; - std::iota (std::begin(v), std::end(v), 0); - - // Shuffle - std::random_device rd; - std::mt19937 g(rd()); - std::shuffle(v.begin(), v.end(), g); - - // Batches loop - const tensorflow::uint64 nBatches = vcl_ceil(m_NumberOfSamples / m_BatchSize); - const tensorflow::uint64 rest = m_NumberOfSamples % m_BatchSize; - itk::ProgressReporter progress(this, 0, nBatches); - for (tensorflow::uint64 batch = 0 ; batch < nBatches ; batch++) - { - // Update progress - this->UpdateProgress((float) batch / (float) nBatches); - - // Create input tensors list - DictListType inputs; - - // Batch start and size - const tensorflow::uint64 sampleStart = batch * m_BatchSize; - tensorflow::uint64 batchSize = m_BatchSize; - if (rest != 0) - { - batchSize = rest; - } - - // Populate input tensors - for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++) - { - // Input image pointer - ImagePointerType inputPtr = const_cast<ImageType*>(this->GetInput(i)); - - // Patch size of tensor #i - const SizeType inputPatchSize = this->GetInputFOVSizes().at(i); - - // Create the tensor for the batch - const tensorflow::int64 sz_n = batchSize; - const tensorflow::int64 sz_y = inputPatchSize[1]; - const tensorflow::int64 sz_x = inputPatchSize[0]; - const tensorflow::int64 sz_c = inputPtr->GetNumberOfComponentsPerPixel(); - const tensorflow::TensorShape inputTensorShape({sz_n, sz_y, sz_x, sz_c}); - tensorflow::Tensor inputTensor(this->GetInputTensorsDataTypes()[i], inputTensorShape); - - // Populate the tensor - for (tensorflow::uint64 elem = 0 ; elem < batchSize ; elem++) - { - const tensorflow::uint64 samplePos = sampleStart + elem; - const tensorflow::uint64 randPos = v[samplePos]; - IndexType start; - start[0] = 0; - start[1] = randPos * sz_y; - RegionType patchRegion(start, inputPatchSize); -// TODO: streaming mode on/off -// tf::PropagateRequestedRegion<TInputImage>(inputPtr, patchRegion); - tf::RecopyImageRegionToTensorWithCast<TInputImage>(inputPtr, patchRegion, inputTensor, elem ); - } - - // Input #i : the tensor of patches (aka the batch) - DictType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor }; - inputs.push_back(input1); - } // next input tensor - - // Run the TF session here - TensorListType outputs; - this->RunSession(inputs, outputs); - - // Get output tensors - for (auto& output: outputs) - { - std::cout << tf::PrintTensorInfos(output) << std::endl; - } - - progress.CompletedPixel(); - } // Next batch + // Run the TF session here + TensorListType outputs; + this->RunSession(inputs, outputs); } + } // end namespace otb diff --git a/include/otbTensorflowMultisourceModelValidate.h b/include/otbTensorflowMultisourceModelValidate.h index 868317f9344ba1567deaf71aecd7b18a2ccd635f..048ff70cdfd4319abeba653c0121940a05af6127 100644 --- a/include/otbTensorflowMultisourceModelValidate.h +++ b/include/otbTensorflowMultisourceModelValidate.h @@ -33,16 +33,12 @@ namespace otb * \brief This filter validates a TensorFlow model over multiple input images. * * The filter takes N input images and feed the TensorFlow model. - * Names of input placeholders must be specified using the - * SetInputPlaceholdersNames method - * - * TODO: Add an option to disable streaming * * \ingroup OTBTensorflow */ template <class TInputImage> class ITK_EXPORT TensorflowMultisourceModelValidate : -public TensorflowMultisourceModelBase<TInputImage> +public TensorflowMultisourceModelLearningBase<TInputImage> { public: @@ -56,7 +52,7 @@ public: itkNewMacro(Self); /** Run-time type information (and related methods). */ - itkTypeMacro(TensorflowMultisourceModelValidate, TensorflowMultisourceModelBase); + itkTypeMacro(TensorflowMultisourceModelValidate, TensorflowMultisourceModelLearningBase); /** Images typedefs */ typedef typename Superclass::ImageType ImageType; @@ -70,8 +66,8 @@ public: typedef typename Superclass::DictType DictType; typedef typename Superclass::StringList StringList; typedef typename Superclass::SizeListType SizeListType; - typedef typename Superclass::DictListType DictListType; typedef typename Superclass::TensorListType TensorListType; + typedef typename Superclass::IndexValueType IndexValueType; /* Typedefs for validation */ typedef unsigned long CountValueType; @@ -84,23 +80,11 @@ public: typedef std::vector<ConfMatType> ConfMatListType; typedef itk::ImageRegionConstIterator<ImageType> IteratorType; - /* Set and Get the batch size - itkSetMacro(BatchSize, unsigned int); - itkGetMacro(BatchSize, unsigned int); - - /** Get the number of samples */ - itkGetMacro(NumberOfSamples, unsigned int); - - virtual void GenerateOutputInformation(void); - - virtual void GenerateInputRequestedRegion(); /** Set and Get the input references */ virtual void SetInputReferences(ImageListType input); ImagePointerType GetInputReference(unsigned int index); - virtual void GenerateData(); - /** Get the confusion matrix */ const ConfMatType GetConfusionMatrix(unsigned int target); @@ -111,15 +95,18 @@ protected: TensorflowMultisourceModelValidate(); virtual ~TensorflowMultisourceModelValidate() {}; + void GenerateOutputInformation(void); + void GenerateData(); + void ProcessBatch(TensorListType & inputs, const IndexValueType & sampleStart, + const IndexValueType & batchSize); + private: TensorflowMultisourceModelValidate(const Self&); //purposely not implemented void operator=(const Self&); //purposely not implemented - unsigned int m_BatchSize; // Batch size ImageListType m_References; // The references images // Read only - unsigned int m_NumberOfSamples; // Number of samples ConfMatListType m_ConfusionMatrices; // Confusion matrix MapOfClassesListType m_MapsOfClasses; // Maps of classes diff --git a/include/otbTensorflowMultisourceModelValidate.hxx b/include/otbTensorflowMultisourceModelValidate.hxx index b0ec5e2dd264aa5369cf6f7e4d364a7cbc9fbbdb..3f4fd9b91973eac710a7b9495664cdffef2e5887 100644 --- a/include/otbTensorflowMultisourceModelValidate.hxx +++ b/include/otbTensorflowMultisourceModelValidate.hxx @@ -20,8 +20,6 @@ template <class TInputImage> TensorflowMultisourceModelValidate<TInputImage> ::TensorflowMultisourceModelValidate() { - m_BatchSize = 100; - m_NumberOfSamples = 0; } @@ -32,63 +30,6 @@ TensorflowMultisourceModelValidate<TInputImage> { Superclass::GenerateOutputInformation(); - ImageType * outputPtr = this->GetOutput(); - RegionType nullRegion; - nullRegion.GetModifiableSize().Fill(1); - outputPtr->SetNumberOfComponentsPerPixel(1); - outputPtr->SetLargestPossibleRegion( nullRegion ); - - ////////////////////////////////////////////////////////////////////////////////////////// - // Check the number of samples - ////////////////////////////////////////////////////////////////////////////////////////// - - m_NumberOfSamples = 0; - for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++) - { - // Input image pointer - ImagePointerType inputPtr = const_cast<ImageType*>(this->GetInput(i)); - - // Make sure input is available - if ( inputPtr.IsNull() ) - { - itkExceptionMacro(<< "Input " << i << " is null!"); - } - - // Update input information - inputPtr->UpdateOutputInformation(); - - // Patch size of tensor #i - const SizeType inputPatchSize = this->GetInputFOVSizes().at(i); - - // Input image requested region - const RegionType reqRegion = inputPtr->GetLargestPossibleRegion(); - - // Check size X - if (inputPatchSize[0] != reqRegion.GetSize(0)) - itkExceptionMacro("Patch size for input " << i << " is " << inputPatchSize << - " but input patches image size is " << reqRegion.GetSize()); - - // Check size Y - if (reqRegion.GetSize(1) % inputPatchSize[1] != 0) - itkExceptionMacro("Input patches image must have a number of rows which is " - "a multiple of the patch size Y! Patches image has " << reqRegion.GetSize(1) << - " rows but patch size Y is " << inputPatchSize[1] << " for input " << i); - - // Get the batch size - const tensorflow::uint64 currNumberOfSamples = reqRegion.GetSize(1) / inputPatchSize[1]; - - // Check the consistency with other inputs - if (m_NumberOfSamples == 0) - { - m_NumberOfSamples = currNumberOfSamples; - } - else if (m_NumberOfSamples != currNumberOfSamples) - { - itkGenericExceptionMacro("Previous batch size is " << m_NumberOfSamples << " but input " << i - << " has a batch size of " << currNumberOfSamples ); - } - } // next input - ////////////////////////////////////////////////////////////////////////////////////////// // Check the references ////////////////////////////////////////////////////////////////////////////////////////// @@ -98,7 +39,7 @@ TensorflowMultisourceModelValidate<TInputImage> { itkExceptionMacro("No reference is set"); } - SizeListType outputEFSizes = this->GetOutputFOESizes(); + SizeListType outputEFSizes = this->GetOutputExpressionFields(); if (nbOfRefs != outputEFSizes.size()) { itkExceptionMacro("There is " << nbOfRefs << " but only " << @@ -115,32 +56,16 @@ TensorflowMultisourceModelValidate<TInputImage> itkExceptionMacro("Reference image " << i << " width is " << refRegion.GetSize(0) << " but field of expression width is " << outputFOESize[0]); } - if (refRegion.GetSize(1) / outputFOESize[1] != m_NumberOfSamples) + if (refRegion.GetSize(1) / outputFOESize[1] != this->GetNumberOfSamples()) { itkExceptionMacro("Reference image " << i << " height is " << refRegion.GetSize(1) << " but field of expression width is " << outputFOESize[1] << - " which is not consistent with the number of samples (" << m_NumberOfSamples << ")"); + " which is not consistent with the number of samples (" << this->GetNumberOfSamples() << ")"); } } } -template <class TInputImage> -void -TensorflowMultisourceModelValidate<TInputImage> -::GenerateInputRequestedRegion() - { - Superclass::GenerateInputRequestedRegion(); - - // For each image, set no image region - for(unsigned int i = 0; i < this->GetNumberOfInputs(); ++i) - { - RegionType nullRegion; - ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(i) ); - inputImage->SetRequestedRegion(nullRegion); - } // next image - - } /* * Set the references images @@ -195,127 +120,8 @@ TensorflowMultisourceModelValidate<TInputImage> confMatMaps.push_back(mat); } - // Batches loop - const tensorflow::uint64 nBatches = vcl_ceil(m_NumberOfSamples / m_BatchSize); - const tensorflow::uint64 rest = m_NumberOfSamples % m_BatchSize; - itk::ProgressReporter progress(this, 0, nBatches); - for (tensorflow::uint64 batch = 0 ; batch < nBatches ; batch++) - { - // Update progress - this->UpdateProgress((float) batch / (float) nBatches); - - // Sample start of this batch - const tensorflow::uint64 sampleStart = batch * m_BatchSize; - tensorflow::uint64 batchSize = m_BatchSize; - if (rest != 0) - { - batchSize = rest; - } - - // Create input tensors list - DictListType inputs; - - // Populate input tensors - for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++) - { - // Input image pointer - ImagePointerType inputPtr = const_cast<ImageType*>(this->GetInput(i)); - - // Patch size of tensor #i - const SizeType inputPatchSize = this->GetInputFOVSizes().at(i); - - // Create the tensor for the batch - const tensorflow::int64 sz_n = batchSize; - const tensorflow::int64 sz_y = inputPatchSize[1]; - const tensorflow::int64 sz_x = inputPatchSize[0]; - const tensorflow::int64 sz_c = inputPtr->GetNumberOfComponentsPerPixel(); - const tensorflow::TensorShape inputTensorShape({sz_n, sz_y, sz_x, sz_c}); - tensorflow::Tensor inputTensor(this->GetInputTensorsDataTypes()[i], inputTensorShape); - - // Populate the tensor - for (tensorflow::uint64 elem = 0 ; elem < batchSize ; elem++) - { - const tensorflow::uint64 samplePos = sampleStart + elem; - IndexType start; - start[0] = 0; - start[1] = samplePos * sz_y; - RegionType patchRegion(start, inputPatchSize); - tf::PropagateRequestedRegion<TInputImage>(inputPtr, patchRegion); - tf::RecopyImageRegionToTensorWithCast<TInputImage>(inputPtr, patchRegion, inputTensor, elem ); - } - - // Input #i : the tensor of patches (aka the batch) - DictType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor }; - inputs.push_back(input1); - } // next input tensor - - // Run the TF session here - TensorListType outputs; - this->RunSession(inputs, outputs); - - // Perform the validation - if (outputs.size() != m_References.size()) - { - itkWarningMacro("There is " << outputs.size() << " outputs returned after session run, " << - "but only " << m_References.size() << " reference(s) set"); - } - SizeListType outputEFSizes = this->GetOutputFOESizes(); - for (unsigned int refIdx = 0 ; refIdx < outputs.size() ; refIdx++) - { - // Recopy the chunk - const SizeType outputFOESize = outputEFSizes[refIdx]; - IndexType cpyStart; - cpyStart.Fill(0); - IndexType refRegStart; - refRegStart.Fill(0); - refRegStart[1] = outputFOESize[1] * sampleStart; - SizeType cpySize; - cpySize[0] = outputFOESize[0]; - cpySize[1] = outputFOESize[1] * batchSize; - RegionType cpyRegion(cpyStart, cpySize); - RegionType refRegion(refRegStart, cpySize); - - // Allocate a temporary image - ImagePointerType img = ImageType::New(); - img->SetRegions(cpyRegion); - img->SetNumberOfComponentsPerPixel(1); - img->Allocate(); - - int co = 0; - tf::CopyTensorToImageRegion<TInputImage>(outputs[refIdx], cpyRegion, img, cpyRegion, co); - - // Retrieve the reference image region - tf::PropagateRequestedRegion<TInputImage>(m_References[refIdx], refRegion); - - // Update the confusion matrices - IteratorType inIt(img, cpyRegion); - IteratorType refIt(m_References[refIdx], refRegion); - for (inIt.GoToBegin(), refIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt, ++refIt) - { - const int classIn = static_cast<LabelValueType>(inIt.Get()[0]); - const int classRef = static_cast<LabelValueType>(refIt.Get()[0]); - - if (confMatMaps[refIdx].count(classRef) == 0) - { - MapType newMap; - newMap[classIn] = 1; - confMatMaps[refIdx][classRef] = newMap; - } - else - { - if (confMatMaps[refIdx][classRef].count(classIn) == 0) - { - confMatMaps[refIdx][classRef][classIn] = 1; - } - else - { - confMatMaps[refIdx][classRef][classIn]++; - } - } - } - } - progress.CompletedPixel(); - } // Next batch + // Run all the batches + Superclass::GenerateData(); // Compute confusion matrices for (unsigned int i = 0 ; i < confMatMaps.size() ; i++) @@ -353,6 +159,88 @@ TensorflowMultisourceModelValidate<TInputImage> m_ConfusionMatrices.push_back(matrix); m_MapsOfClasses.push_back(values); + + + + + } + + } + + +template <class TInputImage> +void +TensorflowMultisourceModelValidate<TInputImage> +::ProcessBatch(TensorListType & inputs, const IndexValueType & sampleStart, + const IndexValueType & batchSize) + { + // Populate input tensors + PopulateInputTensor(inputs, sampleStart, batchSize); + + // Run the TF session here + TensorListType outputs; + this->RunSession(inputs, outputs); + + // Perform the validation + if (outputs.size() != m_References.size()) + { + itkWarningMacro("There is " << outputs.size() << " outputs returned after session run, " << + "but only " << m_References.size() << " reference(s) set"); + } + SizeListType outputEFSizes = this->GetOutputExpressionFields(); + for (unsigned int refIdx = 0 ; refIdx < outputs.size() ; refIdx++) + { + // Recopy the chunk + const SizeType outputFOESize = outputEFSizes[refIdx]; + IndexType cpyStart; + cpyStart.Fill(0); + IndexType refRegStart; + refRegStart.Fill(0); + refRegStart[1] = outputFOESize[1] * sampleStart; + SizeType cpySize; + cpySize[0] = outputFOESize[0]; + cpySize[1] = outputFOESize[1] * batchSize; + RegionType cpyRegion(cpyStart, cpySize); + RegionType refRegion(refRegStart, cpySize); + + // Allocate a temporary image + ImagePointerType img = ImageType::New(); + img->SetRegions(cpyRegion); + img->SetNumberOfComponentsPerPixel(1); + img->Allocate(); + + int co = 0; + tf::CopyTensorToImageRegion<TInputImage>(outputs[refIdx], cpyRegion, img, cpyRegion, co); + + // Retrieve the reference image region + tf::PropagateRequestedRegion<TInputImage>(m_References[refIdx], refRegion); + + // Update the confusion matrices + IteratorType inIt(img, cpyRegion); + IteratorType refIt(m_References[refIdx], refRegion); + for (inIt.GoToBegin(), refIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt, ++refIt) + { + const int classIn = static_cast<LabelValueType>(inIt.Get()[0]); + const int classRef = static_cast<LabelValueType>(refIt.Get()[0]); + + if (confMatMaps[refIdx].count(classRef) == 0) + { + MapType newMap; + newMap[classIn] = 1; + confMatMaps[refIdx][classRef] = newMap; + } + else + { + if (confMatMaps[refIdx][classRef].count(classIn) == 0) + { + confMatMaps[refIdx][classRef][classIn] = 1; + } + else + { + confMatMaps[refIdx][classRef][classIn]++; + } + } + } } }