Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Lozac'h Loic
otbtf
Commits
e42a2d60
Commit
e42a2d60
authored
Aug 30, 2018
by
remi cresson
Browse files
REFAC: wip
parent
20bb0192
Changes
10
Hide whitespace changes
Inline
Side-by-side
include/otbTensorflowMultisourceModelBase.h
View file @
e42a2d60
...
...
@@ -72,10 +72,10 @@ public:
typedef
typename
TInputImage
::
RegionType
RegionType
;
/** Typedefs for parameters */
typedef
std
::
pair
<
std
::
string
,
tensorflow
::
Tensor
>
DictType
;
typedef
std
::
pair
<
std
::
string
,
tensorflow
::
Tensor
>
Dict
Element
Type
;
typedef
std
::
vector
<
std
::
string
>
StringList
;
typedef
std
::
vector
<
SizeType
>
SizeListType
;
typedef
std
::
vector
<
DictType
>
DictLis
tType
;
typedef
std
::
vector
<
Dict
Element
Type
>
Dic
tType
;
typedef
std
::
vector
<
tensorflow
::
DataType
>
DataTypeListType
;
typedef
std
::
vector
<
tensorflow
::
TensorShapeProto
>
TensorShapeProtoList
;
typedef
std
::
vector
<
tensorflow
::
Tensor
>
TensorListType
;
...
...
@@ -87,27 +87,28 @@ public:
tensorflow
::
Session
*
GetSession
()
{
return
m_Session
;
}
/** Model parameters */
void
PushBackInputBundle
(
std
::
string
placeholder
,
SizeType
receptiveField
,
ImagePointerType
image
);
void
PushBackInputTensorBundle
(
std
::
string
name
,
SizeType
receptiveField
,
ImagePointerType
image
);
void
PushBackOuputTensorBundle
(
std
::
string
name
,
SizeType
expressionField
);
//
/** Input placeholders names */
//
itkSetMacro(InputPlaceholders
Names
, StringList);
itkGetMacro
(
InputPlaceholders
Names
,
StringList
);
//
//
/** Receptive field */
//
itkSetMacro(Input
FOVSize
s, SizeListType);
itkGetMacro
(
Input
FOVSize
s
,
SizeListType
);
/** Input placeholders names */
itkSetMacro
(
InputPlaceholders
,
StringList
);
itkGetMacro
(
InputPlaceholders
,
StringList
);
/** Receptive field */
itkSetMacro
(
Input
ReceptiveField
s
,
SizeListType
);
itkGetMacro
(
Input
ReceptiveField
s
,
SizeListType
);
/** Output tensors names */
itkSetMacro
(
OutputTensors
Names
,
StringList
);
itkGetMacro
(
OutputTensors
Names
,
StringList
);
itkSetMacro
(
OutputTensors
,
StringList
);
itkGetMacro
(
OutputTensors
,
StringList
);
/** Expression field */
itkSetMacro
(
Output
FOESize
s
,
SizeListType
);
itkGetMacro
(
Output
FOESize
s
,
SizeListType
);
itkSetMacro
(
Output
ExpressionField
s
,
SizeListType
);
itkGetMacro
(
Output
ExpressionField
s
,
SizeListType
);
/** User placeholders */
void
SetUserPlaceholders
(
Dict
List
Type
dict
)
{
m_UserPlaceholders
=
dict
;
}
Dict
List
Type
GetUserPlaceholders
()
{
return
m_UserPlaceholders
;
}
void
SetUserPlaceholders
(
DictType
dict
)
{
m_UserPlaceholders
=
dict
;
}
DictType
GetUserPlaceholders
()
{
return
m_UserPlaceholders
;
}
/** Target nodes names */
itkSetMacro
(
TargetNodesNames
,
StringList
);
...
...
@@ -125,9 +126,9 @@ protected:
TensorflowMultisourceModelBase
();
virtual
~
TensorflowMultisourceModelBase
()
{};
virtual
std
::
stringstream
GenerateDebugReport
(
Dict
List
Type
&
inputs
,
TensorListType
&
outputs
);
virtual
std
::
stringstream
GenerateDebugReport
(
DictType
&
inputs
);
virtual
void
RunSession
(
Dict
List
Type
&
inputs
,
TensorListType
&
outputs
);
virtual
void
RunSession
(
DictType
&
inputs
,
TensorListType
&
outputs
);
private:
TensorflowMultisourceModelBase
(
const
Self
&
);
//purposely not implemented
...
...
@@ -138,11 +139,11 @@ private:
tensorflow
::
Session
*
m_Session
;
// The tensorflow session
// Model parameters
StringList
m_InputPlaceholders
Names
;
// Input placeholders names
SizeListType
m_Input
FOVSizes
;
// Input tensors field of view (FOV) size
s
S
izeListType
m_Output
FOESize
s
;
// Output tensors
field of expression (FOE) siz
es
Dict
ListType
m_
UserPlaceholders
;
// User placeholder
s
StringList
m_OutputTensorsName
s
;
// User
tenso
rs
StringList
m_InputPlaceholders
;
// Input placeholders names
SizeListType
m_Input
ReceptiveFields
;
// Input receptive field
s
S
tringList
m_Output
Tensor
s
;
// Output tensors
nam
es
Size
ListType
m_
OutputExpressionFields
;
// Output expression field
s
DictType
m_UserPlaceholder
s
;
// User
placeholde
rs
StringList
m_TargetNodesNames
;
// User target tensors
// Read-only
...
...
include/otbTensorflowMultisourceModelBase.hxx
View file @
e42a2d60
...
...
@@ -20,22 +20,23 @@ template <class TInputImage, class TOutputImage>
TensorflowMultisourceModelBase
<
TInputImage
,
TOutputImage
>
::
TensorflowMultisourceModelBase
()
{
m_Session
=
nullptr
;
}
template
<
class
TInputImage
,
class
TOutputImage
>
void
TensorflowMultisourceModelBase
<
TInputImage
,
TOutputImage
>
::
PushBackInputBundle
(
std
::
string
placeholder
,
SizeType
receptiveField
,
ImagePointerType
image
)
::
PushBackInput
Tensor
Bundle
(
std
::
string
placeholder
,
SizeType
receptiveField
,
ImagePointerType
image
)
{
Superclass
::
PushBackInput
(
image
);
m_Input
FOVSize
s
.
push_back
(
receptiveField
);
m_InputPlaceholders
Names
.
push_back
(
placeholder
);
m_Input
ReceptiveField
s
.
push_back
(
receptiveField
);
m_InputPlaceholders
.
push_back
(
placeholder
);
}
template
<
class
TInputImage
,
class
TOutputImage
>
std
::
stringstream
TensorflowMultisourceModelBase
<
TInputImage
,
TOutputImage
>
::
GenerateDebugReport
(
Dict
List
Type
&
inputs
,
TensorListType
&
outputs
)
::
GenerateDebugReport
(
DictType
&
inputs
)
{
// Create a debug report
std
::
stringstream
debugReport
;
...
...
@@ -69,7 +70,7 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
template
<
class
TInputImage
,
class
TOutputImage
>
void
TensorflowMultisourceModelBase
<
TInputImage
,
TOutputImage
>
::
RunSession
(
Dict
List
Type
&
inputs
,
TensorListType
&
outputs
)
::
RunSession
(
DictType
&
inputs
,
TensorListType
&
outputs
)
{
// Add the user's placeholders
...
...
@@ -82,11 +83,11 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
// The session will initialize the outputs
// Run the session, evaluating our output tensors from the graph
auto
status
=
this
->
GetSession
()
->
Run
(
inputs
,
m_OutputTensors
Names
,
m_TargetNodesNames
,
&
outputs
);
auto
status
=
this
->
GetSession
()
->
Run
(
inputs
,
m_OutputTensors
,
m_TargetNodesNames
,
&
outputs
);
if
(
!
status
.
ok
())
{
// Create a debug report
std
::
stringstream
debugReport
=
GenerateDebugReport
(
inputs
,
outputs
);
std
::
stringstream
debugReport
=
GenerateDebugReport
(
inputs
);
// Throw an exception with the report
itkExceptionMacro
(
"Can't run the tensorflow session !
\n
"
<<
...
...
@@ -108,11 +109,11 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
// - patches sizes
// - input image
const
unsigned
int
nbInputs
=
this
->
GetNumberOfInputs
();
if
(
nbInputs
!=
m_Input
FOVSize
s
.
size
()
||
nbInputs
!=
m_InputPlaceholders
Names
.
size
())
if
(
nbInputs
!=
m_Input
ReceptiveField
s
.
size
()
||
nbInputs
!=
m_InputPlaceholders
.
size
())
{
itkExceptionMacro
(
"Number of input images is "
<<
nbInputs
<<
" but the number of input patches size is "
<<
m_Input
FOVSize
s
.
size
()
<<
" and the number of input tensors names is "
<<
m_InputPlaceholders
Names
.
size
());
" but the number of input patches size is "
<<
m_Input
ReceptiveField
s
.
size
()
<<
" and the number of input tensors names is "
<<
m_InputPlaceholders
.
size
());
}
//////////////////////////////////////////////////////////////////////////////////////////
...
...
@@ -120,8 +121,8 @@ TensorflowMultisourceModelBase<TInputImage, TOutputImage>
//////////////////////////////////////////////////////////////////////////////////////////
// Get input and output tensors datatypes and shapes
tf
::
GetTensorAttributes
(
m_Graph
,
m_InputPlaceholders
Names
,
m_InputTensorsShapes
,
m_InputTensorsDataTypes
);
tf
::
GetTensorAttributes
(
m_Graph
,
m_OutputTensors
Names
,
m_OutputTensorsShapes
,
m_OutputTensorsDataTypes
);
tf
::
GetTensorAttributes
(
m_Graph
,
m_InputPlaceholders
,
m_InputTensorsShapes
,
m_InputTensorsDataTypes
);
tf
::
GetTensorAttributes
(
m_Graph
,
m_OutputTensors
,
m_OutputTensorsShapes
,
m_OutputTensorsDataTypes
);
}
...
...
include/otbTensorflowMultisourceModelFilter.h
View file @
e42a2d60
...
...
@@ -30,27 +30,26 @@ namespace otb
*
* The filter takes N input images and feed the TensorFlow model to produce
* one output image of desired TF op results.
* Names of input
/output
placeholders
/
tensors must be specified using the
* Set
Input
Placeholders
Names/SetOutputTensorName
s.
* Names of input placeholders
and output
tensors must be specified using the
* SetPlaceholders
() and SetTensors() method
s.
*
* Example: we have a
t
ensor
f
low model which runs the input images "x1" and "x2"
* Example: we have a
T
ensor
F
low model which runs the input images "x1" and "x2"
* and produces the output image "y".
* "x1" and "x2" are two TF placeholders, we set Input
TensorNames
={"x1","x2"}
* "y1" corresponds to one TF op output, we set OutputTensor
Name
s={"y1"}
* "x1" and "x2" are two TF placeholders, we set Input
Placeholder
={"x1","x2"}
* "y1" corresponds to one TF op output, we set OutputTensors={"y1"}
*
* The reference grid for the output image is the same as the first input image.
* This grid can be scaled by setting the OutputSpacingScale value.
* This can be used to run models which downsize the output image spacing
* (
typically
fully convolutional model with strides) or to produce the result
* (
e.g.
fully convolutional model with strides) or to produce the result
* of a patch-based network at regular intervals.
*
* For each input
, input field of view (FOV
) must be set.
* For each input
(resp. output), receptive field (resp. expression field
) must be set.
* If the number of values in the output tensors (produced by the model) don't
* fit with the output image region, exception will be thrown.
* fit with the output image region,
an
exception will be thrown.
*
*
* The tensorflow Graph is passed using the SetGraph() method
* The tensorflow Session is passed using the SetSession() method
* The TensorFlow Graph is passed using the SetGraph() method
* The TensorFlow Session is passed using the SetSession() method
*
* \ingroup OTBTensorflow
*/
...
...
@@ -94,6 +93,7 @@ public:
typedef
typename
itk
::
ImageRegionConstIterator
<
TInputImage
>
InputConstIteratorType
;
/* Typedefs for parameters */
typedef
typename
Superclass
::
DictElementType
DictElementType
;
typedef
typename
Superclass
::
DictType
DictType
;
typedef
typename
Superclass
::
StringList
StringList
;
typedef
typename
Superclass
::
SizeListType
SizeListType
;
...
...
@@ -101,8 +101,6 @@ public:
typedef
typename
Superclass
::
TensorListType
TensorListType
;
typedef
std
::
vector
<
float
>
ScaleListType
;
itkSetMacro
(
OutputFOESize
,
SizeType
);
itkGetMacro
(
OutputFOESize
,
SizeType
);
itkSetMacro
(
OutputGridSize
,
SizeType
);
itkGetMacro
(
OutputGridSize
,
SizeType
);
itkSetMacro
(
ForceOutputGridSize
,
bool
);
...
...
@@ -132,7 +130,6 @@ private:
TensorflowMultisourceModelFilter
(
const
Self
&
);
//purposely not implemented
void
operator
=
(
const
Self
&
);
//purposely not implemented
SizeType
m_OutputFOESize
;
// Output tensors field of expression (FOE) sizes
SizeType
m_OutputGridSize
;
// Output grid size
bool
m_ForceOutputGridSize
;
// Force output grid size
bool
m_FullyConvolutional
;
// Convolution mode
...
...
include/otbTensorflowMultisourceModelFilter.hxx
View file @
e42a2d60
...
...
@@ -216,7 +216,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
// Update output image extent
PointType
currentInputImageExtentInf
,
currentInputImageExtentSup
;
ImageToExtent
(
currentImage
,
currentInputImageExtentInf
,
currentInputImageExtentSup
,
this
->
GetInput
FOVSize
s
()[
imageIndex
]);
ImageToExtent
(
currentImage
,
currentInputImageExtentInf
,
currentInputImageExtentSup
,
this
->
GetInput
ReceptiveField
s
()[
imageIndex
]);
for
(
unsigned
int
dim
=
0
;
dim
<
ImageType
::
ImageDimension
;
++
dim
)
{
extentInf
[
dim
]
=
vnl_math_max
(
currentInputImageExtentInf
[
dim
],
extentInf
[
dim
]);
...
...
@@ -236,7 +236,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
if
(
!
m_ForceOutputGridSize
)
{
// Default is the output field of expression
m_OutputGridSize
=
m_OutputFOESize
;
m_OutputGridSize
=
this
->
GetOutputExpressionFields
().
at
(
0
)
;
}
// Resize the largestPossibleRegion to be a multiple of the grid size
...
...
@@ -315,9 +315,9 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
}
// Compute the FOV-scale*FOE radius to pad
SizeType
toPad
(
this
->
GetInput
FOVSize
s
().
at
(
i
));
toPad
[
0
]
-=
1
+
(
m_OutputFOESize
[
0
]
-
1
)
*
m_OutputSpacingScale
;
toPad
[
1
]
-=
1
+
(
m_OutputFOESize
[
1
]
-
1
)
*
m_OutputSpacingScale
;
SizeType
toPad
(
this
->
GetInput
ReceptiveField
s
().
at
(
i
));
toPad
[
0
]
-=
1
+
(
this
->
GetOutputExpressionFields
().
at
(
0
)
[
0
]
-
1
)
*
m_OutputSpacingScale
;
toPad
[
1
]
-=
1
+
(
this
->
GetOutputExpressionFields
().
at
(
0
)
[
1
]
-
1
)
*
m_OutputSpacingScale
;
// Pad with radius
SmartPad
(
inRegion
,
toPad
);
...
...
@@ -365,7 +365,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
const
unsigned
int
nInputs
=
this
->
GetNumberOfInputs
();
// Create input tensors list
Dict
List
Type
inputs
;
DictType
inputs
;
// Populate input tensors
for
(
unsigned
int
i
=
0
;
i
<
nInputs
;
i
++
)
...
...
@@ -374,7 +374,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
const
ImagePointerType
inputPtr
=
const_cast
<
TInputImage
*>
(
this
->
GetInput
(
i
));
// Patch size of tensor #i
const
SizeType
inputPatchSize
=
this
->
GetInput
FOVSize
s
().
at
(
i
);
const
SizeType
inputPatchSize
=
this
->
GetInput
ReceptiveField
s
().
at
(
i
);
// Input image requested region
const
RegionType
reqRegion
=
inputPtr
->
GetRequestedRegion
();
...
...
@@ -395,7 +395,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
tf
::
RecopyImageRegionToTensorWithCast
<
TInputImage
>
(
inputPtr
,
reqRegion
,
inputTensor
,
0
);
// Input #1 : the tensor of patches (aka the batch)
DictType
input1
=
{
this
->
GetInputPlaceholdersNames
()[
i
],
inputTensor
};
Dict
Element
Type
input1
=
{
this
->
GetInputPlaceholdersNames
()[
i
],
inputTensor
};
inputs
.
push_back
(
input1
);
}
else
...
...
@@ -429,7 +429,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
}
// Input #1 : the tensor of patches (aka the batch)
DictType
input1
=
{
this
->
GetInputPlaceholdersNames
()[
i
],
inputTensor
};
Dict
Element
Type
input1
=
{
this
->
GetInputPlaceholdersNames
()[
i
],
inputTensor
};
inputs
.
push_back
(
input1
);
}
// mode is not full convolutional
...
...
include/otbTensorflowMultisourceModelLearningBase.h
0 → 100644
View file @
e42a2d60
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef otbTensorflowMultisourceModelLearningBase_h
#define otbTensorflowMultisourceModelLearningBase_h
#include "itkProcessObject.h"
#include "itkNumericTraits.h"
#include "itkSimpleDataObjectDecorator.h"
// Base
#include "otbTensorflowMultisourceModelBase.h"
// Shuffle
#include <random>
#include <algorithm>
#include <iterator>
namespace
otb
{
/**
* \class TensorflowMultisourceModelLearningBase
* \brief This filter is the base class for learning filters.
*
* \ingroup OTBTensorflow
*/
template
<
class
TInputImage
>
class
ITK_EXPORT
TensorflowMultisourceModelLearningBase
:
public
TensorflowMultisourceModelBase
<
TInputImage
>
{
public:
/** Standard class typedefs. */
typedef
TensorflowMultisourceModelLearningBase
Self
;
typedef
TensorflowMultisourceModelBase
<
TInputImage
>
Superclass
;
typedef
itk
::
SmartPointer
<
Self
>
Pointer
;
typedef
itk
::
SmartPointer
<
const
Self
>
ConstPointer
;
/** Run-time type information (and related methods). */
itkTypeMacro
(
TensorflowMultisourceModelLearningBase
,
TensorflowMultisourceModelBase
);
/** Images typedefs */
typedef
typename
Superclass
::
ImageType
ImageType
;
typedef
typename
Superclass
::
ImagePointerType
ImagePointerType
;
typedef
typename
Superclass
::
RegionType
RegionType
;
typedef
typename
Superclass
::
SizeType
SizeType
;
typedef
typename
Superclass
::
IndexType
IndexType
;
/* Typedefs for parameters */
typedef
typename
Superclass
::
DictType
DictType
;
typedef
typename
Superclass
::
DictElementType
DictElementType
;
typedef
typename
Superclass
::
StringList
StringList
;
typedef
typename
Superclass
::
SizeListType
SizeListType
;
typedef
typename
Superclass
::
TensorListType
TensorListType
;
/* Typedefs for index */
typedef
typename
ImageType
::
IndexValueType
IndexValueType
;
typedef
std
::
vector
<
IndexValueType
>
IndexListType
;
// Batch size
itkSetMacro
(
BatchSize
,
IndexValueType
);
itkGetMacro
(
BatchSize
,
IndexValueType
);
// Use streaming
itkSetMacro
(
UseStreaming
,
bool
);
itkGetMacro
(
UseStreaming
,
bool
);
// Get number of samples
itkGetMacro
(
NumberOfSamples
,
IndexValueType
);
protected:
TensorflowMultisourceModelLearningBase
();
virtual
~
TensorflowMultisourceModelLearningBase
()
{};
virtual
void
GenerateOutputInformation
(
void
);
virtual
void
GenerateInputRequestedRegion
();
virtual
void
GenerateData
();
virtual
void
PopulateInputTensors
(
TensorListType
&
inputs
,
const
IndexValueType
&
sampleStart
,
const
IndexValueType
&
batchSize
,
const
IndexListType
&
order
=
IndexListType
());
virtual
void
ProcessBatch
(
TensorListType
&
inputs
,
const
IndexValueType
&
sampleStart
,
const
IndexValueType
&
batchSize
)
=
0
;
private:
TensorflowMultisourceModelLearningBase
(
const
Self
&
);
//purposely not implemented
void
operator
=
(
const
Self
&
);
//purposely not implemented
unsigned
int
m_BatchSize
;
// Batch size
bool
m_UseStreaming
;
// Use streaming on/off
// Read only
IndexValueType
m_NumberOfSamples
;
// Number of samples
};
// end class
}
// end namespace otb
#include "otbTensorflowMultisourceModelLearningBase.hxx"
#endif
include/otbTensorflowMultisourceModelLearningBase.hxx
0 → 100644
View file @
e42a2d60
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef otbTensorflowMultisourceModelLearningBase_txx
#define otbTensorflowMultisourceModelLearningBase_txx
#include "otbTensorflowMultisourceModelLearningBase.h"
namespace
otb
{
template
<
class
TInputImage
>
TensorflowMultisourceModelLearningBase
<
TInputImage
>
::
TensorflowMultisourceModelLearningBase
()
:
m_BatchSize
(
100
),
m_NumberOfSamples
(
0
),
m_UseStreaming
(
false
)
{
}
template
<
class
TInputImage
>
void
TensorflowMultisourceModelLearningBase
<
TInputImage
>
::
GenerateOutputInformation
()
{
Superclass
::
GenerateOutputInformation
();
ImageType
*
outputPtr
=
this
->
GetOutput
();
RegionType
nullRegion
;
nullRegion
.
GetModifiableSize
().
Fill
(
1
);
outputPtr
->
SetNumberOfComponentsPerPixel
(
1
);
outputPtr
->
SetLargestPossibleRegion
(
nullRegion
);
// Count the number of samples
m_NumberOfSamples
=
0
;
for
(
unsigned
int
i
=
0
;
i
<
this
->
GetNumberOfInputs
()
;
i
++
)
{
// Input image pointer
ImagePointerType
inputPtr
=
const_cast
<
ImageType
*>
(
this
->
GetInput
(
i
));
// Make sure input is available
if
(
inputPtr
.
IsNull
()
)
{
itkExceptionMacro
(
<<
"Input "
<<
i
<<
" is null!"
);
}
// Update input information
inputPtr
->
UpdateOutputInformation
();
// Patch size of tensor #i
const
SizeType
inputPatchSize
=
this
->
GetInputReceptiveFields
().
at
(
i
);
// Input image requested region
const
RegionType
reqRegion
=
inputPtr
->
GetLargestPossibleRegion
();
// Check size X
if
(
inputPatchSize
[
0
]
!=
reqRegion
.
GetSize
(
0
))
itkExceptionMacro
(
"Patch size for input "
<<
i
<<
" is "
<<
inputPatchSize
<<
" but input patches image size is "
<<
reqRegion
.
GetSize
());
// Check size Y
if
(
reqRegion
.
GetSize
(
1
)
%
inputPatchSize
[
1
]
!=
0
)
itkExceptionMacro
(
"Input patches image must have a number of rows which is "
<<
"a multiple of the patch size Y! Patches image has "
<<
reqRegion
.
GetSize
(
1
)
<<
" rows but patch size Y is "
<<
inputPatchSize
[
1
]
<<
" for input "
<<
i
);
// Get the batch size
const
tensorflow
::
uint64
currNumberOfSamples
=
reqRegion
.
GetSize
(
1
)
/
inputPatchSize
[
1
];
// Check the consistency with other inputs
if
(
m_NumberOfSamples
==
0
)
{
m_NumberOfSamples
=
currNumberOfSamples
;
}
else
if
(
m_NumberOfSamples
!=
currNumberOfSamples
)
{
itkGenericExceptionMacro
(
"Previous batch size is "
<<
m_NumberOfSamples
<<
" but input "
<<
i
<<
" has a batch size of "
<<
currNumberOfSamples
);
}
}
// next input
}
template
<
class
TInputImage
>
void
TensorflowMultisourceModelLearningBase
<
TInputImage
>
::
GenerateInputRequestedRegion
()
{
Superclass
::
GenerateInputRequestedRegion
();
// For each image, set no image region
for
(
unsigned
int
i
=
0
;
i
<
this
->
GetNumberOfInputs
();
++
i
)
{
RegionType
nullRegion
;
ImageType
*
inputImage
=
static_cast
<
ImageType
*
>
(
Superclass
::
ProcessObject
::
GetInput
(
i
)
);
// If the streaming is enabled, we don't read the full image
if
(
m_UseStreaming
)
{
inputImage
->
SetRequestedRegion
(
nullRegion
);
}
else
{
inputImage
->
SetRequestedRegion
(
inputImage
->
GetLargestPossibleRegion
());
}
}
// next image
}
/**
*
*/
template
<
class
TInputImage
>
void
TensorflowMultisourceModelLearningBase
<
TInputImage
>
::
GenerateData
()
{
// Batches loop
const
IndexValueType
nBatches
=
vcl_ceil
(
m_NumberOfSamples
/
m_BatchSize
);
const
IndexValueType
rest
=
m_NumberOfSamples
%
m_BatchSize
;
itk
::
ProgressReporter
progress
(
this
,
0
,
nBatches
);
for
(
IndexValueType
batch
=
0
;
batch
<
nBatches
;
batch
++
)
{
// Create input tensors list
TensorListType
inputs
;
// Batch start and size
const
IndexValueType
sampleStart
=
batch
*
m_BatchSize
;
IndexValueType
batchSize
=
m_BatchSize
;
if
(
rest
!=
0
)
{
batchSize
=
rest
;
}
// Process the batch
ProcessBatch
(
inputs
,
sampleStart
,
batchSize
);
progress
.
CompletedPixel
();
}
// Next batch
}
template
<
class
TInputImage
>
void
TensorflowMultisourceModelLearningBase
<
TInputImage
>
::
PopulateInputTensors
(
TensorListType
&
inputs
,
const
IndexValueType
&
sampleStart
,
const
IndexValueType
&
batchSize
,
const
IndexListType
&
order
)
{
const
bool
reorder
=
order
.
size
();
// Populate input tensors
for
(
unsigned
int
i
=
0
;
i
<
this
->
GetNumberOfInputs
()
;
i
++
)
{
// Input image pointer
ImagePointerType
inputPtr
=
const_cast
<
ImageType
*>
(
this
->
GetInput
(
i
));
// Patch size of tensor #i
const
SizeType
inputPatchSize
=
this
->
GetInputReceptiveFields
().
at
(
i
);
// Create the tensor for the batch
const
tensorflow
::
int64
sz_n
=
batchSize
;
const
tensorflow
::
int64
sz_y
=
inputPatchSize
[
1
];
const
tensorflow
::
int64
sz_x
=
inputPatchSize
[
0
];
const
tensorflow
::
int64
sz_c
=
inputPtr
->
GetNumberOfComponentsPerPixel
();
const
tensorflow
::
TensorShape
inputTensorShape
({
sz_n
,
sz_y
,
sz_x
,
sz_c
});
tensorflow
::
Tensor
inputTensor
(
this
->
GetInputTensorsDataTypes
()[
i
],
inputTensorShape
);
// Populate the tensor
for
(
tensorflow
::
uint64
elem
=
0
;
elem
<
batchSize
;
elem
++
)
{
const
tensorflow
::
uint64
samplePos
=
sampleStart
+
elem
;
IndexType
start
;
start
[
0
]
=
0
;
if
(
reorder
)
{
start
[
1
]
=
order
[
samplePos
]
*
sz_y
;
}
else
{
start
[
1
]
=
samplePos
*
sz_y
;;
}
RegionType
patchRegion
(
start
,
inputPatchSize
);
if
(
m_UseStreaming
)
{
// If streaming is enabled, we need to explicitly propagate requested region
tf
::
PropagateRequestedRegion
<
TInputImage
>
(
inputPtr
,
patchRegion
);
}
tf
::
RecopyImageRegionToTensorWithCast
<
TInputImage
>
(
inputPtr
,
patchRegion
,
inputTensor
,
elem
);