diff --git a/include/otbTensorflowMultisourceModelBase.h b/include/otbTensorflowMultisourceModelBase.h
index e0c23846fd45110efe30704acde79c35060e5310..8f4440730946d947f83fb92d960e027485be5ec0 100644
--- a/include/otbTensorflowMultisourceModelBase.h
+++ b/include/otbTensorflowMultisourceModelBase.h
@@ -72,10 +72,10 @@ public:
   typedef typename TInputImage::RegionType           RegionType;
 
   /** Typedefs for parameters */
-  typedef std::pair<std::string, tensorflow::Tensor> DictType;
+  typedef std::pair<std::string, tensorflow::Tensor> DictElementType;
   typedef std::vector<std::string>                   StringList;
   typedef std::vector<SizeType>                      SizeListType;
-  typedef std::vector<DictType>                      DictListType;
+  typedef std::vector<DictElementType>               DictType;
   typedef std::vector<tensorflow::DataType>          DataTypeListType;
   typedef std::vector<tensorflow::TensorShapeProto>  TensorShapeProtoList;
   typedef std::vector<tensorflow::Tensor>            TensorListType;
@@ -87,27 +87,28 @@ public:
   tensorflow::Session * GetSession()             { return m_Session;    }
 
   /** Model parameters */
-  void PushBackInputBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image);
+  void PushBackInputTensorBundle(std::string name, SizeType receptiveField, ImagePointerType image);
+  void PushBackOuputTensorBundle(std::string name, SizeType expressionField);
 
-//  /** Input placeholders names */
-//  itkSetMacro(InputPlaceholdersNames, StringList);
-  itkGetMacro(InputPlaceholdersNames, StringList);
-//
-//  /** Receptive field */
-//  itkSetMacro(InputFOVSizes, SizeListType);
-  itkGetMacro(InputFOVSizes, SizeListType);
+  /** Input placeholders names */
+  itkSetMacro(InputPlaceholders, StringList);
+  itkGetMacro(InputPlaceholders, StringList);
+
+  /** Receptive field */
+  itkSetMacro(InputReceptiveFields, SizeListType);
+  itkGetMacro(InputReceptiveFields, SizeListType);
 
   /** Output tensors names */
-  itkSetMacro(OutputTensorsNames, StringList);
-  itkGetMacro(OutputTensorsNames, StringList);
+  itkSetMacro(OutputTensors, StringList);
+  itkGetMacro(OutputTensors, StringList);
 
   /** Expression field */
-  itkSetMacro(OutputFOESizes, SizeListType);
-  itkGetMacro(OutputFOESizes, SizeListType);
+  itkSetMacro(OutputExpressionFields, SizeListType);
+  itkGetMacro(OutputExpressionFields, SizeListType);
 
   /** User placeholders */
-  void SetUserPlaceholders(DictListType dict) { m_UserPlaceholders = dict; }
-  DictListType GetUserPlaceholders()          { return m_UserPlaceholders; }
+  void SetUserPlaceholders(DictType dict) { m_UserPlaceholders = dict; }
+  DictType GetUserPlaceholders()          { return m_UserPlaceholders; }
 
   /** Target nodes names */
   itkSetMacro(TargetNodesNames, StringList);
@@ -125,9 +126,9 @@ protected:
   TensorflowMultisourceModelBase();
   virtual ~TensorflowMultisourceModelBase() {};
 
-  virtual std::stringstream GenerateDebugReport(DictListType & inputs, TensorListType & outputs);
+  virtual std::stringstream GenerateDebugReport(DictType & inputs);
 
-  virtual void RunSession(DictListType & inputs, TensorListType & outputs);
+  virtual void RunSession(DictType & inputs, TensorListType & outputs);
 
 private:
   TensorflowMultisourceModelBase(const Self&); //purposely not implemented
@@ -138,11 +139,11 @@ private:
   tensorflow::Session *      m_Session;                 // The tensorflow session
 
   // Model parameters
-  StringList                 m_InputPlaceholdersNames;  // Input placeholders names
-  SizeListType               m_InputFOVSizes;           // Input tensors field of view (FOV) sizes
-  SizeListType               m_OutputFOESizes;          // Output tensors field of expression (FOE) sizes
-  DictListType               m_UserPlaceholders;        // User placeholders
-  StringList                 m_OutputTensorsNames;      // User tensors
+  StringList                 m_InputPlaceholders;       // Input placeholders names
+  SizeListType               m_InputReceptiveFields;    // Input receptive fields
+  StringList                 m_OutputTensors;           // Output tensors names
+  SizeListType               m_OutputExpressionFields;  // Output expression fields
+  DictType                   m_UserPlaceholders;        // User placeholders
   StringList                 m_TargetNodesNames;        // User target tensors
 
   // Read-only
diff --git a/include/otbTensorflowMultisourceModelBase.hxx b/include/otbTensorflowMultisourceModelBase.hxx
index 07b2cf71d5afdab187b494d3573a44dcd74c07d3..938c903544d8d5eed0fddf6596f0224ae019ca1f 100644
--- a/include/otbTensorflowMultisourceModelBase.hxx
+++ b/include/otbTensorflowMultisourceModelBase.hxx
@@ -20,22 +20,23 @@ template <class TInputImage, class TOutputImage>
 TensorflowMultisourceModelBase<TInputImage, TOutputImage>
 ::TensorflowMultisourceModelBase()
  {
+  m_Session = nullptr;
  }
 
 template <class TInputImage, class TOutputImage>
 void
 TensorflowMultisourceModelBase<TInputImage, TOutputImage>
-::PushBackInputBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image)
+::PushBackInputTensorBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image)
  {
   Superclass::PushBackInput(image);
-  m_InputFOVSizes.push_back(receptiveField);
-  m_InputPlaceholdersNames.push_back(placeholder);
+  m_InputReceptiveFields.push_back(receptiveField);
+  m_InputPlaceholders.push_back(placeholder);
  }
 
 template <class TInputImage, class TOutputImage>
 std::stringstream
 TensorflowMultisourceModelBase<TInputImage, TOutputImage>
-::GenerateDebugReport(DictListType & inputs, TensorListType & outputs)
+::GenerateDebugReport(DictType & inputs)
  {
   // Create a debug report
   std::stringstream debugReport;
@@ -69,7 +70,7 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
 template <class TInputImage, class TOutputImage>
 void
 TensorflowMultisourceModelBase<TInputImage, TOutputImage>
-::RunSession(DictListType & inputs, TensorListType & outputs)
+::RunSession(DictType & inputs, TensorListType & outputs)
  {
 
   // Add the user's placeholders
@@ -82,11 +83,11 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
   // The session will initialize the outputs
 
   // Run the session, evaluating our output tensors from the graph
-  auto status = this->GetSession()->Run(inputs, m_OutputTensorsNames, m_TargetNodesNames, &outputs);
+  auto status = this->GetSession()->Run(inputs, m_OutputTensors, m_TargetNodesNames, &outputs);
   if (!status.ok()) {
 
     // Create a debug report
-    std::stringstream debugReport = GenerateDebugReport(inputs, outputs);
+    std::stringstream debugReport = GenerateDebugReport(inputs);
 
     // Throw an exception with the report
     itkExceptionMacro("Can't run the tensorflow session !\n" <<
@@ -108,11 +109,11 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
   // - patches sizes
   // - input image
   const unsigned int nbInputs = this->GetNumberOfInputs();
-  if (nbInputs != m_InputFOVSizes.size() || nbInputs != m_InputPlaceholdersNames.size())
+  if (nbInputs != m_InputReceptiveFields.size() || nbInputs != m_InputPlaceholders.size())
     {
     itkExceptionMacro("Number of input images is " << nbInputs <<
-                      " but the number of input patches size is " << m_InputFOVSizes.size() <<
-                      " and the number of input tensors names is " << m_InputPlaceholdersNames.size());
+                      " but the number of input patches size is " << m_InputReceptiveFields.size() <<
+                      " and the number of input tensors names is " << m_InputPlaceholders.size());
     }
 
   //////////////////////////////////////////////////////////////////////////////////////////
@@ -120,8 +121,8 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
   //////////////////////////////////////////////////////////////////////////////////////////
 
   // Get input and output tensors datatypes and shapes
-  tf::GetTensorAttributes(m_Graph, m_InputPlaceholdersNames, m_InputTensorsShapes, m_InputTensorsDataTypes);
-  tf::GetTensorAttributes(m_Graph, m_OutputTensorsNames, m_OutputTensorsShapes, m_OutputTensorsDataTypes);
+  tf::GetTensorAttributes(m_Graph, m_InputPlaceholders, m_InputTensorsShapes, m_InputTensorsDataTypes);
+  tf::GetTensorAttributes(m_Graph, m_OutputTensors, m_OutputTensorsShapes, m_OutputTensorsDataTypes);
 
  }
 
diff --git a/include/otbTensorflowMultisourceModelFilter.h b/include/otbTensorflowMultisourceModelFilter.h
index 833294ada5c05bf469c509aa58ea69ca04cfd1d4..fdc799a5cd69f86a77bc396f3f159f4ffde101e5 100644
--- a/include/otbTensorflowMultisourceModelFilter.h
+++ b/include/otbTensorflowMultisourceModelFilter.h
@@ -30,27 +30,26 @@ namespace otb
  *
  * The filter takes N input images and feed the TensorFlow model to produce
  * one output image of desired TF op results.
- * Names of input/output placeholders/tensors must be specified using the
- * SetInputPlaceholdersNames/SetOutputTensorNames.
+ * Names of input placeholders and output tensors must be specified using the
+ * SetPlaceholders() and SetTensors() methods.
  *
- * Example: we have a tensorflow model which runs the input images "x1" and "x2"
+ * Example: we have a TensorFlow model which runs the input images "x1" and "x2"
  *          and produces the output image "y".
- *          "x1" and "x2" are two TF placeholders, we set InputTensorNames={"x1","x2"}
- *          "y1" corresponds to one TF op output, we set OutputTensorNames={"y1"}
+ *          "x1" and "x2" are two TF placeholders, we set InputPlaceholder={"x1","x2"}
+ *          "y1" corresponds to one TF op output, we set OutputTensors={"y1"}
  *
  * The reference grid for the output image is the same as the first input image.
  * This grid can be scaled by setting the OutputSpacingScale value.
  * This can be used to run models which downsize the output image spacing
- * (typically fully convolutional model with strides) or to produce the result
+ * (e.g. fully convolutional model with strides) or to produce the result
  * of a patch-based network at regular intervals.
  *
- * For each input, input field of view (FOV) must be set.
+ * For each input (resp. output), receptive field (resp. expression field) must be set.
  * If the number of values in the output tensors (produced by the model) don't
- * fit with the output image region, exception will be thrown.
+ * fit with the output image region, an exception will be thrown.
  *
- *
- * The tensorflow Graph is passed using the SetGraph() method
- * The tensorflow Session is passed using the SetSession() method
+ * The TensorFlow Graph is passed using the SetGraph() method
+ * The TensorFlow Session is passed using the SetSession() method
  *
  * \ingroup OTBTensorflow
  */
@@ -94,6 +93,7 @@ public:
   typedef typename itk::ImageRegionConstIterator<TInputImage>              InputConstIteratorType;
 
   /* Typedefs for parameters */
+  typedef typename Superclass::DictElementType     DictElementType;
   typedef typename Superclass::DictType            DictType;
   typedef typename Superclass::StringList          StringList;
   typedef typename Superclass::SizeListType        SizeListType;
@@ -101,8 +101,6 @@ public:
   typedef typename Superclass::TensorListType      TensorListType;
   typedef std::vector<float>                       ScaleListType;
 
-  itkSetMacro(OutputFOESize, SizeType);
-  itkGetMacro(OutputFOESize, SizeType);
   itkSetMacro(OutputGridSize, SizeType);
   itkGetMacro(OutputGridSize, SizeType);
   itkSetMacro(ForceOutputGridSize, bool);
@@ -132,7 +130,6 @@ private:
   TensorflowMultisourceModelFilter(const Self&); //purposely not implemented
   void operator=(const Self&); //purposely not implemented
 
-  SizeType                   m_OutputFOESize;        // Output tensors field of expression (FOE) sizes
   SizeType                   m_OutputGridSize;       // Output grid size
   bool                       m_ForceOutputGridSize;  // Force output grid size
   bool                       m_FullyConvolutional;   // Convolution mode
diff --git a/include/otbTensorflowMultisourceModelFilter.hxx b/include/otbTensorflowMultisourceModelFilter.hxx
index c7a547e240b2ccc8e60a0ced3d5c734b501537bd..e5fa16b9b2947bc91bd7adf337a971473ea8a313 100644
--- a/include/otbTensorflowMultisourceModelFilter.hxx
+++ b/include/otbTensorflowMultisourceModelFilter.hxx
@@ -216,7 +216,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
 
     // Update output image extent
     PointType currentInputImageExtentInf, currentInputImageExtentSup;
-    ImageToExtent(currentImage, currentInputImageExtentInf, currentInputImageExtentSup, this->GetInputFOVSizes()[imageIndex]);
+    ImageToExtent(currentImage, currentInputImageExtentInf, currentInputImageExtentSup, this->GetInputReceptiveFields()[imageIndex]);
     for(unsigned int dim = 0; dim<ImageType::ImageDimension; ++dim)
       {
       extentInf[dim] = vnl_math_max(currentInputImageExtentInf[dim], extentInf[dim]);
@@ -236,7 +236,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
   if (!m_ForceOutputGridSize)
     {
     // Default is the output field of expression
-    m_OutputGridSize = m_OutputFOESize;
+    m_OutputGridSize = this->GetOutputExpressionFields().at(0);
     }
 
   // Resize the largestPossibleRegion to be a multiple of the grid size
@@ -315,9 +315,9 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
       }
 
     // Compute the FOV-scale*FOE radius to pad
-    SizeType toPad(this->GetInputFOVSizes().at(i));
-    toPad[0] -= 1 + (m_OutputFOESize[0] - 1) * m_OutputSpacingScale;
-    toPad[1] -= 1 + (m_OutputFOESize[1] - 1) * m_OutputSpacingScale;
+    SizeType toPad(this->GetInputReceptiveFields().at(i));
+    toPad[0] -= 1 + (this->GetOutputExpressionFields().at(0)[0] - 1) * m_OutputSpacingScale;
+    toPad[1] -= 1 + (this->GetOutputExpressionFields().at(0)[1] - 1) * m_OutputSpacingScale;
 
     // Pad with radius
     SmartPad(inRegion, toPad);
@@ -365,7 +365,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
   const unsigned int nInputs = this->GetNumberOfInputs();
 
   // Create input tensors list
-  DictListType inputs;
+  DictType inputs;
 
   // Populate input tensors
   for (unsigned int i = 0 ; i < nInputs ; i++)
@@ -374,7 +374,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
     const ImagePointerType inputPtr = const_cast<TInputImage*>(this->GetInput(i));
 
     // Patch size of tensor #i
-    const SizeType inputPatchSize = this->GetInputFOVSizes().at(i);
+    const SizeType inputPatchSize = this->GetInputReceptiveFields().at(i);
 
     // Input image requested region
     const RegionType reqRegion = inputPtr->GetRequestedRegion();
@@ -395,7 +395,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
       tf::RecopyImageRegionToTensorWithCast<TInputImage>(inputPtr, reqRegion, inputTensor, 0);
 
       // Input #1 : the tensor of patches (aka the batch)
-      DictType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor };
+      DictElementType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor };
       inputs.push_back(input1);
       }
     else
@@ -429,7 +429,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
         }
 
       // Input #1 : the tensor of patches (aka the batch)
-      DictType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor };
+      DictElementType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor };
       inputs.push_back(input1);
       } // mode is not full convolutional
 
diff --git a/include/otbTensorflowMultisourceModelLearningBase.h b/include/otbTensorflowMultisourceModelLearningBase.h
new file mode 100644
index 0000000000000000000000000000000000000000..d6feb24e83ca08eabb1e2566337511b5df0bb614
--- /dev/null
+++ b/include/otbTensorflowMultisourceModelLearningBase.h
@@ -0,0 +1,112 @@
+/*=========================================================================
+
+  Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
+
+
+     This software is distributed WITHOUT ANY WARRANTY; without even
+     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+     PURPOSE.  See the above copyright notices for more information.
+
+=========================================================================*/
+#ifndef otbTensorflowMultisourceModelLearningBase_h
+#define otbTensorflowMultisourceModelLearningBase_h
+
+#include "itkProcessObject.h"
+#include "itkNumericTraits.h"
+#include "itkSimpleDataObjectDecorator.h"
+
+// Base
+#include "otbTensorflowMultisourceModelBase.h"
+
+// Shuffle
+#include <random>
+#include <algorithm>
+#include <iterator>
+
+namespace otb
+{
+
+/**
+ * \class TensorflowMultisourceModelLearningBase
+ * \brief This filter is the base class for learning filters.
+ *
+ * \ingroup OTBTensorflow
+ */
+template <class TInputImage>
+class ITK_EXPORT TensorflowMultisourceModelLearningBase :
+public TensorflowMultisourceModelBase<TInputImage>
+{
+public:
+
+  /** Standard class typedefs. */
+  typedef TensorflowMultisourceModelLearningBase       Self;
+  typedef TensorflowMultisourceModelBase<TInputImage>  Superclass;
+  typedef itk::SmartPointer<Self>                      Pointer;
+  typedef itk::SmartPointer<const Self>                ConstPointer;
+
+  /** Run-time type information (and related methods). */
+  itkTypeMacro(TensorflowMultisourceModelLearningBase, TensorflowMultisourceModelBase);
+
+  /** Images typedefs */
+  typedef typename Superclass::ImageType         ImageType;
+  typedef typename Superclass::ImagePointerType  ImagePointerType;
+  typedef typename Superclass::RegionType        RegionType;
+  typedef typename Superclass::SizeType          SizeType;
+  typedef typename Superclass::IndexType         IndexType;
+
+  /* Typedefs for parameters */
+  typedef typename Superclass::DictType          DictType;
+  typedef typename Superclass::DictElementType   DictElementType;
+  typedef typename Superclass::StringList        StringList;
+  typedef typename Superclass::SizeListType      SizeListType;
+  typedef typename Superclass::TensorListType    TensorListType;
+
+  /* Typedefs for index */
+  typedef typename ImageType::IndexValueType     IndexValueType;
+  typedef std::vector<IndexValueType>            IndexListType;
+
+  // Batch size
+  itkSetMacro(BatchSize, IndexValueType);
+  itkGetMacro(BatchSize, IndexValueType);
+
+  // Use streaming
+  itkSetMacro(UseStreaming, bool);
+  itkGetMacro(UseStreaming, bool);
+
+  // Get number of samples
+  itkGetMacro(NumberOfSamples, IndexValueType);
+
+protected:
+  TensorflowMultisourceModelLearningBase();
+  virtual ~TensorflowMultisourceModelLearningBase() {};
+
+  virtual void GenerateOutputInformation(void);
+
+  virtual void GenerateInputRequestedRegion();
+
+  virtual void GenerateData();
+
+  virtual void PopulateInputTensors(TensorListType & inputs, const IndexValueType & sampleStart,
+      const IndexValueType & batchSize, const IndexListType & order = IndexListType());
+
+  virtual void ProcessBatch(TensorListType & inputs, const IndexValueType & sampleStart,
+      const IndexValueType & batchSize) = 0;
+
+private:
+  TensorflowMultisourceModelLearningBase(const Self&); //purposely not implemented
+  void operator=(const Self&); //purposely not implemented
+
+  unsigned int          m_BatchSize;       // Batch size
+  bool                  m_UseStreaming;    // Use streaming on/off
+
+  // Read only
+  IndexValueType        m_NumberOfSamples; // Number of samples
+
+}; // end class
+
+
+} // end namespace otb
+
+#include "otbTensorflowMultisourceModelLearningBase.hxx"
+
+#endif
diff --git a/include/otbTensorflowMultisourceModelLearningBase.hxx b/include/otbTensorflowMultisourceModelLearningBase.hxx
new file mode 100644
index 0000000000000000000000000000000000000000..9c913cd2d3d6230a0356bec742ae749f333949f8
--- /dev/null
+++ b/include/otbTensorflowMultisourceModelLearningBase.hxx
@@ -0,0 +1,211 @@
+/*=========================================================================
+
+  Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
+
+
+     This software is distributed WITHOUT ANY WARRANTY; without even
+     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+     PURPOSE.  See the above copyright notices for more information.
+
+=========================================================================*/
+#ifndef otbTensorflowMultisourceModelLearningBase_txx
+#define otbTensorflowMultisourceModelLearningBase_txx
+
+#include "otbTensorflowMultisourceModelLearningBase.h"
+
+namespace otb
+{
+
+template <class TInputImage>
+TensorflowMultisourceModelLearningBase<TInputImage>
+::TensorflowMultisourceModelLearningBase(): m_BatchSize(100),
+ m_NumberOfSamples(0), m_UseStreaming(false)
+ {
+ }
+
+
+template <class TInputImage>
+void
+TensorflowMultisourceModelLearningBase<TInputImage>
+::GenerateOutputInformation()
+ {
+  Superclass::GenerateOutputInformation();
+
+  ImageType * outputPtr = this->GetOutput();
+  RegionType nullRegion;
+  nullRegion.GetModifiableSize().Fill(1);
+  outputPtr->SetNumberOfComponentsPerPixel(1);
+  outputPtr->SetLargestPossibleRegion( nullRegion );
+
+  // Count the number of samples
+  m_NumberOfSamples = 0;
+  for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++)
+    {
+    // Input image pointer
+    ImagePointerType inputPtr = const_cast<ImageType*>(this->GetInput(i));
+
+    // Make sure input is available
+    if ( inputPtr.IsNull() )
+      {
+      itkExceptionMacro(<< "Input " << i << " is null!");
+      }
+
+    // Update input information
+    inputPtr->UpdateOutputInformation();
+
+    // Patch size of tensor #i
+    const SizeType inputPatchSize = this->GetInputReceptiveFields().at(i);
+
+    // Input image requested region
+    const RegionType reqRegion = inputPtr->GetLargestPossibleRegion();
+
+    // Check size X
+    if (inputPatchSize[0] != reqRegion.GetSize(0))
+      itkExceptionMacro("Patch size for input " << i
+          << " is " << inputPatchSize
+          << " but input patches image size is " << reqRegion.GetSize());
+
+    // Check size Y
+    if (reqRegion.GetSize(1) % inputPatchSize[1] != 0)
+      itkExceptionMacro("Input patches image must have a number of rows which is "
+          << "a multiple of the patch size Y! Patches image has " << reqRegion.GetSize(1)
+          << " rows but patch size Y is " <<  inputPatchSize[1] << " for input " << i);
+
+    // Get the batch size
+    const tensorflow::uint64 currNumberOfSamples = reqRegion.GetSize(1) / inputPatchSize[1];
+
+    // Check the consistency with other inputs
+    if (m_NumberOfSamples == 0)
+      {
+      m_NumberOfSamples = currNumberOfSamples;
+      }
+    else if (m_NumberOfSamples != currNumberOfSamples)
+      {
+      itkGenericExceptionMacro("Previous batch size is " << m_NumberOfSamples
+          << " but input " << i
+          << " has a batch size of " << currNumberOfSamples );
+      }
+    } // next input
+ }
+
+template <class TInputImage>
+void
+TensorflowMultisourceModelLearningBase<TInputImage>
+::GenerateInputRequestedRegion()
+ {
+  Superclass::GenerateInputRequestedRegion();
+
+  // For each image, set no image region
+  for(unsigned int i = 0; i < this->GetNumberOfInputs(); ++i)
+    {
+    RegionType nullRegion;
+    ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(i) );
+
+    // If the streaming is enabled, we don't read the full image
+    if (m_UseStreaming)
+    {
+      inputImage->SetRequestedRegion(nullRegion);
+    }
+    else
+    {
+      inputImage->SetRequestedRegion(inputImage->GetLargestPossibleRegion());
+    }
+    } // next image
+ }
+
+/**
+ *
+ */
+template <class TInputImage>
+void
+TensorflowMultisourceModelLearningBase<TInputImage>
+::GenerateData()
+ {
+
+  // Batches loop
+  const IndexValueType nBatches = vcl_ceil(m_NumberOfSamples / m_BatchSize);
+  const IndexValueType rest = m_NumberOfSamples % m_BatchSize;
+
+  itk::ProgressReporter progress(this, 0, nBatches);
+
+  for (IndexValueType batch = 0 ; batch < nBatches ; batch++)
+    {
+
+    // Create input tensors list
+    TensorListType inputs;
+
+    // Batch start and size
+    const IndexValueType sampleStart = batch * m_BatchSize;
+    IndexValueType batchSize = m_BatchSize;
+    if (rest != 0)
+    {
+      batchSize = rest;
+    }
+
+    // Process the batch
+    ProcessBatch(inputs, sampleStart, batchSize);
+
+    progress.CompletedPixel();
+    } // Next batch
+
+ }
+
+template <class TInputImage>
+void
+TensorflowMultisourceModelLearningBase<TInputImage>
+::PopulateInputTensors(TensorListType & inputs, const IndexValueType & sampleStart,
+    const IndexValueType & batchSize, const IndexListType & order)
+ {
+  const bool reorder = order.size();
+
+  // Populate input tensors
+  for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++)
+    {
+    // Input image pointer
+    ImagePointerType inputPtr = const_cast<ImageType*>(this->GetInput(i));
+
+    // Patch size of tensor #i
+    const SizeType inputPatchSize = this->GetInputReceptiveFields().at(i);
+
+    // Create the tensor for the batch
+    const tensorflow::int64 sz_n = batchSize;
+    const tensorflow::int64 sz_y = inputPatchSize[1];
+    const tensorflow::int64 sz_x = inputPatchSize[0];
+    const tensorflow::int64 sz_c = inputPtr->GetNumberOfComponentsPerPixel();
+    const tensorflow::TensorShape inputTensorShape({sz_n, sz_y, sz_x, sz_c});
+    tensorflow::Tensor inputTensor(this->GetInputTensorsDataTypes()[i], inputTensorShape);
+
+    // Populate the tensor
+    for (tensorflow::uint64 elem = 0 ; elem < batchSize ; elem++)
+      {
+      const tensorflow::uint64 samplePos = sampleStart + elem;
+      IndexType start;
+      start[0] = 0;
+      if (reorder)
+      {
+        start[1] = order[samplePos] * sz_y;
+      }
+      else
+      {
+        start[1] = samplePos * sz_y;;
+      }
+      RegionType patchRegion(start, inputPatchSize);
+      if (m_UseStreaming)
+      {
+        // If streaming is enabled, we need to explicitly propagate requested region
+        tf::PropagateRequestedRegion<TInputImage>(inputPtr, patchRegion);
+      }
+      tf::RecopyImageRegionToTensorWithCast<TInputImage>(inputPtr, patchRegion, inputTensor, elem );
+      }
+
+    // Input #i : the tensor of patches (aka the batch)
+    DictElementType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor };
+    inputs.push_back(input1);
+    } // next input tensor
+ }
+
+
+} // end namespace otb
+
+
+#endif
diff --git a/include/otbTensorflowMultisourceModelTrain.h b/include/otbTensorflowMultisourceModelTrain.h
index 8308f2ed5bbbb64b949163f33bea3a3b5eb1fa80..ec4ce349c6af76d36c6d2280c080b041d4b06e2d 100644
--- a/include/otbTensorflowMultisourceModelTrain.h
+++ b/include/otbTensorflowMultisourceModelTrain.h
@@ -16,7 +16,7 @@
 #include "itkSimpleDataObjectDecorator.h"
 
 // Base
-#include "otbTensorflowMultisourceModelBase.h"
+#include "otbTensorflowMultisourceModelLearningBase.h"
 
 // Shuffle
 #include <random>
@@ -31,67 +31,47 @@ namespace otb
  * \brief This filter train a TensorFlow model over multiple input images.
  *
  * The filter takes N input images and feed the TensorFlow model.
- * Names of input placeholders must be specified using the
- * SetInputPlaceholdersNames method
  *
- * TODO: Add an option to disable streaming
  *
  * \ingroup OTBTensorflow
  */
 template <class TInputImage>
 class ITK_EXPORT TensorflowMultisourceModelTrain :
-public TensorflowMultisourceModelBase<TInputImage>
+public TensorflowMultisourceModelLearningBase<TInputImage>
 {
 public:
 
   /** Standard class typedefs. */
-  typedef TensorflowMultisourceModelTrain                    Self;
-  typedef TensorflowMultisourceModelBase<TInputImage>        Superclass;
-  typedef itk::SmartPointer<Self>                            Pointer;
-  typedef itk::SmartPointer<const Self>                      ConstPointer;
+  typedef TensorflowMultisourceModelTrain                     Self;
+  typedef TensorflowMultisourceModelLearningBase<TInputImage> Superclass;
+  typedef itk::SmartPointer<Self>                             Pointer;
+  typedef itk::SmartPointer<const Self>                       ConstPointer;
 
   /** Method for creation through the object factory. */
   itkNewMacro(Self);
 
   /** Run-time type information (and related methods). */
-  itkTypeMacro(TensorflowMultisourceModelTrain, TensorflowMultisourceModelBase);
-
-  /** Images typedefs */
-  typedef typename Superclass::ImageType         ImageType;
-  typedef typename Superclass::ImagePointerType  ImagePointerType;
-  typedef typename Superclass::RegionType        RegionType;
-  typedef typename Superclass::SizeType          SizeType;
-  typedef typename Superclass::IndexType         IndexType;
-
-  /* Typedefs for parameters */
-  typedef typename Superclass::DictType          DictType;
-  typedef typename Superclass::StringList        StringList;
-  typedef typename Superclass::SizeListType      SizeListType;
-  typedef typename Superclass::DictListType      DictListType;
-  typedef typename Superclass::TensorListType    TensorListType;
-
-  itkSetMacro(BatchSize, unsigned int);
-  itkGetMacro(BatchSize, unsigned int);
-  itkGetMacro(NumberOfSamples, unsigned int);
-
-  virtual void GenerateOutputInformation(void);
+  itkTypeMacro(TensorflowMultisourceModelTrain, TensorflowMultisourceModelLearningBase);
 
-  virtual void GenerateInputRequestedRegion();
+  /** Superclass typedefs */
+  typedef typename Superclass::IndexValueType    IndexValueType;
+  typedef typename Superclass::TensorListType    TensorListType;
+  typedef typename Superclass::IndexListType     IndexListType;
 
-  virtual void GenerateData();
 
 protected:
   TensorflowMultisourceModelTrain();
   virtual ~TensorflowMultisourceModelTrain() {};
 
+  void GenerateData();
+  void ProcessBatch(TensorListType & inputs, const IndexValueType & sampleStart,
+      const IndexValueType & batchSize);
+
 private:
   TensorflowMultisourceModelTrain(const Self&); //purposely not implemented
   void operator=(const Self&); //purposely not implemented
 
-  unsigned int               m_BatchSize;               // Batch size
-
-  // Read only
-  unsigned int               m_NumberOfSamples;         // Number of samples
+  IndexListType     m_RandomIndices;           // Reordered indices
 
 }; // end class
 
diff --git a/include/otbTensorflowMultisourceModelTrain.hxx b/include/otbTensorflowMultisourceModelTrain.hxx
index aa47237c30ba886c5bc7b9effa5028d9cdefde91..2ad029129f0c75bf940cd7b8badd025a98e564ac 100644
--- a/include/otbTensorflowMultisourceModelTrain.hxx
+++ b/include/otbTensorflowMultisourceModelTrain.hxx
@@ -20,183 +20,44 @@ template <class TInputImage>
 TensorflowMultisourceModelTrain<TInputImage>
 ::TensorflowMultisourceModelTrain()
  {
-  m_BatchSize = 100;
-  m_NumberOfSamples = 0;
  }
 
-
 template <class TInputImage>
 void
 TensorflowMultisourceModelTrain<TInputImage>
-::GenerateOutputInformation()
+::GenerateData()
  {
-  Superclass::GenerateOutputInformation();
-
-  ImageType * outputPtr = this->GetOutput();
-  RegionType nullRegion;
-  nullRegion.GetModifiableSize().Fill(1);
-  outputPtr->SetNumberOfComponentsPerPixel(1);
-  outputPtr->SetLargestPossibleRegion( nullRegion );
-
-  //////////////////////////////////////////////////////////////////////////////////////////
-  //                               Check the number of samples
-  //////////////////////////////////////////////////////////////////////////////////////////
-
-  m_NumberOfSamples = 0;
-  for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++)
-    {
-    // Input image pointer
-    ImagePointerType inputPtr = const_cast<ImageType*>(this->GetInput(i));
-
-    // Make sure input is available
-    if ( inputPtr.IsNull() )
-      {
-      itkExceptionMacro(<< "Input " << i << " is null!");
-      }
-
-    // Update input information
-    inputPtr->UpdateOutputInformation();
-
-    // Patch size of tensor #i
-    const SizeType inputPatchSize = this->GetInputFOVSizes().at(i);
-
-    // Input image requested region
-    const RegionType reqRegion = inputPtr->GetLargestPossibleRegion();
-
-    // Check size X
-    if (inputPatchSize[0] != reqRegion.GetSize(0))
-      itkExceptionMacro("Patch size for input " << i << " is " << inputPatchSize <<
-                        " but input patches image size is " << reqRegion.GetSize());
-
-    // Check size Y
-    if (reqRegion.GetSize(1) % inputPatchSize[1] != 0)
-      itkExceptionMacro("Input patches image must have a number of rows which is "
-          "a multiple of the patch size Y! Patches image has " << reqRegion.GetSize(1) <<
-          " rows but patch size Y is " <<  inputPatchSize[1] << " for input " << i);
-
-    // Get the batch size
-    const tensorflow::uint64 currNumberOfSamples = reqRegion.GetSize(1) / inputPatchSize[1];
-
-    // Check the consistency with other inputs
-    if (m_NumberOfSamples == 0)
-      {
-      m_NumberOfSamples = currNumberOfSamples;
-      }
-    else if (m_NumberOfSamples != currNumberOfSamples)
-      {
-      itkGenericExceptionMacro("Previous batch size is " << m_NumberOfSamples << " but input " << i
-                               << " has a batch size of " << currNumberOfSamples );
-      }
-    } // next input
- }
 
-template <class TInputImage>
-void
-TensorflowMultisourceModelTrain<TInputImage>
-::GenerateInputRequestedRegion()
- {
-  Superclass::GenerateInputRequestedRegion();
-
-  // For each image, set no image region
-  for(unsigned int i = 0; i < this->GetNumberOfInputs(); ++i)
-    {
-    RegionType nullRegion;
-    ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(i) );
-// TODO: streaming mode on/off
-//    inputImage->SetRequestedRegion(nullRegion);
-    inputImage->SetRequestedRegion(inputImage->GetLargestPossibleRegion());
-    } // next image
+  // Initial sequence 1...N
+  m_RandomIndices.resize(this->GetNumberOfSamples());
+  std::iota (std::begin(m_RandomIndices), std::end(m_RandomIndices), 0);
+
+  // Shuffle the sequence
+  std::random_device rd;
+  std::mt19937 g(rd());
+  std::shuffle(m_RandomIndices.begin(), m_RandomIndices.end(), g);
+
+  // Call the generic method
+  Superclass::GenerateData();
+
  }
 
-/**
- *
- */
 template <class TInputImage>
 void
 TensorflowMultisourceModelTrain<TInputImage>
-::GenerateData()
+::ProcessBatch(TensorListType & inputs, const IndexValueType & sampleStart,
+    const IndexValueType & batchSize)
  {
+  // Populate input tensors
+  PopulateInputTensor(inputs, sampleStart, batchSize, m_RandomIndices);
 
-  // Random sequence
-  std::vector<int> v(m_NumberOfSamples) ;
-  std::iota (std::begin(v), std::end(v), 0);
-
-  // Shuffle
-  std::random_device rd;
-  std::mt19937 g(rd());
-  std::shuffle(v.begin(), v.end(), g);
-
-  // Batches loop
-  const tensorflow::uint64 nBatches = vcl_ceil(m_NumberOfSamples / m_BatchSize);
-  const tensorflow::uint64 rest = m_NumberOfSamples % m_BatchSize;
-  itk::ProgressReporter progress(this, 0, nBatches);
-  for (tensorflow::uint64 batch = 0 ; batch < nBatches ; batch++)
-    {
-    // Update progress
-    this->UpdateProgress((float) batch / (float) nBatches);
-
-    // Create input tensors list
-    DictListType inputs;
-
-    // Batch start and size
-    const tensorflow::uint64 sampleStart = batch * m_BatchSize;
-    tensorflow::uint64 batchSize = m_BatchSize;
-    if (rest != 0)
-    {
-      batchSize = rest;
-    }
-
-    // Populate input tensors
-    for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++)
-      {
-      // Input image pointer
-      ImagePointerType inputPtr = const_cast<ImageType*>(this->GetInput(i));
-
-      // Patch size of tensor #i
-      const SizeType inputPatchSize = this->GetInputFOVSizes().at(i);
-
-      // Create the tensor for the batch
-      const tensorflow::int64 sz_n = batchSize;
-      const tensorflow::int64 sz_y = inputPatchSize[1];
-      const tensorflow::int64 sz_x = inputPatchSize[0];
-      const tensorflow::int64 sz_c = inputPtr->GetNumberOfComponentsPerPixel();
-      const tensorflow::TensorShape inputTensorShape({sz_n, sz_y, sz_x, sz_c});
-      tensorflow::Tensor inputTensor(this->GetInputTensorsDataTypes()[i], inputTensorShape);
-
-      // Populate the tensor
-      for (tensorflow::uint64 elem = 0 ; elem < batchSize ; elem++)
-        {
-        const tensorflow::uint64 samplePos = sampleStart + elem;
-        const tensorflow::uint64 randPos = v[samplePos];
-        IndexType start;
-        start[0] = 0;
-        start[1] = randPos * sz_y;
-        RegionType patchRegion(start, inputPatchSize);
-// TODO: streaming mode on/off
-//        tf::PropagateRequestedRegion<TInputImage>(inputPtr, patchRegion);
-        tf::RecopyImageRegionToTensorWithCast<TInputImage>(inputPtr, patchRegion, inputTensor, elem );
-        }
-
-      // Input #i : the tensor of patches (aka the batch)
-      DictType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor };
-      inputs.push_back(input1);
-      } // next input tensor
-
-    // Run the TF session here
-    TensorListType outputs;
-    this->RunSession(inputs, outputs);
-
-    // Get output tensors
-    for (auto& output: outputs)
-      {
-      std::cout << tf::PrintTensorInfos(output) << std::endl;
-      }
-
-    progress.CompletedPixel();
-    } // Next batch
+  // Run the TF session here
+  TensorListType outputs;
+  this->RunSession(inputs, outputs);
 
  }
 
+
 } // end namespace otb
 
 
diff --git a/include/otbTensorflowMultisourceModelValidate.h b/include/otbTensorflowMultisourceModelValidate.h
index 868317f9344ba1567deaf71aecd7b18a2ccd635f..048ff70cdfd4319abeba653c0121940a05af6127 100644
--- a/include/otbTensorflowMultisourceModelValidate.h
+++ b/include/otbTensorflowMultisourceModelValidate.h
@@ -33,16 +33,12 @@ namespace otb
  * \brief This filter validates a TensorFlow model over multiple input images.
  *
  * The filter takes N input images and feed the TensorFlow model.
- * Names of input placeholders must be specified using the
- * SetInputPlaceholdersNames method
- *
- * TODO: Add an option to disable streaming
  *
  * \ingroup OTBTensorflow
  */
 template <class TInputImage>
 class ITK_EXPORT TensorflowMultisourceModelValidate :
-public TensorflowMultisourceModelBase<TInputImage>
+public TensorflowMultisourceModelLearningBase<TInputImage>
 {
 public:
 
@@ -56,7 +52,7 @@ public:
   itkNewMacro(Self);
 
   /** Run-time type information (and related methods). */
-  itkTypeMacro(TensorflowMultisourceModelValidate, TensorflowMultisourceModelBase);
+  itkTypeMacro(TensorflowMultisourceModelValidate, TensorflowMultisourceModelLearningBase);
 
   /** Images typedefs */
   typedef typename Superclass::ImageType         ImageType;
@@ -70,8 +66,8 @@ public:
   typedef typename Superclass::DictType          DictType;
   typedef typename Superclass::StringList        StringList;
   typedef typename Superclass::SizeListType      SizeListType;
-  typedef typename Superclass::DictListType      DictListType;
   typedef typename Superclass::TensorListType    TensorListType;
+  typedef typename Superclass::IndexValueType    IndexValueType;
 
   /* Typedefs for validation */
   typedef unsigned long                            CountValueType;
@@ -84,23 +80,11 @@ public:
   typedef std::vector<ConfMatType>                 ConfMatListType;
   typedef itk::ImageRegionConstIterator<ImageType> IteratorType;
 
-  /* Set and Get the batch size
-  itkSetMacro(BatchSize, unsigned int);
-  itkGetMacro(BatchSize, unsigned int);
-
-  /** Get the number of samples */
-  itkGetMacro(NumberOfSamples, unsigned int);
-
-  virtual void GenerateOutputInformation(void);
-
-  virtual void GenerateInputRequestedRegion();
 
   /** Set and Get the input references */
   virtual void SetInputReferences(ImageListType input);
   ImagePointerType GetInputReference(unsigned int index);
 
-  virtual void GenerateData();
-
   /** Get the confusion matrix */
   const ConfMatType GetConfusionMatrix(unsigned int target);
 
@@ -111,15 +95,18 @@ protected:
   TensorflowMultisourceModelValidate();
   virtual ~TensorflowMultisourceModelValidate() {};
 
+  void GenerateOutputInformation(void);
+  void GenerateData();
+  void ProcessBatch(TensorListType & inputs, const IndexValueType & sampleStart,
+      const IndexValueType & batchSize);
+
 private:
   TensorflowMultisourceModelValidate(const Self&); //purposely not implemented
   void operator=(const Self&); //purposely not implemented
 
-  unsigned int               m_BatchSize;               // Batch size
   ImageListType              m_References;              // The references images
 
   // Read only
-  unsigned int               m_NumberOfSamples;         // Number of samples
   ConfMatListType            m_ConfusionMatrices;       // Confusion matrix
   MapOfClassesListType       m_MapsOfClasses;           // Maps of classes
 
diff --git a/include/otbTensorflowMultisourceModelValidate.hxx b/include/otbTensorflowMultisourceModelValidate.hxx
index b0ec5e2dd264aa5369cf6f7e4d364a7cbc9fbbdb..3f4fd9b91973eac710a7b9495664cdffef2e5887 100644
--- a/include/otbTensorflowMultisourceModelValidate.hxx
+++ b/include/otbTensorflowMultisourceModelValidate.hxx
@@ -20,8 +20,6 @@ template <class TInputImage>
 TensorflowMultisourceModelValidate<TInputImage>
 ::TensorflowMultisourceModelValidate()
  {
-  m_BatchSize = 100;
-  m_NumberOfSamples = 0;
  }
 
 
@@ -32,63 +30,6 @@ TensorflowMultisourceModelValidate<TInputImage>
  {
   Superclass::GenerateOutputInformation();
 
-  ImageType * outputPtr = this->GetOutput();
-  RegionType nullRegion;
-  nullRegion.GetModifiableSize().Fill(1);
-  outputPtr->SetNumberOfComponentsPerPixel(1);
-  outputPtr->SetLargestPossibleRegion( nullRegion );
-
-  //////////////////////////////////////////////////////////////////////////////////////////
-  //                               Check the number of samples
-  //////////////////////////////////////////////////////////////////////////////////////////
-
-  m_NumberOfSamples = 0;
-  for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++)
-    {
-    // Input image pointer
-    ImagePointerType inputPtr = const_cast<ImageType*>(this->GetInput(i));
-
-    // Make sure input is available
-    if ( inputPtr.IsNull() )
-      {
-      itkExceptionMacro(<< "Input " << i << " is null!");
-      }
-
-    // Update input information
-    inputPtr->UpdateOutputInformation();
-
-    // Patch size of tensor #i
-    const SizeType inputPatchSize = this->GetInputFOVSizes().at(i);
-
-    // Input image requested region
-    const RegionType reqRegion = inputPtr->GetLargestPossibleRegion();
-
-    // Check size X
-    if (inputPatchSize[0] != reqRegion.GetSize(0))
-      itkExceptionMacro("Patch size for input " << i << " is " << inputPatchSize <<
-                        " but input patches image size is " << reqRegion.GetSize());
-
-    // Check size Y
-    if (reqRegion.GetSize(1) % inputPatchSize[1] != 0)
-      itkExceptionMacro("Input patches image must have a number of rows which is "
-          "a multiple of the patch size Y! Patches image has " << reqRegion.GetSize(1) <<
-          " rows but patch size Y is " <<  inputPatchSize[1] << " for input " << i);
-
-    // Get the batch size
-    const tensorflow::uint64 currNumberOfSamples = reqRegion.GetSize(1) / inputPatchSize[1];
-
-    // Check the consistency with other inputs
-    if (m_NumberOfSamples == 0)
-      {
-      m_NumberOfSamples = currNumberOfSamples;
-      }
-    else if (m_NumberOfSamples != currNumberOfSamples)
-      {
-      itkGenericExceptionMacro("Previous batch size is " << m_NumberOfSamples << " but input " << i
-                               << " has a batch size of " << currNumberOfSamples );
-      }
-    } // next input
-
   //////////////////////////////////////////////////////////////////////////////////////////
   //                               Check the references
   //////////////////////////////////////////////////////////////////////////////////////////
@@ -98,7 +39,7 @@ TensorflowMultisourceModelValidate<TInputImage>
     {
     itkExceptionMacro("No reference is set");
     }
-  SizeListType outputEFSizes = this->GetOutputFOESizes();
+  SizeListType outputEFSizes = this->GetOutputExpressionFields();
   if (nbOfRefs != outputEFSizes.size())
     {
     itkExceptionMacro("There is " << nbOfRefs << " but only " <<
@@ -115,32 +56,16 @@ TensorflowMultisourceModelValidate<TInputImage>
       itkExceptionMacro("Reference image " << i << " width is " << refRegion.GetSize(0) <<
                         " but field of expression width is " << outputFOESize[0]);
       }
-    if (refRegion.GetSize(1) / outputFOESize[1] != m_NumberOfSamples)
+    if (refRegion.GetSize(1) / outputFOESize[1] != this->GetNumberOfSamples())
       {
       itkExceptionMacro("Reference image " << i << " height is " << refRegion.GetSize(1) <<
                         " but field of expression width is " << outputFOESize[1] <<
-                        " which is not consistent with the number of samples (" << m_NumberOfSamples << ")");
+                        " which is not consistent with the number of samples (" << this->GetNumberOfSamples() << ")");
       }
     }
 
  }
 
-template <class TInputImage>
-void
-TensorflowMultisourceModelValidate<TInputImage>
-::GenerateInputRequestedRegion()
- {
-  Superclass::GenerateInputRequestedRegion();
-
-  // For each image, set no image region
-  for(unsigned int i = 0; i < this->GetNumberOfInputs(); ++i)
-    {
-    RegionType nullRegion;
-    ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(i) );
-    inputImage->SetRequestedRegion(nullRegion);
-    } // next image
-
- }
 
 /*
  * Set the references images
@@ -195,127 +120,8 @@ TensorflowMultisourceModelValidate<TInputImage>
     confMatMaps.push_back(mat);
     }
 
-  // Batches loop
-  const tensorflow::uint64 nBatches = vcl_ceil(m_NumberOfSamples / m_BatchSize);
-  const tensorflow::uint64 rest = m_NumberOfSamples % m_BatchSize;
-  itk::ProgressReporter progress(this, 0, nBatches);
-  for (tensorflow::uint64 batch = 0 ; batch < nBatches ; batch++)
-    {
-    // Update progress
-    this->UpdateProgress((float) batch / (float) nBatches);
-
-    // Sample start of this batch
-    const tensorflow::uint64 sampleStart = batch * m_BatchSize;
-    tensorflow::uint64 batchSize = m_BatchSize;
-    if (rest != 0)
-    {
-      batchSize = rest;
-    }
-
-    // Create input tensors list
-    DictListType inputs;
-
-    // Populate input tensors
-    for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++)
-      {
-      // Input image pointer
-      ImagePointerType inputPtr = const_cast<ImageType*>(this->GetInput(i));
-
-      // Patch size of tensor #i
-      const SizeType inputPatchSize = this->GetInputFOVSizes().at(i);
-
-      // Create the tensor for the batch
-      const tensorflow::int64 sz_n = batchSize;
-      const tensorflow::int64 sz_y = inputPatchSize[1];
-      const tensorflow::int64 sz_x = inputPatchSize[0];
-      const tensorflow::int64 sz_c = inputPtr->GetNumberOfComponentsPerPixel();
-      const tensorflow::TensorShape inputTensorShape({sz_n, sz_y, sz_x, sz_c});
-      tensorflow::Tensor inputTensor(this->GetInputTensorsDataTypes()[i], inputTensorShape);
-
-      // Populate the tensor
-      for (tensorflow::uint64 elem = 0 ; elem < batchSize ; elem++)
-        {
-        const tensorflow::uint64 samplePos = sampleStart + elem;
-        IndexType start;
-        start[0] = 0;
-        start[1] = samplePos * sz_y;
-        RegionType patchRegion(start, inputPatchSize);
-        tf::PropagateRequestedRegion<TInputImage>(inputPtr, patchRegion);
-        tf::RecopyImageRegionToTensorWithCast<TInputImage>(inputPtr, patchRegion, inputTensor, elem );
-        }
-
-      // Input #i : the tensor of patches (aka the batch)
-      DictType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor };
-      inputs.push_back(input1);
-      } // next input tensor
-
-    // Run the TF session here
-    TensorListType outputs;
-    this->RunSession(inputs, outputs);
-
-    // Perform the validation
-    if (outputs.size() != m_References.size())
-      {
-      itkWarningMacro("There is " << outputs.size() << " outputs returned after session run, " <<
-                      "but only " << m_References.size() << " reference(s) set");
-      }
-    SizeListType outputEFSizes = this->GetOutputFOESizes();
-    for (unsigned int refIdx = 0 ; refIdx < outputs.size() ; refIdx++)
-      {
-      // Recopy the chunk
-      const SizeType outputFOESize = outputEFSizes[refIdx];
-      IndexType cpyStart;
-      cpyStart.Fill(0);
-      IndexType refRegStart;
-      refRegStart.Fill(0);
-      refRegStart[1] = outputFOESize[1] * sampleStart;
-      SizeType cpySize;
-      cpySize[0] = outputFOESize[0];
-      cpySize[1] = outputFOESize[1] * batchSize;
-      RegionType cpyRegion(cpyStart, cpySize);
-      RegionType refRegion(refRegStart, cpySize);
-
-      // Allocate a temporary image
-      ImagePointerType img = ImageType::New();
-      img->SetRegions(cpyRegion);
-      img->SetNumberOfComponentsPerPixel(1);
-      img->Allocate();
-
-      int co = 0;
-      tf::CopyTensorToImageRegion<TInputImage>(outputs[refIdx], cpyRegion, img, cpyRegion, co);
-
-      // Retrieve the reference image region
-      tf::PropagateRequestedRegion<TInputImage>(m_References[refIdx], refRegion);
-
-      // Update the confusion matrices
-      IteratorType inIt(img, cpyRegion);
-      IteratorType refIt(m_References[refIdx], refRegion);
-      for (inIt.GoToBegin(), refIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt, ++refIt)
-        {
-        const int classIn = static_cast<LabelValueType>(inIt.Get()[0]);
-        const int classRef = static_cast<LabelValueType>(refIt.Get()[0]);
-
-        if (confMatMaps[refIdx].count(classRef) == 0)
-          {
-          MapType newMap;
-          newMap[classIn] = 1;
-          confMatMaps[refIdx][classRef] = newMap;
-          }
-        else
-          {
-          if (confMatMaps[refIdx][classRef].count(classIn) == 0)
-            {
-            confMatMaps[refIdx][classRef][classIn] = 1;
-            }
-          else
-            {
-            confMatMaps[refIdx][classRef][classIn]++;
-            }
-          }
-        }
-      }
-    progress.CompletedPixel();
-    } // Next batch
+  // Run all the batches
+  Superclass::GenerateData();
 
   // Compute confusion matrices
   for (unsigned int i = 0 ; i < confMatMaps.size() ; i++)
@@ -353,6 +159,88 @@ TensorflowMultisourceModelValidate<TInputImage>
     m_ConfusionMatrices.push_back(matrix);
     m_MapsOfClasses.push_back(values);
 
+
+
+
+
+    }
+
+ }
+
+
+template <class TInputImage>
+void
+TensorflowMultisourceModelValidate<TInputImage>
+::ProcessBatch(TensorListType & inputs, const IndexValueType & sampleStart,
+    const IndexValueType & batchSize)
+ {
+  // Populate input tensors
+  PopulateInputTensor(inputs, sampleStart, batchSize);
+
+  // Run the TF session here
+  TensorListType outputs;
+  this->RunSession(inputs, outputs);
+
+  // Perform the validation
+  if (outputs.size() != m_References.size())
+    {
+    itkWarningMacro("There is " << outputs.size() << " outputs returned after session run, " <<
+                    "but only " << m_References.size() << " reference(s) set");
+    }
+  SizeListType outputEFSizes = this->GetOutputExpressionFields();
+  for (unsigned int refIdx = 0 ; refIdx < outputs.size() ; refIdx++)
+    {
+    // Recopy the chunk
+    const SizeType outputFOESize = outputEFSizes[refIdx];
+    IndexType cpyStart;
+    cpyStart.Fill(0);
+    IndexType refRegStart;
+    refRegStart.Fill(0);
+    refRegStart[1] = outputFOESize[1] * sampleStart;
+    SizeType cpySize;
+    cpySize[0] = outputFOESize[0];
+    cpySize[1] = outputFOESize[1] * batchSize;
+    RegionType cpyRegion(cpyStart, cpySize);
+    RegionType refRegion(refRegStart, cpySize);
+
+    // Allocate a temporary image
+    ImagePointerType img = ImageType::New();
+    img->SetRegions(cpyRegion);
+    img->SetNumberOfComponentsPerPixel(1);
+    img->Allocate();
+
+    int co = 0;
+    tf::CopyTensorToImageRegion<TInputImage>(outputs[refIdx], cpyRegion, img, cpyRegion, co);
+
+    // Retrieve the reference image region
+    tf::PropagateRequestedRegion<TInputImage>(m_References[refIdx], refRegion);
+
+    // Update the confusion matrices
+    IteratorType inIt(img, cpyRegion);
+    IteratorType refIt(m_References[refIdx], refRegion);
+    for (inIt.GoToBegin(), refIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt, ++refIt)
+      {
+      const int classIn = static_cast<LabelValueType>(inIt.Get()[0]);
+      const int classRef = static_cast<LabelValueType>(refIt.Get()[0]);
+
+      if (confMatMaps[refIdx].count(classRef) == 0)
+        {
+        MapType newMap;
+        newMap[classIn] = 1;
+        confMatMaps[refIdx][classRef] = newMap;
+        }
+      else
+        {
+        if (confMatMaps[refIdx][classRef].count(classIn) == 0)
+          {
+          confMatMaps[refIdx][classRef][classIn] = 1;
+          }
+        else
+          {
+          confMatMaps[refIdx][classRef][classIn]++;
+          }
+        }
+      }
     }
 
  }