An error occurred while loading the file. Please try again.
-
Thibault Hallouin authored
Internally, rather than using the multi-dimensional character of tensors to compute all sites and all lead times at once, loops are performed for each site and each lead time, in turn, in order to minimise memory imprint. Although at the moment, the input tensors are expected to feature the sites and lead times dimensions. If memory is an issue, the user can still send smaller tensors with size 1 for those dimensions and recompose multi-sites/multi-lead times output arrays externally.
295b3208
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef otbTensorflowMultisourceModelBase_txx
#define otbTensorflowMultisourceModelBase_txx
#include "otbTensorflowMultisourceModelBase.h"
namespace otb
{
template <class TInputImage, class TOutputImage>
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::TensorflowMultisourceModelBase()
{
m_Session = nullptr;
}
template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::PushBackInputTensorBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image)
{
Superclass::PushBackInput(image);
m_InputReceptiveFields.push_back(receptiveField);
m_InputPlaceholders.push_back(placeholder);
}
template <class TInputImage, class TOutputImage>
std::stringstream
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::GenerateDebugReport(DictType & inputs)
{
// Create a debug report
std::stringstream debugReport;
// Describe the output buffered region
ImagePointerType outputPtr = this->GetOutput();
const RegionType outputReqRegion = outputPtr->GetRequestedRegion();
debugReport << "Output image buffered region: " << outputReqRegion << "\n";
// Describe inputs
for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++)
{
const ImagePointerType inputPtr = const_cast<TInputImage*>(this->GetInput(i));
const RegionType reqRegion = inputPtr->GetRequestedRegion();
debugReport << "Input #" << i << ":\n";
debugReport << "Requested region: " << reqRegion << "\n";
debugReport << "Tensor shape (\"" << inputs[i].first << "\": " << tf::PrintTensorShape(inputs[i].second.shape()) << "\n";
}
// Show user placeholders
debugReport << "User placeholders:\n" ;
for (auto& dict: this->GetUserPlaceholders())
{
debugReport << dict.first << " " << tf::PrintTensorInfos(dict.second) << "\n" << std::endl;
}
return debugReport;
}
template <class TInputImage, class TOutputImage>
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
void
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::RunSession(DictType & inputs, TensorListType & outputs)
{
// Add the user's placeholders
for (auto& dict: this->GetUserPlaceholders())
{
inputs.push_back(dict);
}
// Run the TF session here
// The session will initialize the outputs
// Run the session, evaluating our output tensors from the graph
auto status = this->GetSession()->Run(inputs, m_OutputTensors, m_TargetNodesNames, &outputs);
if (!status.ok()) {
// Create a debug report
std::stringstream debugReport = GenerateDebugReport(inputs);
// Throw an exception with the report
itkExceptionMacro("Can't run the tensorflow session !\n" <<
"Tensorflow error message:\n" << status.ToString() << "\n"
"OTB Filter debug message:\n" << debugReport.str() );
}
}
template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::GenerateOutputInformation()
{
// Check that the number of the following is the same
// - input placeholders names
// - input receptive fields
// - input images
const unsigned int nbInputs = this->GetNumberOfInputs();
if (nbInputs != m_InputReceptiveFields.size() || nbInputs != m_InputPlaceholders.size())
{
itkExceptionMacro("Number of input images is " << nbInputs <<
" but the number of input patches size is " << m_InputReceptiveFields.size() <<
" and the number of input tensors names is " << m_InputPlaceholders.size());
}
// Check that the number of the following is the same
// - output tensors names
// - output expression fields
if (m_OutputExpressionFields.size() != m_OutputTensors.size())
{
itkExceptionMacro("Number of output tensors names is " << m_OutputTensors.size() <<
" but the number of output fields of expression is " << m_OutputExpressionFields.size());
}
//////////////////////////////////////////////////////////////////////////////////////////
// Get tensors information
//////////////////////////////////////////////////////////////////////////////////////////
// Get input and output tensors datatypes and shapes
tf::GetTensorAttributes(m_Graph, m_InputPlaceholders, m_InputTensorsShapes, m_InputTensorsDataTypes);
tf::GetTensorAttributes(m_Graph, m_OutputTensors, m_OutputTensorsShapes, m_OutputTensorsDataTypes);
}
} // end namespace otb
141142143
#endif