Commit ddf221b3 authored by Cresson Remi's avatar Cresson Remi
Browse files

REFAC: wip#2

parent e42a2d60
......@@ -248,16 +248,16 @@ public:
m_TFFilter = TFModelFilterType::New();
m_TFFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def());
m_TFFilter->SetSession(m_SavedModel.session.get());
m_TFFilter->SetOutputTensorsNames(GetParameterStringList("output.names"));
m_TFFilter->SetOutputTensors(GetParameterStringList("output.names"));
m_TFFilter->SetOutputSpacingScale(GetParameterFloat("output.spcscale"));
otbAppLogINFO("Output spacing ratio: " << m_TFFilter->GetOutputSpacingScale());
// Get user placeholders
TFModelFilterType::DictListType dict;
TFModelFilterType::StringList expressions = GetParameterStringList("model.userplaceholders");
TFModelFilterType::DictType dict;
for (auto& exp: expressions)
{
TFModelFilterType::DictType entry = tf::ExpressionToTensor(exp);
TFModelFilterType::DictElementType entry = tf::ExpressionToTensor(exp);
dict.push_back(entry);
otbAppLogINFO("Using placeholder " << entry.first << " with " << tf::PrintTensorInfos(entry.second));
......@@ -267,7 +267,7 @@ public:
// Input sources
for (auto& bundle: m_Bundles)
{
m_TFFilter->PushBackInputBundle(bundle.m_Placeholder, bundle.m_PatchSize, bundle.m_ImageSource.Get());
m_TFFilter->PushBackInputTensorBundle(bundle.m_Placeholder, bundle.m_PatchSize, bundle.m_ImageSource.Get());
}
// Fully convolutional mode on/off
......@@ -281,9 +281,9 @@ public:
FloatVectorImageType::SizeType foe;
foe[0] = GetParameterInt("output.foex");
foe[1] = GetParameterInt("output.foey");
m_TFFilter->SetOutputFOESize(foe);
m_TFFilter->SetOutputExpressionFields({foe});
otbAppLogINFO("Output field of expression: " << m_TFFilter->GetOutputFOESize());
otbAppLogINFO("Output field of expression: " << m_TFFilter->GetOutputExpressionFields()[0]);
// Streaming
if (GetParameterInt("finetuning.disabletiling")!=1)
......
......@@ -53,7 +53,7 @@ public:
itkNewMacro(Self);
itkTypeMacro(TensorflowModelTrain, Application);
/** Typedefs for tensorflow */
/** Typedefs for TensorFlow */
typedef otb::TensorflowMultisourceModelTrain<FloatVectorImageType> TrainModelFilterType;
typedef otb::TensorflowMultisourceModelValidate<FloatVectorImageType> ValidateModelFilterType;
typedef otb::TensorflowSource<FloatVectorImageType> TFSource;
......@@ -75,8 +75,8 @@ public:
// Parameters keys
std::string m_KeyInForTrain; // Key of input image list (training)
std::string m_KeyInForValid; // Key of input image list (validation)
std::string m_KeyPHNameForTrain; // Key for placeholder name in the tensorflow model (training)
std::string m_KeyPHNameForValid; // Key for placeholder name in the tensorflow model (validation)
std::string m_KeyPHNameForTrain; // Key for placeholder name in the TensorFlow model (training)
std::string m_KeyPHNameForValid; // Key for placeholder name in the TensorFlow model (validation)
std::string m_KeyPszX; // Key for samples sizes X
std::string m_KeyPszY; // Key for samples sizes Y
};
......@@ -194,10 +194,10 @@ public:
AddParameter(ParameterType_StringList, "training.userplaceholders",
"Additional single-valued placeholders for training. Supported types: int, float, bool.");
MandatoryOff ("training.userplaceholders");
AddParameter(ParameterType_StringList, "training.targetnodesnames", "Names of the target nodes");
MandatoryOn ("training.targetnodesnames");
AddParameter(ParameterType_StringList, "training.outputtensorsnames", "Names of the output tensors to display");
MandatoryOff ("training.outputtensorsnames");
AddParameter(ParameterType_StringList, "training.targetnodes", "Names of the target nodes");
MandatoryOn ("training.targetnodes");
AddParameter(ParameterType_StringList, "training.outputtensors", "Names of the output tensors to display");
MandatoryOff ("training.outputtensors");
// Metrics
AddParameter(ParameterType_Group, "validation", "Validation parameters");
......@@ -228,7 +228,7 @@ public:
SetDocExampleParameterValue("source2.fovy", "1");
SetDocExampleParameterValue("model.dir", "/tmp/my_saved_model/");
SetDocExampleParameterValue("training.userplaceholders", "is_training=true dropout=0.2");
SetDocExampleParameterValue("training.targetnodenames", "optimizer");
SetDocExampleParameterValue("training.targetnodes", "optimizer");
SetDocExampleParameterValue("model.saveto", "/tmp/my_saved_model_vars1");
}
......@@ -350,13 +350,13 @@ public:
//
// Get user placeholders
//
TrainModelFilterType::DictListType GetUserPlaceholders(const std::string key)
TrainModelFilterType::DictType GetUserPlaceholders(const std::string key)
{
TrainModelFilterType::DictListType dict;
TrainModelFilterType::DictType dict;
TrainModelFilterType::StringList expressions = GetParameterStringList(key);
for (auto& exp: expressions)
{
TrainModelFilterType::DictType entry = tf::ExpressionToTensor(exp);
TrainModelFilterType::DictElementType entry = tf::ExpressionToTensor(exp);
dict.push_back(entry);
otbAppLogINFO("Using placeholder " << entry.first << " with " << tf::PrintTensorInfos(entry.second));
......@@ -414,15 +414,15 @@ public:
m_TrainModelFilter = TrainModelFilterType::New();
m_TrainModelFilter->SetGraph(m_SavedModel.meta_graph_def.graph_def());
m_TrainModelFilter->SetSession(m_SavedModel.session.get());
m_TrainModelFilter->SetOutputTensorsNames(GetParameterStringList("training.outputtensorsnames"));
m_TrainModelFilter->SetTargetNodesNames(GetParameterStringList("training.targetnodesnames"));
m_TrainModelFilter->SetOutputTensors(GetParameterStringList("training.outputtensors"));
m_TrainModelFilter->SetTargetNodesNames(GetParameterStringList("training.targetnodes"));
m_TrainModelFilter->SetBatchSize(GetParameterInt("training.batchsize"));
m_TrainModelFilter->SetUserPlaceholders(GetUserPlaceholders("training.userplaceholders"));
// Set inputs
for (unsigned int i = 0 ; i < m_InputSourcesForTraining.size() ; i++)
{
m_TrainModelFilter->PushBackInputBundle(
m_TrainModelFilter->PushBackInputTensorBundle(
m_InputPlaceholdersForTraining[i],
m_InputPatchesSizeForTraining[i],
m_InputSourcesForTraining[i]);
......@@ -458,14 +458,14 @@ public:
for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++)
{
m_ValidateModelFilter->PushBackInputBundle(
m_ValidateModelFilter->PushBackInputTensorBundle(
m_InputPlaceholdersForValidation[i],
m_InputPatchesSizeForValidation[i],
m_InputSourcesForEvaluationAgainstLearningData[i]);
}
m_ValidateModelFilter->SetOutputTensorsNames(m_TargetTensorsNames);
m_ValidateModelFilter->SetOutputTensors(m_TargetTensorsNames);
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData);
m_ValidateModelFilter->SetOutputFOESizes(m_TargetPatchesSize);
m_ValidateModelFilter->SetOutputExpressionFields(m_TargetPatchesSize);
// Update
AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)");
......@@ -509,10 +509,13 @@ public:
private:
tensorflow::SavedModelBundle m_SavedModel; // must be alive during all the execution of the application !
// Filters
TrainModelFilterType::Pointer m_TrainModelFilter;
ValidateModelFilterType::Pointer m_ValidateModelFilter;
tensorflow::SavedModelBundle m_SavedModel; // must be alive during all the execution of the application !
// Inputs
BundleList m_Bundles;
// Patches size
......
......@@ -97,7 +97,6 @@ public:
typedef typename Superclass::DictType DictType;
typedef typename Superclass::StringList StringList;
typedef typename Superclass::SizeListType SizeListType;
typedef typename Superclass::DictListType DictListType;
typedef typename Superclass::TensorListType TensorListType;
typedef std::vector<float> ScaleListType;
......@@ -139,6 +138,7 @@ private:
SpacingType m_OutputSpacing; // Output image spacing
PointType m_OutputOrigin; // Output image origin
SizeType m_OutputSize; // Output image size
PixelType m_NullPixel; // Pixel filled with zeros
}; // end class
......
......@@ -284,6 +284,10 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
outputPtr->SetSignedSpacing ( m_OutputSpacing );
outputPtr->SetLargestPossibleRegion( largestPossibleRegion);
// Set null pixel
m_NullPixel.SetSize(outputPtr->GetNumberOfComponentsPerPixel());
m_NullPixel.Fill(0);
}
template <class TInputImage, class TOutputImage>
......@@ -395,13 +399,12 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
tf::RecopyImageRegionToTensorWithCast<TInputImage>(inputPtr, reqRegion, inputTensor, 0);
// Input #1 : the tensor of patches (aka the batch)
DictElementType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor };
inputs.push_back(input1);
DictElementType input = { this->GetInputPlaceholders()[i], inputTensor };
inputs.push_back(input);
}
else
{
// Preparing patches (not very optimized ! )
// It would be better to perform the loop inside the TF session using TF operators
// Preparing patches
// Shape of input tensor #i
tensorflow::int64 sz_n = outputReqRegion.GetNumberOfPixels();
tensorflow::int64 sz_y = inputPatchSize[1];
......@@ -429,8 +432,9 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
}
// Input #1 : the tensor of patches (aka the batch)
DictElementType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor };
inputs.push_back(input1);
DictElementType input = { this->GetInputPlaceholders()[i], inputTensor };
inputs.push_back(input);
} // mode is not full convolutional
} // next input tensor
......@@ -442,10 +446,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
// Fill the output buffer with zero value
outputPtr->SetBufferedRegion(outputReqRegion);
outputPtr->Allocate();
OutputPixelType nullpix;
nullpix.SetSize(outputPtr->GetNumberOfComponentsPerPixel());
nullpix.Fill(0);
outputPtr->FillBuffer(nullpix);
outputPtr->FillBuffer(m_NullPixel);
// Get output tensors
int bandOffset = 0;
......@@ -453,7 +454,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
{
// The offset (i.e. the starting index of the channel for the output tensor) is updated
// during this call
// TODO: implement a generic strategy enabling FOE copy in patch-based mode (see tf::CopyTensorToImageRegion)
// TODO: implement a generic strategy enabling expression field copy in patch-based mode (see tf::CopyTensorToImageRegion)
try
{
tf::CopyTensorToImageRegion<TOutputImage> (outputs[i],
......@@ -461,7 +462,7 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
}
catch( itk::ExceptionObject & err )
{
std::stringstream debugMsg = this->GenerateDebugReport(inputs, outputs);
std::stringstream debugMsg = this->GenerateDebugReport(inputs);
itkExceptionMacro("Error occured during tensor to image conversion.\n"
<< "Context: " << debugMsg.str()
<< "Error:" << err);
......
......@@ -18,11 +18,6 @@
// Base
#include "otbTensorflowMultisourceModelBase.h"
// Shuffle
#include <random>
#include <algorithm>
#include <iterator>
namespace otb
{
......@@ -86,10 +81,10 @@ protected:
virtual void GenerateData();
virtual void PopulateInputTensors(TensorListType & inputs, const IndexValueType & sampleStart,
const IndexValueType & batchSize, const IndexListType & order = IndexListType());
virtual void PopulateInputTensors(DictType & inputs, const IndexValueType & sampleStart,
const IndexValueType & batchSize, const IndexListType & order);
virtual void ProcessBatch(TensorListType & inputs, const IndexValueType & sampleStart,
virtual void ProcessBatch(DictType & inputs, const IndexValueType & sampleStart,
const IndexValueType & batchSize) = 0;
private:
......
......@@ -19,7 +19,7 @@ namespace otb
template <class TInputImage>
TensorflowMultisourceModelLearningBase<TInputImage>
::TensorflowMultisourceModelLearningBase(): m_BatchSize(100),
m_NumberOfSamples(0), m_UseStreaming(false)
m_UseStreaming(false), m_NumberOfSamples(0)
{
}
......@@ -31,6 +31,7 @@ TensorflowMultisourceModelLearningBase<TInputImage>
{
Superclass::GenerateOutputInformation();
// Set an empty output buffered region
ImageType * outputPtr = this->GetOutput();
RegionType nullRegion;
nullRegion.GetModifiableSize().Fill(1);
......@@ -72,7 +73,7 @@ TensorflowMultisourceModelLearningBase<TInputImage>
<< " rows but patch size Y is " << inputPatchSize[1] << " for input " << i);
// Get the batch size
const tensorflow::uint64 currNumberOfSamples = reqRegion.GetSize(1) / inputPatchSize[1];
const IndexValueType currNumberOfSamples = reqRegion.GetSize(1) / inputPatchSize[1];
// Check the consistency with other inputs
if (m_NumberOfSamples == 0)
......@@ -95,21 +96,21 @@ TensorflowMultisourceModelLearningBase<TInputImage>
{
Superclass::GenerateInputRequestedRegion();
// For each image, set no image region
// For each image, set the requested region
RegionType nullRegion;
for(unsigned int i = 0; i < this->GetNumberOfInputs(); ++i)
{
RegionType nullRegion;
ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(i) );
// If the streaming is enabled, we don't read the full image
if (m_UseStreaming)
{
{
inputImage->SetRequestedRegion(nullRegion);
}
}
else
{
{
inputImage->SetRequestedRegion(inputImage->GetLargestPossibleRegion());
}
}
} // next image
}
......@@ -131,8 +132,8 @@ TensorflowMultisourceModelLearningBase<TInputImage>
for (IndexValueType batch = 0 ; batch < nBatches ; batch++)
{
// Create input tensors list
TensorListType inputs;
// Feed dict
DictType inputs;
// Batch start and size
const IndexValueType sampleStart = batch * m_BatchSize;
......@@ -143,7 +144,7 @@ TensorflowMultisourceModelLearningBase<TInputImage>
}
// Process the batch
ProcessBatch(inputs, sampleStart, batchSize);
this->ProcessBatch(inputs, sampleStart, batchSize);
progress.CompletedPixel();
} // Next batch
......@@ -153,7 +154,7 @@ TensorflowMultisourceModelLearningBase<TInputImage>
template <class TInputImage>
void
TensorflowMultisourceModelLearningBase<TInputImage>
::PopulateInputTensors(TensorListType & inputs, const IndexValueType & sampleStart,
::PopulateInputTensors(DictType & inputs, const IndexValueType & sampleStart,
const IndexValueType & batchSize, const IndexListType & order)
{
const bool reorder = order.size();
......@@ -176,7 +177,7 @@ TensorflowMultisourceModelLearningBase<TInputImage>
tensorflow::Tensor inputTensor(this->GetInputTensorsDataTypes()[i], inputTensorShape);
// Populate the tensor
for (tensorflow::uint64 elem = 0 ; elem < batchSize ; elem++)
for (IndexValueType elem = 0 ; elem < batchSize ; elem++)
{
const tensorflow::uint64 samplePos = sampleStart + elem;
IndexType start;
......@@ -199,8 +200,8 @@ TensorflowMultisourceModelLearningBase<TInputImage>
}
// Input #i : the tensor of patches (aka the batch)
DictElementType input1 = { this->GetInputPlaceholdersNames()[i], inputTensor };
inputs.push_back(input1);
DictElementType input = { this->GetInputPlaceholders()[i], inputTensor };
inputs.push_back(input);
} // next input tensor
}
......
......@@ -54,8 +54,9 @@ public:
itkTypeMacro(TensorflowMultisourceModelTrain, TensorflowMultisourceModelLearningBase);
/** Superclass typedefs */
typedef typename Superclass::IndexValueType IndexValueType;
typedef typename Superclass::DictType DictType;
typedef typename Superclass::TensorListType TensorListType;
typedef typename Superclass::IndexValueType IndexValueType;
typedef typename Superclass::IndexListType IndexListType;
......@@ -63,8 +64,8 @@ protected:
TensorflowMultisourceModelTrain();
virtual ~TensorflowMultisourceModelTrain() {};
void GenerateData();
void ProcessBatch(TensorListType & inputs, const IndexValueType & sampleStart,
virtual void GenerateData();
virtual void ProcessBatch(DictType & inputs, const IndexValueType & sampleStart,
const IndexValueType & batchSize);
private:
......
......@@ -45,11 +45,11 @@ TensorflowMultisourceModelTrain<TInputImage>
template <class TInputImage>
void
TensorflowMultisourceModelTrain<TInputImage>
::ProcessBatch(TensorListType & inputs, const IndexValueType & sampleStart,
::ProcessBatch(DictType & inputs, const IndexValueType & sampleStart,
const IndexValueType & batchSize)
{
// Populate input tensors
PopulateInputTensor(inputs, sampleStart, batchSize, m_RandomIndices);
this->PopulateInputTensors(inputs, sampleStart, batchSize, m_RandomIndices);
// Run the TF session here
TensorListType outputs;
......
......@@ -43,10 +43,10 @@ public TensorflowMultisourceModelLearningBase<TInputImage>
public:
/** Standard class typedefs. */
typedef TensorflowMultisourceModelValidate Self;
typedef TensorflowMultisourceModelBase<TInputImage> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
typedef TensorflowMultisourceModelValidate Self;
typedef TensorflowMultisourceModelLearningBase<TInputImage> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Method for creation through the object factory. */
itkNewMacro(Self);
......@@ -68,6 +68,7 @@ public:
typedef typename Superclass::SizeListType SizeListType;
typedef typename Superclass::TensorListType TensorListType;
typedef typename Superclass::IndexValueType IndexValueType;
typedef typename Superclass::IndexListType IndexListType;
/* Typedefs for validation */
typedef unsigned long CountValueType;
......@@ -97,7 +98,7 @@ protected:
void GenerateOutputInformation(void);
void GenerateData();
void ProcessBatch(TensorListType & inputs, const IndexValueType & sampleStart,
void ProcessBatch(DictType & inputs, const IndexValueType & sampleStart,
const IndexValueType & batchSize);
private:
......@@ -110,6 +111,9 @@ private:
ConfMatListType m_ConfusionMatrices; // Confusion matrix
MapOfClassesListType m_MapsOfClasses; // Maps of classes
// Internal
std::vector<MatMapType> m_ConfMatMaps; // Accumulators
}; // end class
......
......@@ -110,24 +110,24 @@ TensorflowMultisourceModelValidate<TInputImage>
// Temporary images for outputs
m_ConfusionMatrices.clear();
m_MapsOfClasses.clear();
std::vector<MatMapType> confMatMaps;
m_ConfMatMaps.clear();
for (auto const& ref: m_References)
{
(void) ref;
// New confusion matrix
MatMapType mat;
confMatMaps.push_back(mat);
m_ConfMatMaps.push_back(mat);
}
// Run all the batches
Superclass::GenerateData();
// Compute confusion matrices
for (unsigned int i = 0 ; i < confMatMaps.size() ; i++)
for (unsigned int i = 0 ; i < m_ConfMatMaps.size() ; i++)
{
// Confusion matrix (map) for current target
MatMapType mat = confMatMaps[i];
MatMapType mat = m_ConfMatMaps[i];
// List all values
MapOfClassesType values;
......@@ -159,10 +159,6 @@ TensorflowMultisourceModelValidate<TInputImage>
m_ConfusionMatrices.push_back(matrix);
m_MapsOfClasses.push_back(values);
}
}
......@@ -171,11 +167,12 @@ TensorflowMultisourceModelValidate<TInputImage>
template <class TInputImage>
void
TensorflowMultisourceModelValidate<TInputImage>
::ProcessBatch(TensorListType & inputs, const IndexValueType & sampleStart,
::ProcessBatch(DictType & inputs, const IndexValueType & sampleStart,
const IndexValueType & batchSize)
{
// Populate input tensors
PopulateInputTensor(inputs, sampleStart, batchSize);
IndexListType empty;
this->PopulateInputTensors(inputs, sampleStart, batchSize, empty);
// Run the TF session here
TensorListType outputs;
......@@ -223,21 +220,21 @@ TensorflowMultisourceModelValidate<TInputImage>
const int classIn = static_cast<LabelValueType>(inIt.Get()[0]);
const int classRef = static_cast<LabelValueType>(refIt.Get()[0]);
if (confMatMaps[refIdx].count(classRef) == 0)
if (m_ConfMatMaps[refIdx].count(classRef) == 0)
{
MapType newMap;
newMap[classIn] = 1;
confMatMaps[refIdx][classRef] = newMap;
m_ConfMatMaps[refIdx][classRef] = newMap;
}
else
{
if (confMatMaps[refIdx][classRef].count(classIn) == 0)
if (m_ConfMatMaps[refIdx][classRef].count(classIn) == 0)
{
confMatMaps[refIdx][classRef][classIn] = 1;
m_ConfMatMaps[refIdx][classRef][classIn] = 1;
}
else
{
confMatMaps[refIdx][classRef][classIn]++;
m_ConfMatMaps[refIdx][classRef][classIn]++;
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment