-
remi cresson authored4bb2789b
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef otbTensorflowSampler_txx
#define otbTensorflowSampler_txx
#include "otbTensorflowSampler.h"
namespace otb
{
template <class TInputImage, class TVectorData>
TensorflowSampler<TInputImage, TVectorData>
::TensorflowSampler()
{
}
template <class TInputImage, class TVectorData>
void
TensorflowSampler<TInputImage, TVectorData>
::PushBackInputWithPatchSize(const ImageType *input, SizeType & patchSize)
{
this->ProcessObject::PushBackInput(const_cast<ImageType*>(input));
m_PatchSizes.push_back(patchSize);
}
template <class TInputImage, class TVectorData>
const TInputImage*
TensorflowSampler<TInputImage, TVectorData>
::GetInput(unsigned int index)
{
if (this->GetNumberOfInputs() < 1)
{
itkExceptionMacro("Input not set");
}
return static_cast<const ImageType*>(this->ProcessObject::GetInput(index));
}
/**
* Resize an image given a patch size and a number of samples
*/
template <class TInputImage, class TVectorData>
void
TensorflowSampler<TInputImage, TVectorData>
::ResizeImage(ImagePointerType & image, SizeType & patchSize, unsigned int nbSamples)
{
// New image region
RegionType region;
region.SetSize(0, patchSize[0]);
region.SetSize(1, patchSize[1] * nbSamples);
// Resize
ExtractROIMultiFilterPointerType resizer = ExtractROIMultiFilterType::New();
resizer->SetInput(image);
resizer->SetExtractionRegion(region);
resizer->Update();
// Assign
image = resizer->GetOutput();
}
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
/**
* Allocate an image given a patch size and a number of samples
*/
template <class TInputImage, class TVectorData>
void
TensorflowSampler<TInputImage, TVectorData>
::AllocateImage(ImagePointerType & image, SizeType & patchSize, unsigned int nbSamples, unsigned int nbComponents)
{
// Image region
RegionType region;
region.SetSize(0, patchSize[0]);
region.SetSize(1, patchSize[1] * nbSamples);
// Allocate label image
image = ImageType::New();
image->SetNumberOfComponentsPerPixel(nbComponents);
image->SetRegions(region);
image->Allocate();
}
/**
* Do the work
*/
template <class TInputImage, class TVectorData>
void
TensorflowSampler<TInputImage, TVectorData>
::Update()
{
// Check number of inputs
if (this->GetNumberOfInputs() != m_PatchSizes.size())
{
itkExceptionMacro("Number of inputs and patches sizes are not the same");
}
// Count points
unsigned int nTotal = 0;
unsigned int geomId = 0;
TreeIteratorType itVector(m_InputVectorData->GetDataTree());
itVector.GoToBegin();
while (!itVector.IsAtEnd())
{
if (!itVector.Get()->IsRoot() && !itVector.Get()->IsDocument() && !itVector.Get()->IsFolder())
{
const DataNodePointer currentGeometry = itVector.Get();
if (!currentGeometry->HasField(m_Field))
{
itkWarningMacro("Field \"" << m_Field << "\" not found in geometry #" << geomId);
}
else
{
nTotal++;
}
geomId++;
}
++itVector;
} // next feature
// Check number
if (nTotal == 0)
{
itkExceptionMacro("There is no geometry to sample. Geometries must be points.")
}
// Allocate label image
SizeType labelPatchSize;
labelPatchSize.Fill(1);
AllocateImage(m_OutputLabelImage, labelPatchSize, nTotal, 1);
// Allocate patches image
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
const unsigned int nbInputs = this->GetNumberOfInputs();
m_OutputPatchImages.clear();
m_OutputPatchImages.reserve(nbInputs);
for (unsigned int i = 0 ; i < nbInputs ; i++)
{
ImagePointerType newImage;
AllocateImage(newImage, m_PatchSizes[i], nTotal, GetInput(i)->GetNumberOfComponentsPerPixel());
m_OutputPatchImages.push_back(newImage);
}
itk::ProgressReporter progess(this, 0, nTotal);
// Iterate on the vector data
itVector.GoToBegin();
unsigned long count = 0;
unsigned long rejected = 0;
IndexType labelIndex;
labelIndex[0] = 0;
PixelType labelPix;
labelPix.SetSize(1);
while (!itVector.IsAtEnd())
{
if (!itVector.Get()->IsRoot() && !itVector.Get()->IsDocument() && !itVector.Get()->IsFolder())
{
DataNodePointer currentGeometry = itVector.Get();
PointType point = currentGeometry->GetPoint();
// Get the label value
labelPix[0] = static_cast<InternalPixelType>(currentGeometry->GetFieldAsInt(m_Field));
bool hasBeenSampled = true;
for (unsigned int i = 0 ; i < nbInputs ; i++)
{
// Get input
ImagePointerType inputPtr = const_cast<ImageType *>(this->GetInput(i));
// Try to sample the image
if (!tf::SampleImage<ImageType>(inputPtr, m_OutputPatchImages[i], point, count, m_PatchSizes[i]))
{
// If not, reject this sample
hasBeenSampled = false;
}
} // Next input
if (hasBeenSampled)
{
// Fill label
labelIndex[1] = count;
m_OutputLabelImage->SetPixel(labelIndex, labelPix);
// update count
count++;
}
else
{
rejected++;
}
// Update progres
progess.CompletedPixel();
}
++itVector;
} // next feature
// Resize output images
ResizeImage(m_OutputLabelImage, labelPatchSize, count);
for (unsigned int i = 0 ; i < nbInputs ; i++)
{
ResizeImage(m_OutputPatchImages[i], m_PatchSizes[i], count);
211212213214215216217218219220221222223
}
// Update number of samples produced
m_NumberOfAcceptedSamples = count;
m_NumberOfRejectedSamples = rejected;
}
} // end namespace otb
#endif