otbTensorflowMultisourceModelBase.hxx 8.80 KiB
/*=========================================================================
     Copyright (c) 2018-2019 IRSTEA
     Copyright (c) 2020-2021 INRAE
     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.
=========================================================================*/
#ifndef otbTensorflowMultisourceModelBase_txx
#define otbTensorflowMultisourceModelBase_txx
#include "otbTensorflowMultisourceModelBase.h"
namespace otb
template <class TInputImage, class TOutputImage>
TensorflowMultisourceModelBase<TInputImage, TOutputImage>::TensorflowMultisourceModelBase()
  Superclass::SetCoordinateTolerance(itk::NumericTraits<double>::max());
  Superclass::SetDirectionTolerance(itk::NumericTraits<double>::max());
  m_SavedModel = NULL;
template <class TInputImage, class TOutputImage>
tensorflow::SignatureDef
TensorflowMultisourceModelBase<TInputImage, TOutputImage>::GetSignatureDef()
  auto                     signatures = this->GetSavedModel()->GetSignatures();
  tensorflow::SignatureDef signature_def;
  if (signatures.size() == 0)
    itkExceptionMacro("There are no available signatures for this tag-set. \n"
                      << "Please check which tag-set to use by running "
                      << "`saved_model_cli show --dir your_model_dir --all`");
  // If serving_default key exists (which is the default for TF saved model), choose it as signature
  // Else, choose the first one
  if (signatures.contains(tensorflow::kDefaultServingSignatureDefKey))
    signature_def = signatures.at(tensorflow::kDefaultServingSignatureDefKey);
  else
    signature_def = signatures.begin()->second;
  return signature_def;
template <class TInputImage, class TOutputImage>
void
TensorflowMultisourceModelBase<TInputImage, TOutputImage>::PushBackInputTensorBundle(
  std::string       placeholder,
  SizeType          receptiveField,
  ImagePointerType  image,
  bool              useNodata,
  InternalPixelType nodataValue)
  Superclass::PushBackInput(image);
  m_InputReceptiveFields.push_back(receptiveField);
  m_InputPlaceholders.push_back(placeholder);
  m_InputUseNodata.push_back(useNodata);
  m_InputNodataValues.push_back(nodataValue);
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
template <class TInputImage, class TOutputImage> std::stringstream TensorflowMultisourceModelBase<TInputImage, TOutputImage>::GenerateDebugReport(DictType & inputs) { // Create a debug report std::stringstream debugReport; // Describe the output buffered region ImagePointerType outputPtr = this->GetOutput(); const RegionType outputReqRegion = outputPtr->GetRequestedRegion(); debugReport << "Output image buffered region: " << outputReqRegion << "\n"; // Describe inputs for (unsigned int i = 0; i < this->GetNumberOfInputs(); i++) { const ImagePointerType inputPtr = const_cast<TInputImage *>(this->GetInput(i)); const RegionType reqRegion = inputPtr->GetRequestedRegion(); debugReport << "Input #" << i << ":\n"; debugReport << "Requested region: " << reqRegion << "\n"; debugReport << "Tensor \"" << inputs[i].first << "\": " << tf::PrintTensorInfos(inputs[i].second) << "\n"; } // Show user placeholders debugReport << "User placeholders:\n"; for (auto & dict : this->GetUserPlaceholders()) { debugReport << "Tensor \"" << dict.first << "\": " << tf::PrintTensorInfos(dict.second) << "\n" << std::endl; } return debugReport; } template <class TInputImage, class TOutputImage> void TensorflowMultisourceModelBase<TInputImage, TOutputImage>::RunSession(DictType & inputs, TensorListType & outputs, bool & nodata) { // Run the TF session here // The session will initialize the outputs // `inputs` corresponds to a mapping {name, tensor}, with the name being specified by the user when calling // TensorFlowModelServe we must adapt it to `inputs_new`, that corresponds to a mapping {layerName, tensor}, with the // layerName being from the model DictType inputs_new; // Add the user's placeholders std::size_t k = 0; for (auto & dict : this->GetUserPlaceholders()) { inputs_new.emplace_back(m_InputConstants[k], dict.second); k++; } // Add input tensors // During this step we also check for nodata values nodata = false; k = 0; for (auto & dict : inputs) { auto inputTensor = dict.second; inputs_new.emplace_back(m_InputLayers[k], inputTensor); if (m_InputUseNodata[k] == true) { const auto nodataValue = m_InputNodataValues[k]; const tensorflow::int64 nElmT = inputTensor.NumElements(); tensorflow::int64 ndCount = 0; auto array = inputTensor.flat<InternalPixelType>(); for (tensorflow::int64 i = 0 ; i < nElmT ; i++) if (array(i) == nodataValue)
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
ndCount++; if (ndCount == nElmT) { nodata = true; return; } } k += 1; } // Run the session, evaluating our output tensors from the graph auto status = this->GetSavedModel()->session.get()->Run(inputs_new, m_OutputLayers, m_TargetNodesNames, &outputs); if (!status.ok()) { // Create a debug report std::stringstream debugReport = GenerateDebugReport(inputs); // Throw an exception with the report itkExceptionMacro("Can't run the tensorflow session !\n" << "Tensorflow error message:\n" << status.ToString() << "\n" "OTB Filter debug message:\n" << debugReport.str()); } } template <class TInputImage, class TOutputImage> void TensorflowMultisourceModelBase<TInputImage, TOutputImage>::RunSession(DictType & inputs, TensorListType & outputs) { bool nodata; this->RunSession(inputs, outputs, nodata); } template <class TInputImage, class TOutputImage> void TensorflowMultisourceModelBase<TInputImage, TOutputImage>::GenerateOutputInformation() { // Check that the number of the following is the same // - input placeholders names // - input receptive fields // - input images const unsigned int nbInputs = this->GetNumberOfInputs(); if (nbInputs != m_InputReceptiveFields.size() || nbInputs != m_InputPlaceholders.size()) { itkExceptionMacro("Number of input images is " << nbInputs << " but the number of input patches size is " << m_InputReceptiveFields.size() << " and the number of input tensors names is " << m_InputPlaceholders.size()); } // Check that no-data values size is consistent with the inputs // If no value is specified, set a vector of the same size as the inputs if (m_InputNodataValues.size() == 0 && m_InputUseNodata.size() == 0) { m_InputUseNodata = BoolListType(nbInputs, false); m_InputNodataValues = ValueListType(nbInputs, 0.0); } if (nbInputs != m_InputNodataValues.size() || nbInputs != m_InputUseNodata.size()) { itkExceptionMacro("Number of input images is " << nbInputs << " but the number of no-data values is not consistent"); } ////////////////////////////////////////////////////////////////////////////////////////// // Get tensors information ////////////////////////////////////////////////////////////////////////////////////////// // Set all subelement of the model
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
auto signaturedef = this->GetSignatureDef(); // Given the inputs/outputs names that the user specified, get the names of the inputs/outputs contained in the model // and other infos (shapes, dtypes) // For example, for output names specified by the user m_OutputTensors = ['s2t', 's2t_pad'], // this will return m_OutputLayers = ['PartitionedCall:0', 'PartitionedCall:1'] // In case the user hasn't named the output, i.e. m_OutputTensors = [''], // this will return the first output m_OutputLayers = ['PartitionedCall:0'] StringList constantsNames; std::transform(m_UserPlaceholders.begin(), m_UserPlaceholders.end(), std::back_inserter(constantsNames), [](const DictElementType & p) { return p.first; }); if (m_UserPlaceholders.size() > 0) { // Avoid the unnecessary warning when no placeholder is fed tf::GetTensorAttributes(signaturedef.inputs(), constantsNames, m_InputConstants, m_InputConstantsShapes, m_InputConstantsDataTypes); } tf::GetTensorAttributes(signaturedef.inputs(), m_InputPlaceholders, m_InputLayers, m_InputTensorsShapes, m_InputTensorsDataTypes, constantsNames); tf::GetTensorAttributes( signaturedef.outputs(), m_OutputTensors, m_OutputLayers, m_OutputTensorsShapes, m_OutputTensorsDataTypes); } } // end namespace otb #endif