otbTensorflowCopyUtils.cxx 12.21 KiB
/*=========================================================================
  Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.
=========================================================================*/
#include "otbTensorflowCopyUtils.h"
namespace otb {
namespace tf {
// Display a TensorShape
std::string PrintTensorShape(const tensorflow::TensorShape & shp)
  std::stringstream s;
  unsigned int nDims = shp.dims();
  s << "{" << shp.dim_size(0);
  for (unsigned int d = 1 ; d < nDims ; d++)
    s << ", " << shp.dim_size(d);
  s << "}" ;
  return s.str();
// Display infos about a tensor
std::string PrintTensorInfos(const tensorflow::Tensor & tensor)
  std::stringstream s;
  s << "Tensor ";
  // Show dims
  s << "shape is " << PrintTensorShape(tensor.shape());
  // Data type
  s << " data type is " << tensor.dtype();
  return s.str();
// Create a tensor with the good datatype
template<class TImage>
tensorflow::Tensor CreateTensor(tensorflow::TensorShape & shape)
  tensorflow::DataType ts_dt = GetTensorflowDataType<typename TImage::InternalPixelType>();
  tensorflow::Tensor out_tensor(ts_dt, shape);
  return out_tensor;
// Populate a tensor with the buffered region of a vector image using std::copy
// Warning: tensor datatype must be consistent with the image value type
template<class TImage>
void PopulateTensorFromBufferedVectorImage(const typename TImage::Pointer bufferedimagePtr, tensorflow::Tensor & out_tensor)
  size_t n_elem = bufferedimagePtr->GetNumberOfComponentsPerPixel() *
      bufferedimagePtr->GetBufferedRegion().GetNumberOfPixels();
  std::copy_n(bufferedimagePtr->GetBufferPointer(),
      n_elem,
      out_tensor.flat<typename TImage::InternalPixelType>().data());
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
// Recopy an VectorImage region into a 4D-shaped tensorflow::Tensor ({-1, sz_y, sz_x, sz_bands}) // template<class TImage, class TValueType=typename TImage::InternalPixelType> void RecopyImageRegionToTensor(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region, tensorflow::Tensor & tensor, unsigned int elemIdx) // element position along the 1st dimension { typename itk::ImageRegionConstIterator<TImage> inIt(inputPtr, region); unsigned int nBands = inputPtr->GetNumberOfComponentsPerPixel(); auto tMap = tensor.tensor<TValueType, 4>(); for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt) { const int y = inIt.GetIndex()[1] - region.GetIndex()[1]; const int x = inIt.GetIndex()[0] - region.GetIndex()[0]; for (unsigned int band = 0 ; band < nBands ; band++) tMap(elemIdx, y, x, band) = inIt.Get()[band]; } } // // Type-agnostic version of the 'RecopyImageRegionToTensor' function // TODO: add some numeric types // template<class TImage> void RecopyImageRegionToTensorWithCast(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region, tensorflow::Tensor & tensor, unsigned int elemIdx) // element position along the 1st dimension { tensorflow::DataType dt = tensor.dtype(); if (dt == tensorflow::DT_FLOAT) RecopyImageRegionToTensor<TImage, float>(inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_DOUBLE) RecopyImageRegionToTensor<TImage, double>(inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_UINT64) RecopyImageRegionToTensor<TImage, unsigned long long int>(inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_INT64) RecopyImageRegionToTensor<TImage, long long int>(inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_UINT32) RecopyImageRegionToTensor<TImage, unsigned int>(inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_INT32) RecopyImageRegionToTensor<TImage, int>(inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_UINT16) RecopyImageRegionToTensor<TImage, unsigned short int> (inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_INT16) RecopyImageRegionToTensor<TImage, short int>(inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_UINT8) RecopyImageRegionToTensor<TImage, unsigned char> (inputPtr, region, tensor, elemIdx); else itkGenericExceptionMacro("TF DataType "<< dt << " not currently implemented !"); } // // Sample a centered patch (from index) // template<class TImage> void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::IndexType & centerIndex, const typename TImage::SizeType & patchSize, tensorflow::Tensor & tensor, unsigned int elemIdx) { typename TImage::IndexType regionStart; regionStart[0] = centerIndex[0] - patchSize[0] / 2; regionStart[1] = centerIndex[1] - patchSize[1] / 2; typename TImage::RegionType patchRegion(regionStart, patchSize); RecopyImageRegionToTensorWithCast<TImage>(inputPtr, patchRegion, tensor, elemIdx); } // // Sample a centered patch (from coordinates) // template<class TImage> void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::PointType & centerCoord, const typename TImage::SizeType & patchSize, tensorflow::Tensor & tensor, unsigned int elemIdx)
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
{ // Assuming tensor is of shape {-1, sz_y, sz_x, sz_bands} // Get the index of the center typename TImage::IndexType centerIndex; inputPtr->TransformPhysicalPointToIndex(centerCoord, centerIndex); SampleCenteredPatch<TImage>(inputPtr, centerIndex, patchSize, tensor, elemIdx); } // Return the number of channels that the output tensor will occupy in the output image // // shape {n} --> 1 (e.g. a label) // shape {n, c} --> c (e.g. a vector) // shape {x, y, c} --> c (e.g. a patch) // shape {n, x, y, c} --> c (e.g. some patches) // tensorflow::int64 GetNumberOfChannelsForOutputTensor(const tensorflow::Tensor & tensor) { const tensorflow::TensorShape shape = tensor.shape(); const int nDims = shape.dims(); if (nDims == 1) return 1; return shape.dim_size(nDims - 1); } // // Copy a tensor into the image region // TODO: Enable to change mapping from source tensor to image to make it more generic // // Right now, only the following output tensor shapes can be processed: // shape {n} --> 1 (e.g. a label) // shape {n, c} --> c (e.g. a vector) // shape {x, y, c} --> c (e.g. a multichannel image) // template<class TImage, class TValueType> void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion, typename TImage::Pointer outputPtr, const typename TImage::RegionType & outputRegion, int & channelOffset) { // Flatten the tensor auto tFlat = tensor.flat<TValueType>(); // Get the size of the last component of the tensor (see 'GetNumberOfChannelsForOutputTensor(...)') const tensorflow::int64 outputDimSize_C = GetNumberOfChannelsForOutputTensor(tensor); // Number of columns (size x of the buffer) const tensorflow::int64 nCols = bufferRegion.GetSize(0); // Check the tensor size vs the outputRegion size const tensorflow::int64 nElmT = tensor.NumElements(); const tensorflow::int64 nElmI = bufferRegion.GetNumberOfPixels() * outputDimSize_C; if (nElmI != nElmT) { itkGenericExceptionMacro("Number of elements in the tensor is " << nElmT << " but image outputRegion has " << nElmI << " values to fill.\nBuffer region:\n" << bufferRegion << "\nNumber of components: " << outputDimSize_C << "\nTensor shape:\n " << PrintTensorShape(tensor.shape()) << "\nPlease check the input(s) field of view (FOV), " << "the output field of expression (FOE), and the " << "output spacing scale if you run the model in fully " << "convolutional mode (how many strides in your model?)"); } // Iterate over the image typename itk::ImageRegionIterator<TImage> outIt(outputPtr, outputRegion); for (outIt.GoToBegin(); !outIt.IsAtEnd(); ++outIt) { const int x = outIt.GetIndex()[0] - bufferRegion.GetIndex(0); const int y = outIt.GetIndex()[1] - bufferRegion.GetIndex(1);