otbTensorflowSampler.hxx 7.55 KiB
/*=========================================================================
     Copyright (c) 2018-2019 IRSTEA
     Copyright (c) 2020-2021 INRAE
     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.
=========================================================================*/
#ifndef otbTensorflowSampler_txx
#define otbTensorflowSampler_txx
#include "otbTensorflowSampler.h"
namespace otb
template <class TInputImage, class TVectorData>
TensorflowSampler<TInputImage, TVectorData>::TensorflowSampler()
  m_NumberOfAcceptedSamples = 0;
  m_NumberOfRejectedSamples = 0;
template <class TInputImage, class TVectorData>
void
TensorflowSampler<TInputImage, TVectorData>::PushBackInputWithPatchSize(const ImageType * input,
                                                                        SizeType &        patchSize,
                                                                        InternalPixelType nodataval)
  this->ProcessObject::PushBackInput(const_cast<ImageType *>(input));
  m_PatchSizes.push_back(patchSize);
  unsigned int index = m_PatchSizes.size() -1 ;
  m_NoDataValues[index] = nodataval;
template <class TInputImage, class TVectorData>
void
TensorflowSampler<TInputImage, TVectorData>::PushBackInputWithPatchSize(const ImageType * input,
                                                                        SizeType &        patchSize)
  this->ProcessObject::PushBackInput(const_cast<ImageType *>(input));
  m_PatchSizes.push_back(patchSize);
template <class TInputImage, class TVectorData>
const TInputImage *
TensorflowSampler<TInputImage, TVectorData>::GetInput(unsigned int index)
  if (this->GetNumberOfInputs() < 1)
    itkExceptionMacro("Input not set");
  return static_cast<const ImageType *>(this->ProcessObject::GetInput(index));
/**
 * Resize an image given a patch size and a number of samples
template <class TInputImage, class TVectorData>
void
TensorflowSampler<TInputImage, TVectorData>::ResizeImage(ImagePointerType & image,
                                                         SizeType &         patchSize,
                                                         unsigned int       nbSamples)
  // New image region
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
RegionType region; region.SetSize(0, patchSize[0]); region.SetSize(1, patchSize[1] * nbSamples); // Resize ExtractROIMultiFilterPointerType resizer = ExtractROIMultiFilterType::New(); resizer->SetInput(image); resizer->SetExtractionRegion(region); resizer->Update(); // Assign image = resizer->GetOutput(); } /** * Allocate an image given a patch size and a number of samples */ template <class TInputImage, class TVectorData> void TensorflowSampler<TInputImage, TVectorData>::AllocateImage(ImagePointerType & image, SizeType & patchSize, unsigned int nbSamples, unsigned int nbComponents) { // Image region RegionType region; region.SetSize(0, patchSize[0]); region.SetSize(1, patchSize[1] * nbSamples); // Allocate label image image = ImageType::New(); image->SetNumberOfComponentsPerPixel(nbComponents); image->SetRegions(region); image->Allocate(); } /** * Do the work */ template <class TInputImage, class TVectorData> void TensorflowSampler<TInputImage, TVectorData>::Update() { // Check number of inputs if (this->GetNumberOfInputs() != m_PatchSizes.size()) { itkExceptionMacro("Number of inputs and patches sizes are not the same"); } // Count points unsigned int nTotal = 0; unsigned int geomId = 0; TreeIteratorType itVector(m_InputVectorData->GetDataTree()); itVector.GoToBegin(); while (!itVector.IsAtEnd()) { if (!itVector.Get()->IsRoot() && !itVector.Get()->IsDocument() && !itVector.Get()->IsFolder()) { const DataNodePointer currentGeometry = itVector.Get(); if (!currentGeometry->HasField(m_Field)) { itkWarningMacro("Field \"" << m_Field << "\" not found in geometry #" << geomId); } else { nTotal++; } geomId++; }
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
++itVector; } // next feature // Check number if (nTotal == 0) { itkExceptionMacro("There is no geometry to sample. Geometries must be points.") } // Allocate label image SizeType labelPatchSize; labelPatchSize.Fill(1); AllocateImage(m_OutputLabelImage, labelPatchSize, nTotal, 1); // Allocate patches image const unsigned int nbInputs = this->GetNumberOfInputs(); m_OutputPatchImages.clear(); m_OutputPatchImages.reserve(nbInputs); for (unsigned int i = 0; i < nbInputs; i++) { ImagePointerType newImage; AllocateImage(newImage, m_PatchSizes[i], nTotal, GetInput(i)->GetNumberOfComponentsPerPixel()); newImage->SetSignedSpacing(this->GetInput(i)->GetSignedSpacing()); m_OutputPatchImages.push_back(newImage); } itk::ProgressReporter progress(this, 0, nTotal); // Iterate on the vector data itVector.GoToBegin(); unsigned long count = 0; unsigned long rejected = 0; IndexType labelIndex; labelIndex[0] = 0; PixelType labelPix; labelPix.SetSize(1); while (!itVector.IsAtEnd()) { if (!itVector.Get()->IsRoot() && !itVector.Get()->IsDocument() && !itVector.Get()->IsFolder()) { DataNodePointer currentGeometry = itVector.Get(); if (!currentGeometry->HasField(m_Field)) { PointType point = currentGeometry->GetPoint(); // Get the label value labelPix[0] = static_cast<InternalPixelType>(currentGeometry->GetFieldAsInt(m_Field)); bool hasBeenSampled = true; for (unsigned int i = 0; i < nbInputs; i++) { // Get input ImagePointerType inputPtr = const_cast<ImageType *>(this->GetInput(i)); // Try to sample the image if (!tf::SampleImage<ImageType>(inputPtr, m_OutputPatchImages[i], point, count, m_PatchSizes[i])) { // If not, reject this sample hasBeenSampled = false; } // If NoData is provided, check if the sampled patch contains a no-data value if (m_NoDataValues.count(i) > 0 && hasBeenSampled) { IndexType outIndex; outIndex[0] = 0; outIndex[1] = count * m_PatchSizes[i][1]; RegionType region(outIndex, m_PatchSizes[i]); IteratorType it(m_OutputPatchImages[i], region); for (it.GoToBegin(); !it.IsAtEnd(); ++it)
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
{ PixelType pix = it.Get(); for (unsigned int band = 0; band < pix.Size(); band++) if (pix[band] == m_NoDataValues[i]) hasBeenSampled = false; } } } // Next input if (hasBeenSampled) { // Fill label labelIndex[1] = count; m_OutputLabelImage->SetPixel(labelIndex, labelPix); // update count count++; } else { rejected++; } // Update progress progress.CompletedPixel(); } } ++itVector; } // next feature // Resize output images ResizeImage(m_OutputLabelImage, labelPatchSize, count); for (unsigned int i = 0; i < nbInputs; i++) { ResizeImage(m_OutputPatchImages[i], m_PatchSizes[i], count); } // Update number of samples produced m_NumberOfAcceptedSamples = count; m_NumberOfRejectedSamples = rejected; } } // end namespace otb #endif