• Thibault Hallouin's avatar
    add dimensions for sites/lead times to probabilistic evaluator · 295b3208
    Thibault Hallouin authored
    Internally, rather than using the multi-dimensional character of
    tensors to compute all sites and all lead times at once, loops are
    performed for each site and each lead time, in turn, in order to
    minimise memory imprint. Although at the moment, the input tensors are
    expected to feature the sites and lead times dimensions. If memory is
    an issue, the user can still send smaller tensors with size 1 for those
    dimensions and recompose multi-sites/multi-lead times output arrays
    externally.
    295b3208
otbTensorflowMultisourceModelBase.hxx 4.78 KiB
/*=========================================================================
  Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.
=========================================================================*/
#ifndef otbTensorflowMultisourceModelBase_txx
#define otbTensorflowMultisourceModelBase_txx
#include "otbTensorflowMultisourceModelBase.h"
namespace otb
template <class TInputImage, class TOutputImage>
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::TensorflowMultisourceModelBase()
  m_Session = nullptr;
template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::PushBackInputTensorBundle(std::string placeholder, SizeType receptiveField, ImagePointerType image)
  Superclass::PushBackInput(image);
  m_InputReceptiveFields.push_back(receptiveField);
  m_InputPlaceholders.push_back(placeholder);
template <class TInputImage, class TOutputImage>
std::stringstream
TensorflowMultisourceModelBase<TInputImage, TOutputImage>
::GenerateDebugReport(DictType & inputs)
  // Create a debug report
  std::stringstream debugReport;
  // Describe the output buffered region
  ImagePointerType outputPtr = this->GetOutput();
  const RegionType outputReqRegion = outputPtr->GetRequestedRegion();
  debugReport << "Output image buffered region: " << outputReqRegion << "\n";
  // Describe inputs
  for (unsigned int i = 0 ; i < this->GetNumberOfInputs() ; i++)
    const ImagePointerType inputPtr = const_cast<TInputImage*>(this->GetInput(i));
    const RegionType reqRegion = inputPtr->GetRequestedRegion();
    debugReport << "Input #" << i << ":\n";
    debugReport << "Requested region: " << reqRegion << "\n";
    debugReport << "Tensor shape (\"" << inputs[i].first << "\": " << tf::PrintTensorShape(inputs[i].second.shape()) << "\n";
  // Show user placeholders
  debugReport << "User placeholders:\n" ;
  for (auto& dict: this->GetUserPlaceholders())
    debugReport << dict.first << " " << tf::PrintTensorInfos(dict.second) << "\n" << std::endl;
  return debugReport;
template <class TInputImage, class TOutputImage>
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
void TensorflowMultisourceModelBase<TInputImage, TOutputImage> ::RunSession(DictType & inputs, TensorListType & outputs) { // Add the user's placeholders for (auto& dict: this->GetUserPlaceholders()) { inputs.push_back(dict); } // Run the TF session here // The session will initialize the outputs // Run the session, evaluating our output tensors from the graph auto status = this->GetSession()->Run(inputs, m_OutputTensors, m_TargetNodesNames, &outputs); if (!status.ok()) { // Create a debug report std::stringstream debugReport = GenerateDebugReport(inputs); // Throw an exception with the report itkExceptionMacro("Can't run the tensorflow session !\n" << "Tensorflow error message:\n" << status.ToString() << "\n" "OTB Filter debug message:\n" << debugReport.str() ); } } template <class TInputImage, class TOutputImage> void TensorflowMultisourceModelBase<TInputImage, TOutputImage> ::GenerateOutputInformation() { // Check that the number of the following is the same // - input placeholders names // - input receptive fields // - input images const unsigned int nbInputs = this->GetNumberOfInputs(); if (nbInputs != m_InputReceptiveFields.size() || nbInputs != m_InputPlaceholders.size()) { itkExceptionMacro("Number of input images is " << nbInputs << " but the number of input patches size is " << m_InputReceptiveFields.size() << " and the number of input tensors names is " << m_InputPlaceholders.size()); } // Check that the number of the following is the same // - output tensors names // - output expression fields if (m_OutputExpressionFields.size() != m_OutputTensors.size()) { itkExceptionMacro("Number of output tensors names is " << m_OutputTensors.size() << " but the number of output fields of expression is " << m_OutputExpressionFields.size()); } ////////////////////////////////////////////////////////////////////////////////////////// // Get tensors information ////////////////////////////////////////////////////////////////////////////////////////// // Get input and output tensors datatypes and shapes tf::GetTensorAttributes(m_Graph, m_InputPlaceholders, m_InputTensorsShapes, m_InputTensorsDataTypes); tf::GetTensorAttributes(m_Graph, m_OutputTensors, m_OutputTensorsShapes, m_OutputTensorsDataTypes); } } // end namespace otb
141142143
#endif