An error occurred while loading the file. Please try again.
-
Guillaume Perréal authored
Permet de tout représenter sans afficher les boxes des namespaces.
61295671
/*=========================================================================
Copyright (c) 2018-2019 IRSTEA
Copyright (c) 2020-2021 INRAE
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef otbTensorflowSampler_txx
#define otbTensorflowSampler_txx
#include "otbTensorflowSampler.h"
namespace otb
{
template <class TInputImage, class TVectorData>
TensorflowSampler<TInputImage, TVectorData>::TensorflowSampler()
{
m_NumberOfAcceptedSamples = 0;
m_NumberOfRejectedSamples = 0;
}
template <class TInputImage, class TVectorData>
void
TensorflowSampler<TInputImage, TVectorData>::PushBackInputWithPatchSize(const ImageType * input,
SizeType & patchSize,
InternalPixelType nodataval)
{
this->ProcessObject::PushBackInput(const_cast<ImageType *>(input));
m_PatchSizes.push_back(patchSize);
unsigned int index = m_PatchSizes.size() -1 ;
m_NoDataValues[index] = nodataval;
}
template <class TInputImage, class TVectorData>
void
TensorflowSampler<TInputImage, TVectorData>::PushBackInputWithPatchSize(const ImageType * input,
SizeType & patchSize)
{
this->ProcessObject::PushBackInput(const_cast<ImageType *>(input));
m_PatchSizes.push_back(patchSize);
}
template <class TInputImage, class TVectorData>
const TInputImage *
TensorflowSampler<TInputImage, TVectorData>::GetInput(unsigned int index)
{
if (this->GetNumberOfInputs() < 1)
{
itkExceptionMacro("Input not set");
}
return static_cast<const ImageType *>(this->ProcessObject::GetInput(index));
}
/**
* Resize an image given a patch size and a number of samples
*/
template <class TInputImage, class TVectorData>
void
TensorflowSampler<TInputImage, TVectorData>::ResizeImage(ImagePointerType & image,
SizeType & patchSize,
unsigned int nbSamples)
{
// New image region
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
RegionType region;
region.SetSize(0, patchSize[0]);
region.SetSize(1, patchSize[1] * nbSamples);
// Resize
ExtractROIMultiFilterPointerType resizer = ExtractROIMultiFilterType::New();
resizer->SetInput(image);
resizer->SetExtractionRegion(region);
resizer->Update();
// Assign
image = resizer->GetOutput();
}
/**
* Allocate an image given a patch size and a number of samples
*/
template <class TInputImage, class TVectorData>
void
TensorflowSampler<TInputImage, TVectorData>::AllocateImage(ImagePointerType & image,
SizeType & patchSize,
unsigned int nbSamples,
unsigned int nbComponents)
{
// Image region
RegionType region;
region.SetSize(0, patchSize[0]);
region.SetSize(1, patchSize[1] * nbSamples);
// Allocate label image
image = ImageType::New();
image->SetNumberOfComponentsPerPixel(nbComponents);
image->SetRegions(region);
image->Allocate();
}
/**
* Do the work
*/
template <class TInputImage, class TVectorData>
void
TensorflowSampler<TInputImage, TVectorData>::Update()
{
// Check number of inputs
if (this->GetNumberOfInputs() != m_PatchSizes.size())
{
itkExceptionMacro("Number of inputs and patches sizes are not the same");
}
// Count points
unsigned int nTotal = 0;
unsigned int geomId = 0;
TreeIteratorType itVector(m_InputVectorData->GetDataTree());
itVector.GoToBegin();
while (!itVector.IsAtEnd())
{
if (!itVector.Get()->IsRoot() && !itVector.Get()->IsDocument() && !itVector.Get()->IsFolder())
{
const DataNodePointer currentGeometry = itVector.Get();
if (!currentGeometry->HasField(m_Field))
{
itkWarningMacro("Field \"" << m_Field << "\" not found in geometry #" << geomId);
}
else
{
nTotal++;
}
geomId++;
}
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
++itVector;
} // next feature
// Check number
if (nTotal == 0)
{
itkExceptionMacro("There is no geometry to sample. Geometries must be points.")
}
// Allocate label image
SizeType labelPatchSize;
labelPatchSize.Fill(1);
AllocateImage(m_OutputLabelImage, labelPatchSize, nTotal, 1);
// Allocate patches image
const unsigned int nbInputs = this->GetNumberOfInputs();
m_OutputPatchImages.clear();
m_OutputPatchImages.reserve(nbInputs);
for (unsigned int i = 0; i < nbInputs; i++)
{
ImagePointerType newImage;
AllocateImage(newImage, m_PatchSizes[i], nTotal, GetInput(i)->GetNumberOfComponentsPerPixel());
newImage->SetSignedSpacing(this->GetInput(i)->GetSignedSpacing());
m_OutputPatchImages.push_back(newImage);
}
itk::ProgressReporter progress(this, 0, nTotal);
// Iterate on the vector data
itVector.GoToBegin();
unsigned long count = 0;
unsigned long rejected = 0;
IndexType labelIndex;
labelIndex[0] = 0;
PixelType labelPix;
labelPix.SetSize(1);
while (!itVector.IsAtEnd())
{
if (!itVector.Get()->IsRoot() && !itVector.Get()->IsDocument() && !itVector.Get()->IsFolder())
{
DataNodePointer currentGeometry = itVector.Get();
if (!currentGeometry->HasField(m_Field))
{
PointType point = currentGeometry->GetPoint();
// Get the label value
labelPix[0] = static_cast<InternalPixelType>(currentGeometry->GetFieldAsInt(m_Field));
bool hasBeenSampled = true;
for (unsigned int i = 0; i < nbInputs; i++)
{
// Get input
ImagePointerType inputPtr = const_cast<ImageType *>(this->GetInput(i));
// Try to sample the image
if (!tf::SampleImage<ImageType>(inputPtr, m_OutputPatchImages[i], point, count, m_PatchSizes[i]))
{
// If not, reject this sample
hasBeenSampled = false;
}
// If NoData is provided, check if the sampled patch contains a no-data value
if (m_NoDataValues.count(i) > 0 && hasBeenSampled)
{
IndexType outIndex;
outIndex[0] = 0;
outIndex[1] = count * m_PatchSizes[i][1];
RegionType region(outIndex, m_PatchSizes[i]);
IteratorType it(m_OutputPatchImages[i], region);
for (it.GoToBegin(); !it.IsAtEnd(); ++it)
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
{
PixelType pix = it.Get();
for (unsigned int band = 0; band < pix.Size(); band++)
if (pix[band] == m_NoDataValues[i])
hasBeenSampled = false;
}
}
} // Next input
if (hasBeenSampled)
{
// Fill label
labelIndex[1] = count;
m_OutputLabelImage->SetPixel(labelIndex, labelPix);
// update count
count++;
}
else
{
rejected++;
}
// Update progress
progress.CompletedPixel();
}
}
++itVector;
} // next feature
// Resize output images
ResizeImage(m_OutputLabelImage, labelPatchSize, count);
for (unsigned int i = 0; i < nbInputs; i++)
{
ResizeImage(m_OutputPatchImages[i], m_PatchSizes[i], count);
}
// Update number of samples produced
m_NumberOfAcceptedSamples = count;
m_NumberOfRejectedSamples = rejected;
}
} // end namespace otb
#endif