-
Cresson Remi authoredc7b82461
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbTensorflowCopyUtils.h"
namespace otb {
namespace tf {
//
// Display a TensorShape
//
std::string PrintTensorShape(const tensorflow::TensorShape & shp)
{
std::stringstream s;
unsigned int nDims = shp.dims();
s << "{" << shp.dim_size(0);
for (unsigned int d = 1 ; d < nDims ; d++)
s << ", " << shp.dim_size(d);
s << "}" ;
return s.str();
}
//
// Display infos about a tensor
//
std::string PrintTensorInfos(const tensorflow::Tensor & tensor)
{
std::stringstream s;
s << "Tensor ";
// Show dims
s << "shape is " << PrintTensorShape(tensor.shape());
// Data type
s << " data type is " << tensor.dtype();
return s.str();
}
//
// Create a tensor with the good datatype
//
template<class TImage>
tensorflow::Tensor CreateTensor(tensorflow::TensorShape & shape)
{
tensorflow::DataType ts_dt = GetTensorflowDataType<typename TImage::InternalPixelType>();
tensorflow::Tensor out_tensor(ts_dt, shape);
return out_tensor;
}
//
// Populate a tensor with the buffered region of a vector image using std::copy
// Warning: tensor datatype must be consistent with the image value type
//
template<class TImage>
void PopulateTensorFromBufferedVectorImage(const typename TImage::Pointer bufferedimagePtr, tensorflow::Tensor & out_tensor)
{
size_t n_elem = bufferedimagePtr->GetNumberOfComponentsPerPixel() *
bufferedimagePtr->GetBufferedRegion().GetNumberOfPixels();
std::copy_n(bufferedimagePtr->GetBufferPointer(),
n_elem,
out_tensor.flat<typename TImage::InternalPixelType>().data());
}
//
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
// Recopy an VectorImage region into a 4D-shaped tensorflow::Tensor ({-1, sz_y, sz_x, sz_bands})
//
template<class TImage, class TValueType=typename TImage::InternalPixelType>
void RecopyImageRegionToTensor(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region,
tensorflow::Tensor & tensor, unsigned int elemIdx) // element position along the 1st dimension
{
typename itk::ImageRegionConstIterator<TImage> inIt(inputPtr, region);
unsigned int nBands = inputPtr->GetNumberOfComponentsPerPixel();
auto tMap = tensor.tensor<TValueType, 4>();
for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt)
{
const int y = inIt.GetIndex()[1] - region.GetIndex()[1];
const int x = inIt.GetIndex()[0] - region.GetIndex()[0];
for (unsigned int band = 0 ; band < nBands ; band++)
tMap(elemIdx, y, x, band) = inIt.Get()[band];
}
}
//
// Type-agnostic version of the 'RecopyImageRegionToTensor' function
// TODO: add some numeric types
//
template<class TImage>
void RecopyImageRegionToTensorWithCast(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region,
tensorflow::Tensor & tensor, unsigned int elemIdx) // element position along the 1st dimension
{
tensorflow::DataType dt = tensor.dtype();
if (dt == tensorflow::DT_FLOAT)
RecopyImageRegionToTensor<TImage, float>(inputPtr, region, tensor, elemIdx);
else if (dt == tensorflow::DT_DOUBLE)
RecopyImageRegionToTensor<TImage, double>(inputPtr, region, tensor, elemIdx);
else if (dt == tensorflow::DT_UINT64)
RecopyImageRegionToTensor<TImage, unsigned long long int>(inputPtr, region, tensor, elemIdx);
else if (dt == tensorflow::DT_INT64)
RecopyImageRegionToTensor<TImage, long long int>(inputPtr, region, tensor, elemIdx);
else if (dt == tensorflow::DT_UINT32)
RecopyImageRegionToTensor<TImage, unsigned int>(inputPtr, region, tensor, elemIdx);
else if (dt == tensorflow::DT_INT32)
RecopyImageRegionToTensor<TImage, int>(inputPtr, region, tensor, elemIdx);
else if (dt == tensorflow::DT_UINT16)
RecopyImageRegionToTensor<TImage, unsigned short int> (inputPtr, region, tensor, elemIdx);
else if (dt == tensorflow::DT_INT16)
RecopyImageRegionToTensor<TImage, short int>(inputPtr, region, tensor, elemIdx);
else if (dt == tensorflow::DT_UINT8)
RecopyImageRegionToTensor<TImage, unsigned char> (inputPtr, region, tensor, elemIdx);
else
itkGenericExceptionMacro("TF DataType "<< dt << " not currently implemented !");
}
//
// Sample a centered patch (from index)
//
template<class TImage>
void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::IndexType & centerIndex, const typename TImage::SizeType & patchSize,
tensorflow::Tensor & tensor, unsigned int elemIdx)
{
typename TImage::IndexType regionStart;
regionStart[0] = centerIndex[0] - patchSize[0] / 2;
regionStart[1] = centerIndex[1] - patchSize[1] / 2;
typename TImage::RegionType patchRegion(regionStart, patchSize);
RecopyImageRegionToTensorWithCast<TImage>(inputPtr, patchRegion, tensor, elemIdx);
}
//
// Sample a centered patch (from coordinates)
//
template<class TImage>
void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::PointType & centerCoord, const typename TImage::SizeType & patchSize,
tensorflow::Tensor & tensor, unsigned int elemIdx)
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
{
// Assuming tensor is of shape {-1, sz_y, sz_x, sz_bands}
// Get the index of the center
typename TImage::IndexType centerIndex;
inputPtr->TransformPhysicalPointToIndex(centerCoord, centerIndex);
SampleCenteredPatch<TImage>(inputPtr, centerIndex, patchSize, tensor, elemIdx);
}
// Return the number of channels that the output tensor will occupy in the output image
//
// shape {n} --> 1 (e.g. a label)
// shape {n, c} --> c (e.g. a vector)
// shape {x, y, c} --> c (e.g. a patch)
// shape {n, x, y, c} --> c (e.g. some patches)
//
tensorflow::int64 GetNumberOfChannelsForOutputTensor(const tensorflow::Tensor & tensor)
{
const tensorflow::TensorShape shape = tensor.shape();
const int nDims = shape.dims();
if (nDims == 1)
return 1;
return shape.dim_size(nDims - 1);
}
//
// Copy a tensor into the image region
// TODO: Enable to change mapping from source tensor to image to make it more generic
//
// Right now, only the following output tensor shapes can be processed:
// shape {n} --> 1 (e.g. a label)
// shape {n, c} --> c (e.g. a vector)
// shape {x, y, c} --> c (e.g. a multichannel image)
//
template<class TImage, class TValueType>
void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion,
typename TImage::Pointer outputPtr, const typename TImage::RegionType & outputRegion, int & channelOffset)
{
// Flatten the tensor
auto tFlat = tensor.flat<TValueType>();
// Get the size of the last component of the tensor (see 'GetNumberOfChannelsForOutputTensor(...)')
const tensorflow::int64 outputDimSize_C = GetNumberOfChannelsForOutputTensor(tensor);
// Number of columns (size x of the buffer)
const tensorflow::int64 nCols = bufferRegion.GetSize(0);
// Check the tensor size vs the outputRegion size
const tensorflow::int64 nElmT = tensor.NumElements();
const tensorflow::int64 nElmI = bufferRegion.GetNumberOfPixels() * outputDimSize_C;
if (nElmI != nElmT)
{
itkGenericExceptionMacro("Number of elements in the tensor is " << nElmT <<
" but image outputRegion has " << nElmI <<
" values to fill.\nBuffer region:\n" << bufferRegion <<
"\nNumber of components: " << outputDimSize_C <<
"\nTensor shape:\n " << PrintTensorShape(tensor.shape()) <<
"\nPlease check the input(s) field of view (FOV), " <<
"the output field of expression (FOE), and the " <<
"output spacing scale if you run the model in fully " <<
"convolutional mode (how many strides in your model?)");
}
// Iterate over the image
typename itk::ImageRegionIterator<TImage> outIt(outputPtr, outputRegion);
for (outIt.GoToBegin(); !outIt.IsAtEnd(); ++outIt)
{
const int x = outIt.GetIndex()[0] - bufferRegion.GetIndex(0);
const int y = outIt.GetIndex()[1] - bufferRegion.GetIndex(1);