/*========================================================================= Copyright (c) Remi Cresson (IRSTEA). All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the above copyright notices for more information. =========================================================================*/ #include "otbTensorflowCopyUtils.h" namespace otb { namespace tf { // // Display a TensorShape // std::string PrintTensorShape(const tensorflow::TensorShape & shp) { std::stringstream s; unsigned int nDims = shp.dims(); s << "{" << shp.dim_size(0); for (unsigned int d = 1 ; d < nDims ; d++) s << ", " << shp.dim_size(d); s << "}" ; return s.str(); } // // Display infos about a tensor // std::string PrintTensorInfos(const tensorflow::Tensor & tensor) { std::stringstream s; s << "Tensor "; // Show dims s << "shape is " << PrintTensorShape(tensor.shape()); // Data type s << " data type is " << tensor.dtype(); return s.str(); } // // Create a tensor with the good datatype // template<class TImage> tensorflow::Tensor CreateTensor(tensorflow::TensorShape & shape) { tensorflow::DataType ts_dt = GetTensorflowDataType<typename TImage::InternalPixelType>(); tensorflow::Tensor out_tensor(ts_dt, shape); return out_tensor; } // // Populate a tensor with the buffered region of a vector image using std::copy // Warning: tensor datatype must be consistent with the image value type // template<class TImage> void PopulateTensorFromBufferedVectorImage(const typename TImage::Pointer bufferedimagePtr, tensorflow::Tensor & out_tensor) { size_t n_elem = bufferedimagePtr->GetNumberOfComponentsPerPixel() * bufferedimagePtr->GetBufferedRegion().GetNumberOfPixels(); std::copy_n(bufferedimagePtr->GetBufferPointer(), n_elem, out_tensor.flat<typename TImage::InternalPixelType>().data()); } // // Recopy an VectorImage region into a 4D-shaped tensorflow::Tensor ({-1, sz_y, sz_x, sz_bands}) // template<class TImage, class TValueType=typename TImage::InternalPixelType> void RecopyImageRegionToTensor(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region, tensorflow::Tensor & tensor, unsigned int elemIdx) // element position along the 1st dimension { typename itk::ImageRegionConstIterator<TImage> inIt(inputPtr, region); unsigned int nBands = inputPtr->GetNumberOfComponentsPerPixel(); auto tMap = tensor.tensor<TValueType, 4>(); for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt) { const int y = inIt.GetIndex()[1] - region.GetIndex()[1]; const int x = inIt.GetIndex()[0] - region.GetIndex()[0]; for (unsigned int band = 0 ; band < nBands ; band++) tMap(elemIdx, y, x, band) = inIt.Get()[band]; } } // // Type-agnostic version of the 'RecopyImageRegionToTensor' function // TODO: add some numeric types // template<class TImage> void RecopyImageRegionToTensorWithCast(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region, tensorflow::Tensor & tensor, unsigned int elemIdx) // element position along the 1st dimension { tensorflow::DataType dt = tensor.dtype(); if (dt == tensorflow::DT_FLOAT) RecopyImageRegionToTensor<TImage, float> (inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_DOUBLE) RecopyImageRegionToTensor<TImage, double> (inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_INT64) RecopyImageRegionToTensor<TImage, long long int>(inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_INT32) RecopyImageRegionToTensor<TImage, int> (inputPtr, region, tensor, elemIdx); else itkGenericExceptionMacro("TF DataType "<< dt << " not currently implemented !"); } // // Sample a centered patch (from index) // template<class TImage> void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::IndexType & centerIndex, const typename TImage::SizeType & patchSize, tensorflow::Tensor & tensor, unsigned int elemIdx) { typename TImage::IndexType regionStart; regionStart[0] = centerIndex[0] - patchSize[0] / 2; regionStart[1] = centerIndex[1] - patchSize[1] / 2; typename TImage::RegionType patchRegion(regionStart, patchSize); RecopyImageRegionToTensorWithCast<TImage>(inputPtr, patchRegion, tensor, elemIdx); } // // Sample a centered patch (from coordinates) // template<class TImage> void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::PointType & centerCoord, const typename TImage::SizeType & patchSize, tensorflow::Tensor & tensor, unsigned int elemIdx) { // Assuming tensor is of shape {-1, sz_y, sz_x, sz_bands} // Get the index of the center typename TImage::IndexType centerIndex; inputPtr->TransformPhysicalPointToIndex(centerCoord, centerIndex); SampleCenteredPatch<TImage>(inputPtr, centerIndex, patchSize, tensor, elemIdx); } // Return the number of channels that the output tensor will occupy in the output image // // shape {n} --> 1 (e.g. a label) // shape {n, c} --> c (e.g. a vector) // shape {x, y, c} --> c (e.g. a patch) // shape {n, x, y, c} --> c (e.g. some patches) // tensorflow::int64 GetNumberOfChannelsForOutputTensor(const tensorflow::Tensor & tensor) { const tensorflow::TensorShape shape = tensor.shape(); const int nDims = shape.dims(); if (nDims == 1) return 1; return shape.dim_size(nDims - 1); } // // Copy a tensor into the image region // TODO: Enable to change mapping from source tensor to image to make it more generic // // Right now, only the following output tensor shapes can be processed: // shape {n} --> 1 (e.g. a label) // shape {n, c} --> c (e.g. a vector) // shape {x, y, c} --> c (e.g. a multichannel image) // template<class TImage, class TValueType> void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion, typename TImage::Pointer outputPtr, const typename TImage::RegionType & outputRegion, int & channelOffset) { // Flatten the tensor auto tFlat = tensor.flat<TValueType>(); // Get the size of the last component of the tensor (see 'GetNumberOfChannelsForOutputTensor(...)') const tensorflow::int64 outputDimSize_C = GetNumberOfChannelsForOutputTensor(tensor); // Number of columns (size x of the buffer) const tensorflow::int64 nCols = bufferRegion.GetSize(0); // Check the tensor size vs the outputRegion size const tensorflow::int64 nElmT = tensor.NumElements(); const tensorflow::int64 nElmI = bufferRegion.GetNumberOfPixels() * outputDimSize_C; if (nElmI != nElmT) { itkGenericExceptionMacro("Number of elements in the tensor is " << nElmT << " but image outputRegion has " << nElmI << " values to fill.\nBuffer region:\n" << bufferRegion << "\nNumber of components: " << outputDimSize_C << "\nTensor shape:\n " << PrintTensorShape(tensor.shape()) << "\nPlease check the input(s) field of view (FOV), " << "the output field of expression (FOE), and the " << "output spacing scale if you run the model in fully " << "convolutional mode (how many strides in your model?)"); } // Iterate over the image typename itk::ImageRegionIterator<TImage> outIt(outputPtr, outputRegion); for (outIt.GoToBegin(); !outIt.IsAtEnd(); ++outIt) { const int x = outIt.GetIndex()[0] - bufferRegion.GetIndex(0); const int y = outIt.GetIndex()[1] - bufferRegion.GetIndex(1); // TODO: it could be useful to change the tensor-->image mapping here. // e.g use a lambda for "pos" calculation const int pos = outputDimSize_C * (y * nCols + x); for (unsigned int c = 0 ; c < outputDimSize_C ; c++) outIt.Get()[channelOffset + c] = tFlat( pos + c); } // Update the offset channelOffset += outputDimSize_C; } // // Type-agnostic version of the 'CopyTensorToImageRegion' function // TODO: add some numeric types // template<class TImage> void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion, typename TImage::Pointer outputPtr, const typename TImage::RegionType & region, int & channelOffset) { tensorflow::DataType dt = tensor.dtype(); if (dt == tensorflow::DT_FLOAT) CopyTensorToImageRegion<TImage, float> (tensor, bufferRegion, outputPtr, region, channelOffset); else if (dt == tensorflow::DT_DOUBLE) CopyTensorToImageRegion<TImage, double> (tensor, bufferRegion, outputPtr, region, channelOffset); else if (dt == tensorflow::DT_INT64) CopyTensorToImageRegion<TImage, long long int>(tensor, bufferRegion, outputPtr, region, channelOffset); else if (dt == tensorflow::DT_INT32) CopyTensorToImageRegion<TImage, int> (tensor, bufferRegion, outputPtr, region, channelOffset); else itkGenericExceptionMacro("TF DataType "<< dt << " not currently implemented !"); } // // Compare two string lowercase // bool iequals(const std::string& a, const std::string& b) { return std::equal(a.begin(), a.end(), b.begin(), b.end(), [](char cha, char chb) { return tolower(cha) == tolower(chb); }); } // Convert an expression into a dict // // Following types are supported: // -bool // -int // -float // // e.g. is_training=true, droptout=0.2, nfeat=14 std::pair<std::string, tensorflow::Tensor> ExpressionToTensor(std::string expression) { std::pair<std::string, tensorflow::Tensor> dict; std::size_t found = expression.find("="); if (found != std::string::npos) { // Find name and value std::string name = expression.substr(0, found); std::string value = expression.substr(found+1); dict.first = name; // Find type std::size_t found_dot = value.find(".") != std::string::npos; std::size_t is_digit = value.find_first_not_of("0123456789.") == std::string::npos; if (is_digit) { if (found_dot) { // FLOAT try { float val = std::stof(value); tensorflow::Tensor out(tensorflow::DT_FLOAT, tensorflow::TensorShape()); out.scalar<float>()() = val; dict.second = out; } catch(...) { itkGenericExceptionMacro("Error parsing name=" << name << " with value=" << value << " as float"); } } else { // INT try { int val = std::stoi(value); tensorflow::Tensor out(tensorflow::DT_INT32, tensorflow::TensorShape()); out.scalar<int>()() = val; dict.second = out; } catch(...) { itkGenericExceptionMacro("Error parsing name=" << name << " with value=" << value << " as int"); } } } else { // BOOL bool val = true; if (iequals(value, "true")) { val = true; } else if (iequals(value, "false")) { val = false; } else { itkGenericExceptionMacro("Error parsing name=" << name << " with value=" << value << " as bool"); } tensorflow::Tensor out(tensorflow::DT_BOOL, tensorflow::TensorShape()); out.scalar<bool>()() = val; dict.second = out; } } else { itkGenericExceptionMacro("The following expression is not valid: " << "\n\t" << expression << ".\nExpression must be in the form int_value=1 or float_value=1.0 or bool_value=true."); } return dict; } } // end namespace tf } // end namespace otb