An error occurred while loading the file. Please try again.
-
Guillaume Perréal authored
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbTensorflowCopyUtils.h"
namespace otb {
namespace tf {
//
// Display a TensorShape
//
std::string PrintTensorShape(const tensorflow::TensorShape & shp)
{
std::stringstream s;
unsigned int nDims = shp.dims();
s << "{" << shp.dim_size(0);
for (unsigned int d = 1 ; d < nDims ; d++)
s << ", " << shp.dim_size(d);
s << "}" ;
return s.str();
}
//
// Display infos about a tensor
//
std::string PrintTensorInfos(const tensorflow::Tensor & tensor)
{
std::stringstream s;
s << "Tensor ";
// Show dims
s << "shape is " << PrintTensorShape(tensor.shape());
// Data type
s << " data type is " << tensor.dtype();
return s.str();
}
//
// Create a tensor with the good datatype
//
template<class TImage>
tensorflow::Tensor CreateTensor(tensorflow::TensorShape & shape)
{
tensorflow::DataType ts_dt = GetTensorflowDataType<typename TImage::InternalPixelType>();
tensorflow::Tensor out_tensor(ts_dt, shape);
return out_tensor;
}
//
// Populate a tensor with the buffered region of a vector image using std::copy
// Warning: tensor datatype must be consistent with the image value type
//
template<class TImage>
void PopulateTensorFromBufferedVectorImage(const typename TImage::Pointer bufferedimagePtr, tensorflow::Tensor & out_tensor)
{
size_t n_elem = bufferedimagePtr->GetNumberOfComponentsPerPixel() *
bufferedimagePtr->GetBufferedRegion().GetNumberOfPixels();
std::copy_n(bufferedimagePtr->GetBufferPointer(),
n_elem,
out_tensor.flat<typename TImage::InternalPixelType>().data());
}
//
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
// Recopy an VectorImage region into a 4D-shaped tensorflow::Tensor ({-1, sz_y, sz_x, sz_bands})
//
template<class TImage, class TValueType=typename TImage::InternalPixelType>
void RecopyImageRegionToTensor(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region,
tensorflow::Tensor & tensor, unsigned int elemIdx) // element position along the 1st dimension
{
typename itk::ImageRegionConstIterator<TImage> inIt(inputPtr, region);
unsigned int nBands = inputPtr->GetNumberOfComponentsPerPixel();
auto tMap = tensor.tensor<TValueType, 4>();
for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt)
{
const int y = inIt.GetIndex()[1] - region.GetIndex()[1];
const int x = inIt.GetIndex()[0] - region.GetIndex()[0];
for (unsigned int band = 0 ; band < nBands ; band++)
tMap(elemIdx, y, x, band) = inIt.Get()[band];
}
}
//
// Type-agnostic version of the 'RecopyImageRegionToTensor' function
// TODO: add some numeric types
//
template<class TImage>
void RecopyImageRegionToTensorWithCast(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region,
tensorflow::Tensor & tensor, unsigned int elemIdx) // element position along the 1st dimension
{
tensorflow::DataType dt = tensor.dtype();
if (dt == tensorflow::DT_FLOAT)
RecopyImageRegionToTensor<TImage, float> (inputPtr, region, tensor, elemIdx);
else if (dt == tensorflow::DT_DOUBLE)
RecopyImageRegionToTensor<TImage, double> (inputPtr, region, tensor, elemIdx);
else if (dt == tensorflow::DT_INT64)
RecopyImageRegionToTensor<TImage, long long int>(inputPtr, region, tensor, elemIdx);
else if (dt == tensorflow::DT_INT32)
RecopyImageRegionToTensor<TImage, int> (inputPtr, region, tensor, elemIdx);
else
itkGenericExceptionMacro("TF DataType "<< dt << " not currently implemented !");
}
//
// Sample a centered patch (from index)
//
template<class TImage>
void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::IndexType & centerIndex, const typename TImage::SizeType & patchSize,
tensorflow::Tensor & tensor, unsigned int elemIdx)
{
typename TImage::IndexType regionStart;
regionStart[0] = centerIndex[0] - patchSize[0] / 2;
regionStart[1] = centerIndex[1] - patchSize[1] / 2;
typename TImage::RegionType patchRegion(regionStart, patchSize);
RecopyImageRegionToTensorWithCast<TImage>(inputPtr, patchRegion, tensor, elemIdx);
}
//
// Sample a centered patch (from coordinates)
//
template<class TImage>
void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::PointType & centerCoord, const typename TImage::SizeType & patchSize,
tensorflow::Tensor & tensor, unsigned int elemIdx)
{
// Assuming tensor is of shape {-1, sz_y, sz_x, sz_bands}
// Get the index of the center
typename TImage::IndexType centerIndex;
inputPtr->TransformPhysicalPointToIndex(centerCoord, centerIndex);
SampleCenteredPatch<TImage>(inputPtr, centerIndex, patchSize, tensor, elemIdx);
}
// Return the number of channels that the output tensor will occupy in the output image
//
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
// shape {n} --> 1 (e.g. a label)
// shape {n, c} --> c (e.g. a vector)
// shape {x, y, c} --> c (e.g. a patch)
// shape {n, x, y, c} --> c (e.g. some patches)
//
tensorflow::int64 GetNumberOfChannelsForOutputTensor(const tensorflow::Tensor & tensor)
{
const tensorflow::TensorShape shape = tensor.shape();
const int nDims = shape.dims();
if (nDims == 1)
return 1;
return shape.dim_size(nDims - 1);
}
//
// Copy a tensor into the image region
// TODO: Enable to change mapping from source tensor to image to make it more generic
//
// Right now, only the following output tensor shapes can be processed:
// shape {n} --> 1 (e.g. a label)
// shape {n, c} --> c (e.g. a vector)
// shape {x, y, c} --> c (e.g. a multichannel image)
//
template<class TImage, class TValueType>
void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion,
typename TImage::Pointer outputPtr, const typename TImage::RegionType & outputRegion, int & channelOffset)
{
// Flatten the tensor
auto tFlat = tensor.flat<TValueType>();
// Get the size of the last component of the tensor (see 'GetNumberOfChannelsForOutputTensor(...)')
const tensorflow::int64 outputDimSize_C = GetNumberOfChannelsForOutputTensor(tensor);
// Number of columns (size x of the buffer)
const tensorflow::int64 nCols = bufferRegion.GetSize(0);
// Check the tensor size vs the outputRegion size
const tensorflow::int64 nElmT = tensor.NumElements();
const tensorflow::int64 nElmI = bufferRegion.GetNumberOfPixels() * outputDimSize_C;
if (nElmI != nElmT)
{
itkGenericExceptionMacro("Number of elements in the tensor is " << nElmT <<
" but image outputRegion has " << nElmI <<
" values to fill.\nBuffer region:\n" << bufferRegion <<
"\nNumber of components: " << outputDimSize_C <<
"\nTensor shape:\n " << PrintTensorShape(tensor.shape()) <<
"\nPlease check the input(s) field of view (FOV), " <<
"the output field of expression (FOE), and the " <<
"output spacing scale if you run the model in fully " <<
"convolutional mode (how many strides in your model?)");
}
// Iterate over the image
typename itk::ImageRegionIterator<TImage> outIt(outputPtr, outputRegion);
for (outIt.GoToBegin(); !outIt.IsAtEnd(); ++outIt)
{
const int x = outIt.GetIndex()[0] - bufferRegion.GetIndex(0);
const int y = outIt.GetIndex()[1] - bufferRegion.GetIndex(1);
// TODO: it could be useful to change the tensor-->image mapping here.
// e.g use a lambda for "pos" calculation
const int pos = outputDimSize_C * (y * nCols + x);
for (unsigned int c = 0 ; c < outputDimSize_C ; c++)
outIt.Get()[channelOffset + c] = tFlat( pos + c);
}
// Update the offset
channelOffset += outputDimSize_C;
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
}
//
// Type-agnostic version of the 'CopyTensorToImageRegion' function
// TODO: add some numeric types
//
template<class TImage>
void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion,
typename TImage::Pointer outputPtr, const typename TImage::RegionType & region, int & channelOffset)
{
tensorflow::DataType dt = tensor.dtype();
if (dt == tensorflow::DT_FLOAT)
CopyTensorToImageRegion<TImage, float> (tensor, bufferRegion, outputPtr, region, channelOffset);
else if (dt == tensorflow::DT_DOUBLE)
CopyTensorToImageRegion<TImage, double> (tensor, bufferRegion, outputPtr, region, channelOffset);
else if (dt == tensorflow::DT_INT64)
CopyTensorToImageRegion<TImage, long long int>(tensor, bufferRegion, outputPtr, region, channelOffset);
else if (dt == tensorflow::DT_INT32)
CopyTensorToImageRegion<TImage, int> (tensor, bufferRegion, outputPtr, region, channelOffset);
else
itkGenericExceptionMacro("TF DataType "<< dt << " not currently implemented !");
}
//
// Compare two string lowercase
//
bool iequals(const std::string& a, const std::string& b)
{
return std::equal(a.begin(), a.end(),
b.begin(), b.end(),
[](char cha, char chb) {
return tolower(cha) == tolower(chb);
});
}
// Convert an expression into a dict
//
// Following types are supported:
// -bool
// -int
// -float
//
// e.g. is_training=true, droptout=0.2, nfeat=14
std::pair<std::string, tensorflow::Tensor> ExpressionToTensor(std::string expression)
{
std::pair<std::string, tensorflow::Tensor> dict;
std::size_t found = expression.find("=");
if (found != std::string::npos)
{
// Find name and value
std::string name = expression.substr(0, found);
std::string value = expression.substr(found+1);
dict.first = name;
// Find type
std::size_t found_dot = value.find(".") != std::string::npos;
std::size_t is_digit = value.find_first_not_of("0123456789.") == std::string::npos;
if (is_digit)
{
if (found_dot)
{
// FLOAT
try
{
float val = std::stof(value);
tensorflow::Tensor out(tensorflow::DT_FLOAT, tensorflow::TensorShape());
281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
out.scalar<float>()() = val;
dict.second = out;
}
catch(...)
{
itkGenericExceptionMacro("Error parsing name="
<< name << " with value=" << value << " as float");
}
}
else
{
// INT
try
{
int val = std::stoi(value);
tensorflow::Tensor out(tensorflow::DT_INT32, tensorflow::TensorShape());
out.scalar<int>()() = val;
dict.second = out;
}
catch(...)
{
itkGenericExceptionMacro("Error parsing name="
<< name << " with value=" << value << " as int");
}
}
}
else
{
// BOOL
bool val = true;
if (iequals(value, "true"))
{
val = true;
}
else if (iequals(value, "false"))
{
val = false;
}
else
{
itkGenericExceptionMacro("Error parsing name="
<< name << " with value=" << value << " as bool");
}
tensorflow::Tensor out(tensorflow::DT_BOOL, tensorflow::TensorShape());
out.scalar<bool>()() = val;
dict.second = out;
}
}
else
{
itkGenericExceptionMacro("The following expression is not valid: "
<< "\n\t" << expression
<< ".\nExpression must be in the form int_value=1 or float_value=1.0 or bool_value=true.");
}
return dict;
}
} // end namespace tf
} // end namespace otb