Commit c5dda8c2 authored by Cresson Remi's avatar Cresson Remi
Browse files

ENH: new sample selection/extraction strategy (use nodata in selection or extraction)

parent e8ddeec9
......@@ -163,6 +163,12 @@ public:
// Input vector data
AddParameter(ParameterType_InputVectorData, "vec", "Positions of the samples (must be in the same projection as input image)");
// No data parameters
AddParameter(ParameterType_Bool, "usenodata", "Reject samples that have no-data value");
MandatoryOff ("usenodata");
AddParameter(ParameterType_Float, "nodataval", "No data value (used only if \"usenodata\" is on)");
SetDefaultParameterFloat( "nodataval", 0.0);
// Output label
AddParameter(ParameterType_OutputImage, "outlabels", "output labels");
SetDefaultOutputPixelType ("outlabels", ImagePixelType_uint8);
......@@ -191,6 +197,8 @@ public:
SamplerType::Pointer sampler = SamplerType::New();
sampler->SetInputVectorData(GetParameterVectorData("vec"));
sampler->SetField(GetParameterAsString("field"));
sampler->SetRejectPatchesWithNodata(GetParameterInt("usenodata")==1);
sampler->SetNodataValue(GetParameterFloat("nodataval"));
for (auto& bundle: m_Bundles)
{
sampler->PushBackInputWithPatchSize(bundle.m_ImageSource.Get(), bundle.m_PatchSize);
......
......@@ -27,6 +27,7 @@
// image utils
#include "otbTensorflowCommon.h"
#include "otbTensorflowSamplingUtils.h"
#include "itkImageRegionConstIteratorWithOnlyIndex.h"
// Functor to retrieve nodata
template<class TPixel, class OutputPixel>
......@@ -122,6 +123,8 @@ public:
AddParameter(ParameterType_Float, "nodata", "nodata value");
MandatoryOn ("nodata");
SetDefaultParameterFloat ("nodata", 0);
AddParameter(ParameterType_Bool, "nocheck", "If on, no check on the validity of patches is performed");
MandatoryOff ("nocheck");
// Grid
AddParameter(ParameterType_Group, "grid", "grid settings");
......@@ -238,6 +241,18 @@ public:
*/
template<typename TLambda>
void Apply(TLambda lambda)
{
if (GetParameterInt("nockeck")==1)
ApplyFast(lambda);
else
ApplyWithCheck(lambda);
}
/*
* Apply the given function at each sampling location, checking if the patch is valid or not
*/
template<typename TLambda>
void ApplyWithCheck(TLambda lambda)
{
// Explicit streaming over the morphed mask, based on the RAM parameter
......@@ -297,6 +312,49 @@ public:
}
}
/*
* Apply the given function at each sampling location, without checking the valid pixels under
*/
template<typename TLambda>
void ApplyFast(TLambda lambda)
{
FloatVectorImageType::Pointer inputImage = GetParameterFloatVectorImage("in");
FloatVectorImageType::RegionType entireRegion = inputImage->GetLargestPossibleRegion();
entireRegion.ShrinkByRadius(m_Radius);
FloatVectorImageType::IndexType start;
start[0] = m_Radius[0] + 1;
start[1] = m_Radius[1] + 1;
FloatVectorImageType::IndexType pos;
pos.Fill(0);
FloatVectorImageType::IndexValueType step = GetParameterInt("grid.step");
typedef itk::ImageRegionConstIteratorWithOnlyIndex<FloatVectorImageType> IteratorType;
IteratorType inIt (inputImage, entireRegion);
for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt)
{
FloatVectorImageType::IndexType idx = inIt.GetIndex();
idx[0] -= start[0];
idx[1] -= start[1];
if (idx[0] % step == 0 && idx[1] % step == 0)
{
// Update grid position
pos[0] = idx[0] / step;
pos[1] = idx[1] / step;
// Compute coordinates
FloatVectorImageType::PointType geo;
inputImage->TransformIndexToPhysicalPoint(inIt.GetIndex(), geo);
DataNodeType::PointType point;
point[0] = geo[0];
point[1] = geo[1];
// Lambda call
lambda(pos, geo);
}
}
}
/*
* Allocate a std::vector of sample bundle
*/
......@@ -595,6 +653,12 @@ public:
SampleChessboard();
}
else if (GetParameterAsString("strategy") == "chessboardfast")
{
otbAppLogINFO("Sampling at regular interval in space (\"Chessboard\" like) without checking image content");
SampleChessboardFast();
}
else if (GetParameterAsString("strategy") == "balanced")
{
otbAppLogINFO("Sampling with balancing strategy");
......
......@@ -24,6 +24,9 @@
// Tree iterator
#include "itkPreOrderTreeIterator.h"
// Image iterator
#include "itkImageRegionConstIterator.h"
namespace otb
{
......@@ -80,6 +83,7 @@ public:
ExtractROIMultiFilterPointerType;
typedef typename std::vector<ImagePointerType> ImagePointerListType;
typedef typename std::vector<SizeType> SizeListType;
typedef typename itk::ImageRegionConstIterator IteratorType;
/** Vector data typedefs */
typedef TVectorData VectorDataType;
......@@ -104,6 +108,12 @@ public:
virtual void PushBackInputWithPatchSize(const ImageType *input, SizeType & patchSize);
const ImageType* GetInput(unsigned int index);
/** Set / get no-data related parameters */
itkSetMacro(NodataValue, InternalPixelType);
itkGetMacro(NodataValue, InternalPixelType);
itkSetMacro(RejectPatchesWithNodata, bool);
itkGetMacro(RejectPatchesWithNodata, bool);
/** Do the real work */
virtual void Update();
......@@ -134,6 +144,10 @@ private:
unsigned long m_NumberOfAcceptedSamples;
unsigned long m_NumberOfRejectedSamples;
// No data stuff
InternalPixelType m_NodataValue;
bool m_RejectPatchesWithNodata;
}; // end class
} // end namespace otb
......
......@@ -181,6 +181,29 @@ TensorflowSampler<TInputImage, TVectorData>
// If not, reject this sample
hasBeenSampled = false;
}
// Check if it contains no-data values
if (m_RejectPatchesWithNodata && hasBeenSampled)
{
IndexType outIndex;
outIndex[0] = 0;
outIndex[1] = count * m_PatchSizes[i][1];
RegionType region(outIndex, m_PatchSizes);
IteratorType it(m_OutputPatchImages[i], region);
for (it.GoToBegin(); !it.IsAtEnd(); ++it)
{
PixelType pix = it.Get();
for (int i; i<pix.Size(); i++)
if (pix[i] == m_NodataValue)
{
hasBeenSampled = false;
break;
}
if (hasBeenSampled)
break;
}
}
} // Next input
if (hasBeenSampled)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment