Forked from Cresson Remi / otbtf
Source project has a limited visibility.
otbTensorflowCopyUtils.cxx 11.60 KiB
/*=========================================================================
  Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.
=========================================================================*/
#include "otbTensorflowCopyUtils.h"
namespace otb {
namespace tf {
// Display a TensorShape
std::string PrintTensorShape(const tensorflow::TensorShape & shp)
  std::stringstream s;
  unsigned int nDims = shp.dims();
  s << "{" << shp.dim_size(0);
  for (unsigned int d = 1 ; d < nDims ; d++)
    s << ", " << shp.dim_size(d);
  s << "}" ;
  return s.str();
// Display infos about a tensor
std::string PrintTensorInfos(const tensorflow::Tensor & tensor)
  std::stringstream s;
  s << "Tensor ";
  // Show dims
  s << "shape is " << PrintTensorShape(tensor.shape());
  // Data type
  s << " data type is " << tensor.dtype();
  return s.str();
// Create a tensor with the good datatype
template<class TImage>
tensorflow::Tensor CreateTensor(tensorflow::TensorShape & shape)
  tensorflow::DataType ts_dt = GetTensorflowDataType<typename TImage::InternalPixelType>();
  tensorflow::Tensor out_tensor(ts_dt, shape);
  return out_tensor;
// Populate a tensor with the buffered region of a vector image using std::copy
// Warning: tensor datatype must be consistent with the image value type
template<class TImage>
void PopulateTensorFromBufferedVectorImage(const typename TImage::Pointer bufferedimagePtr, tensorflow::Tensor & out_tensor)
  size_t n_elem = bufferedimagePtr->GetNumberOfComponentsPerPixel() *
      bufferedimagePtr->GetBufferedRegion().GetNumberOfPixels();
  std::copy_n(bufferedimagePtr->GetBufferPointer(),
      n_elem,
      out_tensor.flat<typename TImage::InternalPixelType>().data());
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
// Recopy an VectorImage region into a 4D-shaped tensorflow::Tensor ({-1, sz_y, sz_x, sz_bands}) // template<class TImage, class TValueType=typename TImage::InternalPixelType> void RecopyImageRegionToTensor(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region, tensorflow::Tensor & tensor, unsigned int elemIdx) // element position along the 1st dimension { typename itk::ImageRegionConstIterator<TImage> inIt(inputPtr, region); unsigned int nBands = inputPtr->GetNumberOfComponentsPerPixel(); auto tMap = tensor.tensor<TValueType, 4>(); for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt) { const int y = inIt.GetIndex()[1] - region.GetIndex()[1]; const int x = inIt.GetIndex()[0] - region.GetIndex()[0]; for (unsigned int band = 0 ; band < nBands ; band++) tMap(elemIdx, y, x, band) = inIt.Get()[band]; } } // // Type-agnostic version of the 'RecopyImageRegionToTensor' function // TODO: add some numeric types // template<class TImage> void RecopyImageRegionToTensorWithCast(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region, tensorflow::Tensor & tensor, unsigned int elemIdx) // element position along the 1st dimension { tensorflow::DataType dt = tensor.dtype(); if (dt == tensorflow::DT_FLOAT) RecopyImageRegionToTensor<TImage, float> (inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_DOUBLE) RecopyImageRegionToTensor<TImage, double> (inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_INT64) RecopyImageRegionToTensor<TImage, long long int>(inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_INT32) RecopyImageRegionToTensor<TImage, int> (inputPtr, region, tensor, elemIdx); else itkGenericExceptionMacro("TF DataType "<< dt << " not currently implemented !"); } // // Sample a centered patch (from index) // template<class TImage> void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::IndexType & centerIndex, const typename TImage::SizeType & patchSize, tensorflow::Tensor & tensor, unsigned int elemIdx) { typename TImage::IndexType regionStart; regionStart[0] = centerIndex[0] - patchSize[0] / 2; regionStart[1] = centerIndex[1] - patchSize[1] / 2; typename TImage::RegionType patchRegion(regionStart, patchSize); RecopyImageRegionToTensorWithCast<TImage>(inputPtr, patchRegion, tensor, elemIdx); } // // Sample a centered patch (from coordinates) // template<class TImage> void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::PointType & centerCoord, const typename TImage::SizeType & patchSize, tensorflow::Tensor & tensor, unsigned int elemIdx) { // Assuming tensor is of shape {-1, sz_y, sz_x, sz_bands} // Get the index of the center typename TImage::IndexType centerIndex; inputPtr->TransformPhysicalPointToIndex(centerCoord, centerIndex); SampleCenteredPatch<TImage>(inputPtr, centerIndex, patchSize, tensor, elemIdx); } // Return the number of channels that the output tensor will occupy in the output image //
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
// shape {n} --> 1 (e.g. a label) // shape {n, c} --> c (e.g. a vector) // shape {x, y, c} --> c (e.g. a patch) // shape {n, x, y, c} --> c (e.g. some patches) // tensorflow::int64 GetNumberOfChannelsForOutputTensor(const tensorflow::Tensor & tensor) { const tensorflow::TensorShape shape = tensor.shape(); const int nDims = shape.dims(); if (nDims == 1) return 1; return shape.dim_size(nDims - 1); } // // Copy a tensor into the image region // TODO: Enable to change mapping from source tensor to image to make it more generic // // Right now, only the following output tensor shapes can be processed: // shape {n} --> 1 (e.g. a label) // shape {n, c} --> c (e.g. a vector) // shape {x, y, c} --> c (e.g. a multichannel image) // template<class TImage, class TValueType> void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion, typename TImage::Pointer outputPtr, const typename TImage::RegionType & outputRegion, int & channelOffset) { // Flatten the tensor auto tFlat = tensor.flat<TValueType>(); // Get the size of the last component of the tensor (see 'GetNumberOfChannelsForOutputTensor(...)') const tensorflow::int64 outputDimSize_C = GetNumberOfChannelsForOutputTensor(tensor); // Number of columns (size x of the buffer) const tensorflow::int64 nCols = bufferRegion.GetSize(0); // Check the tensor size vs the outputRegion size const tensorflow::int64 nElmT = tensor.NumElements(); const tensorflow::int64 nElmI = bufferRegion.GetNumberOfPixels() * outputDimSize_C; if (nElmI != nElmT) { itkGenericExceptionMacro("Number of elements in the tensor is " << nElmT << " but image outputRegion has " << nElmI << " values to fill.\nBuffer region:\n" << bufferRegion << "\nNumber of components: " << outputDimSize_C << "\nTensor shape:\n " << PrintTensorShape(tensor.shape()) << "\nPlease check the input(s) field of view (FOV), " << "the output field of expression (FOE), and the " << "output spacing scale if you run the model in fully " << "convolutional mode (how many strides in your model?)"); } // Iterate over the image typename itk::ImageRegionIterator<TImage> outIt(outputPtr, outputRegion); for (outIt.GoToBegin(); !outIt.IsAtEnd(); ++outIt) { const int x = outIt.GetIndex()[0] - bufferRegion.GetIndex(0); const int y = outIt.GetIndex()[1] - bufferRegion.GetIndex(1); // TODO: it could be useful to change the tensor-->image mapping here. // e.g use a lambda for "pos" calculation const int pos = outputDimSize_C * (y * nCols + x); for (unsigned int c = 0 ; c < outputDimSize_C ; c++) outIt.Get()[channelOffset + c] = tFlat( pos + c); } // Update the offset channelOffset += outputDimSize_C;
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
} // // Type-agnostic version of the 'CopyTensorToImageRegion' function // TODO: add some numeric types // template<class TImage> void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion, typename TImage::Pointer outputPtr, const typename TImage::RegionType & region, int & channelOffset) { tensorflow::DataType dt = tensor.dtype(); if (dt == tensorflow::DT_FLOAT) CopyTensorToImageRegion<TImage, float> (tensor, bufferRegion, outputPtr, region, channelOffset); else if (dt == tensorflow::DT_DOUBLE) CopyTensorToImageRegion<TImage, double> (tensor, bufferRegion, outputPtr, region, channelOffset); else if (dt == tensorflow::DT_INT64) CopyTensorToImageRegion<TImage, long long int>(tensor, bufferRegion, outputPtr, region, channelOffset); else if (dt == tensorflow::DT_INT32) CopyTensorToImageRegion<TImage, int> (tensor, bufferRegion, outputPtr, region, channelOffset); else itkGenericExceptionMacro("TF DataType "<< dt << " not currently implemented !"); } // // Compare two string lowercase // bool iequals(const std::string& a, const std::string& b) { return std::equal(a.begin(), a.end(), b.begin(), b.end(), [](char cha, char chb) { return tolower(cha) == tolower(chb); }); } // Convert an expression into a dict // // Following types are supported: // -bool // -int // -float // // e.g. is_training=true, droptout=0.2, nfeat=14 std::pair<std::string, tensorflow::Tensor> ExpressionToTensor(std::string expression) { std::pair<std::string, tensorflow::Tensor> dict; std::size_t found = expression.find("="); if (found != std::string::npos) { // Find name and value std::string name = expression.substr(0, found); std::string value = expression.substr(found+1); dict.first = name; // Find type std::size_t found_dot = value.find(".") != std::string::npos; std::size_t is_digit = value.find_first_not_of("0123456789.") == std::string::npos; if (is_digit) { if (found_dot) { // FLOAT try { float val = std::stof(value); tensorflow::Tensor out(tensorflow::DT_FLOAT, tensorflow::TensorShape());
281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
out.scalar<float>()() = val; dict.second = out; } catch(...) { itkGenericExceptionMacro("Error parsing name=" << name << " with value=" << value << " as float"); } } else { // INT try { int val = std::stoi(value); tensorflow::Tensor out(tensorflow::DT_INT32, tensorflow::TensorShape()); out.scalar<int>()() = val; dict.second = out; } catch(...) { itkGenericExceptionMacro("Error parsing name=" << name << " with value=" << value << " as int"); } } } else { // BOOL bool val = true; if (iequals(value, "true")) { val = true; } else if (iequals(value, "false")) { val = false; } else { itkGenericExceptionMacro("Error parsing name=" << name << " with value=" << value << " as bool"); } tensorflow::Tensor out(tensorflow::DT_BOOL, tensorflow::TensorShape()); out.scalar<bool>()() = val; dict.second = out; } } else { itkGenericExceptionMacro("The following expression is not valid: " << "\n\t" << expression << ".\nExpression must be in the form int_value=1 or float_value=1.0 or bool_value=true."); } return dict; } } // end namespace tf } // end namespace otb