Commit 738ec4bf authored by remi.cresson's avatar remi.cresson
Browse files

Merge branch 'patchesselection' into develop

parents 1f40367d 60510f83
......@@ -22,10 +22,18 @@ if(OTB_USE_TENSORFLOW)
endif()
# Tensorflow-independent APPS
OTB_CREATE_APPLICATION(NAME PatchesSelection
SOURCES otbPatchesSelection.cxx
LINK_LIBRARIES ${OTBCOMMON_LIBRARIES}
)
OTB_CREATE_APPLICATION(NAME PatchesExtraction
SOURCES otbPatchesExtraction.cxx
LINK_LIBRARIES ${${otb-module}_LIBRARIES}
)
OTB_CREATE_APPLICATION(NAME PatchesManipulation
SOURCES otbPatchesManipulation.cxx
LINK_LIBRARIES ${OTBCOMMON_LIBRARIES}
)
OTB_CREATE_APPLICATION(NAME LabelImageSampleSelection
SOURCES otbLabelImageSampleSelection.cxx
LINK_LIBRARIES ${${otb-module}_LIBRARIES}
......
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "itkFixedArray.h"
#include "itkObjectFactory.h"
#include "otbWrapperApplicationFactory.h"
// Application engine
#include "otbStandardFilterWatcher.h"
#include "itkFixedArray.h"
// Image
#include "itkImageRegionConstIterator.h"
#include "itkImageRegionIterator.h"
// image utils
#include "otbTensorflowCommon.h"
namespace otb
{
namespace Wrapper
{
class PatchesManipulation : public Application
{
public:
/** Standard class typedefs. */
typedef PatchesManipulation Self;
typedef Application Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macro */
itkNewMacro(Self);
itkTypeMacro(PatchesManipulation, Application);
typedef otb::RAMDrivenStrippedStreamingManager<FloatVectorImageType> StreamingManagerType;
void DoUpdateParameters()
{
}
void DoInit()
{
// Documentation
SetName("PatchesManipulation");
SetDocName("PatchesManipulation");
SetDescription("This application enables to edit patches images.");
SetDocLongDescription("This application provide various operations for patches "
"images edition. ");
SetDocAuthors("Remi Cresson");
// Patches size
AddParameter(ParameterType_Group, "patches", "grid settings");
AddParameter(ParameterType_Int, "patches.sizex", "Patch size X");
SetMinimumParameterIntValue ("patches.sizex", 1);
MandatoryOn ("patches.sizex");
AddParameter(ParameterType_Int, "patches.sizey", "Patch size Y");
SetMinimumParameterIntValue ("patches.sizey", 1);
MandatoryOn ("patches.sizey");
// Operation
AddParameter(ParameterType_Choice, "op", "Operation");
AddChoice("op.merge", "Merge multiple patches images into one");
AddParameter(ParameterType_InputImageList, "op.merge.il", "patches images to merge");
// Output
AddParameter(ParameterType_OutputImage, "out", "Output patches image");
AddRAMParameter();
}
void CheckPatchesDimensions(FloatVectorImageType::Pointer in1, FloatVectorImageType::Pointer in2)
{
FloatVectorImageType::SizeType size1 = in1->GetLargestPossibleRegion().GetSize();
FloatVectorImageType::SizeType size2 = in2->GetLargestPossibleRegion().GetSize();
unsigned int nbands1 = in1->GetNumberOfComponentsPerPixel();
unsigned int nbands2 = in2->GetNumberOfComponentsPerPixel();
if (nbands1 != nbands2)
otbAppLogFATAL("Patches must have the same number of channels");
if (static_cast<int>(size1[0]) != GetParameterInt("patches.sizex"))
otbAppLogFATAL("Input patches image width not consistent with patch size x");
if (size1[1] % GetParameterInt("patches.sizey") != 0)
otbAppLogFATAL("Input patches image height is " << size1[1] << " which is not a multiple of " << GetParameterInt("patches.sizey"));
if (size2[1] % GetParameterInt("patches.sizey") != 0)
otbAppLogFATAL("Patches image height is " << size2[1] << " which is not a multiple of " << GetParameterInt("patches.sizey"));
if (size2[0] != size1[0])
otbAppLogFATAL("Input patches images must have the same width!");
}
/*
* Merge two patches images into one
* TODO:
* Use ImageToImage to create a filter that do this in a streamable way
*/
void MergePatches()
{
std::string key = "op.merge.il";
FloatVectorImageListType::Pointer imagesList = this->GetParameterImageList(key);
unsigned int nImgs = imagesList->Size();
otbAppLogINFO("Number of patches images: " << nImgs);
// Check patches consistency and count rows
FloatVectorImageType::IndexValueType nrows = imagesList->GetNthElement(0)->GetLargestPossibleRegion().GetSize(1);
FloatVectorImageType::Pointer img0 = imagesList->GetNthElement(0);
for (unsigned int i = 1; i < nImgs ; i++)
{
FloatVectorImageType::Pointer img = imagesList->GetNthElement(i);
CheckPatchesDimensions(img0, img);
nrows += img->GetLargestPossibleRegion().GetSize(1);
}
// Allocate output image
FloatVectorImageType::RegionType outRegion;
outRegion.GetModifiableIndex().Fill(0);
outRegion.GetModifiableSize()[0] = GetParameterInt("patches.sizex");
outRegion.GetModifiableSize()[1] = nrows;
m_Out = FloatVectorImageType::New();
m_Out->SetRegions(outRegion);
m_Out->SetNumberOfComponentsPerPixel(img0->GetNumberOfComponentsPerPixel());
otbAppLogINFO("Allocating output image of " << outRegion.GetSize() <<
" pixels with " << img0->GetNumberOfComponentsPerPixel() << " channels");
m_Out->Allocate();
// Read input images
itk::ImageRegionIterator<FloatVectorImageType> outIt(m_Out, outRegion);
outIt.GoToBegin();
for (unsigned int i = 0; i < nImgs ; i++)
{
// Get current image
FloatVectorImageType::Pointer img = imagesList->GetNthElement(i);
FloatVectorImageType::RegionType region = img->GetLargestPossibleRegion();
otbAppLogINFO("Processing input image " << (i+1) << "/" << nImgs);
// Recopy
tf::PropagateRequestedRegion<FloatVectorImageType>(img, region);
itk::ImageRegionConstIterator<FloatVectorImageType> inIt(img, region);
for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt, ++outIt)
outIt.Set(inIt.Get());
// Release data bulk
img->PrepareForNewData();
}
SetParameterOutputImage("out", m_Out);
}
void DoExecute()
{
if (GetParameterAsString("op").compare("merge") == 0)
{
otbAppLogINFO("Operation is merge");
MergePatches();
}
else
otbAppLogFATAL("Please select an existing operation");
}
private:
FloatVectorImageType::Pointer m_Out;
}; // end of class
} // end namespace wrapper
} // end namespace otb
OTB_APPLICATION_EXPORT( otb::Wrapper::PatchesManipulation )
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "itkFixedArray.h"
#include "itkObjectFactory.h"
#include "otbWrapperApplicationFactory.h"
// Application engine
#include "otbStandardFilterWatcher.h"
#include "itkFixedArray.h"
// Image
#include "itkImageRegionConstIterator.h"
#include "itkUnaryFunctorImageFilter.h"
#include "itkFlatStructuringElement.h"
#include "itkBinaryErodeImageFilter.h"
#include "otbStreamingResampleImageFilter.h"
// image utils
#include "otbTensorflowCommon.h"
#include "otbTensorflowSamplingUtils.h"
// Functor to retrieve nodata
template<class TPixel, class OutputPixel>
class IsNoData
{
public:
IsNoData(){}
~IsNoData(){}
inline OutputPixel operator()( const TPixel & A ) const
{
for (unsigned int band = 0 ; band < A.Size() ; band++)
{
if (A[band] != m_NoDataValue)
return 1;
}
return 0;
}
void SetNoDataValue(typename TPixel::ValueType value)
{
m_NoDataValue = value;
}
private:
typename TPixel::ValueType m_NoDataValue;
};
namespace otb
{
namespace Wrapper
{
class PatchesSelection : public Application
{
public:
/** Standard class typedefs. */
typedef PatchesSelection Self;
typedef Application Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macro */
itkNewMacro(Self);
itkTypeMacro(PatchesSelection, Application);
/** Vector data typedefs */
typedef VectorDataType::DataTreeType DataTreeType;
typedef itk::PreOrderTreeIterator<DataTreeType> TreeIteratorType;
typedef VectorDataType::DataNodeType DataNodeType;
typedef DataNodeType::Pointer DataNodePointer;
typedef DataNodeType::PointType DataNodePointType;
/** typedefs */
typedef IsNoData<FloatVectorImageType::PixelType, UInt8ImageType::PixelType > IsNoDataFunctorType;
typedef itk::UnaryFunctorImageFilter<FloatVectorImageType, UInt8ImageType, IsNoDataFunctorType> IsNoDataFilterType;
typedef itk::FlatStructuringElement<2> StructuringType;
typedef StructuringType::RadiusType RadiusType;
typedef itk::BinaryErodeImageFilter<UInt8ImageType, UInt8ImageType, StructuringType> MorphoFilterType;
typedef otb::StreamingResampleImageFilter<UInt8ImageType,UInt8ImageType> PadFilterType;
typedef tf::Distribution<UInt8ImageType> DistributionType;
void DoUpdateParameters()
{
}
void DoInit()
{
// Documentation
SetName("PatchesSelection");
SetDocName("PatchesSelection");
SetDescription("This application generate points sampled at regular interval over "
"the input image region. The grid size and spacing can be configured.");
SetDocLongDescription("This application produces a vector data containing "
"a set of points centered on the patches lying in the valid regions of the input image. ");
SetDocAuthors("Remi Cresson");
// Input image
AddParameter(ParameterType_InputImage, "in", "input image");
// Input no-data value
AddParameter(ParameterType_Float, "nodata", "nodata value");
MandatoryOn ("nodata");
SetDefaultParameterFloat ("nodata", 0);
// Grid
AddParameter(ParameterType_Group, "grid", "grid settings");
AddParameter(ParameterType_Int, "grid.step", "step between patches");
SetMinimumParameterIntValue ("grid.step", 1);
AddParameter(ParameterType_Int, "grid.psize", "patches size");
SetMinimumParameterIntValue ("grid.psize", 1);
// Strategy
AddParameter(ParameterType_Choice, "strategy", "Selection strategy for validation/training patches");
AddChoice("strategy.chessboard", "fifty fifty, like a chess board");
AddChoice("strategy.balanced", "you can chose the degree of spatial randomness vs class balance");
AddParameter(ParameterType_Float, "strategy.balanced.sp", "Spatial proportion: between 0 and 1, "
"indicating the amount of randomly sampled data in space");
SetMinimumParameterFloatValue ("strategy.balanced.sp", 0);
SetMaximumParameterFloatValue ("strategy.balanced.sp", 1);
SetDefaultParameterFloat ("strategy.balanced.sp", 0.25);
AddParameter(ParameterType_Int, "strategy.balanced.nclasses", "Number of classes");
SetMinimumParameterIntValue ("strategy.balanced.nclasses", 2);
MandatoryOn ("strategy.balanced.nclasses");
AddParameter(ParameterType_InputImage, "strategy.balanced.labelimage", "input label image");
MandatoryOn ("strategy.balanced.labelimage");
// Output points
AddParameter(ParameterType_OutputVectorData, "outtrain", "output set of points (training)");
AddParameter(ParameterType_OutputVectorData, "outvalid", "output set of points (validation)");
AddRAMParameter();
}
class SampleBundle
{
public:
SampleBundle(){}
SampleBundle(unsigned int nClasses){
dist = DistributionType(nClasses);
id = 0;
(void) point;
black = true;
(void) index;
}
~SampleBundle(){}
SampleBundle(const SampleBundle & other){
dist = other.GetDistribution();
id = other.GetSampleID();
point = other.GetPosition();
black = other.GetBlack();
index = other.GetIndex();
}
DistributionType GetDistribution() const
{
return dist;
}
DistributionType& GetModifiableDistribution()
{
return dist;
}
unsigned int GetSampleID() const
{
return id;
}
unsigned int& GetModifiableSampleID()
{
return id;
}
DataNodePointType GetPosition() const
{
return point;
}
DataNodePointType& GetModifiablePosition()
{
return point;
}
bool& GetModifiableBlack()
{
return black;
}
bool GetBlack() const
{
return black;
}
UInt8ImageType::IndexType& GetModifiableIndex()
{
return index;
}
UInt8ImageType::IndexType GetIndex() const
{
return index;
}
private:
DistributionType dist;
unsigned int id;
DataNodePointType point;
bool black;
UInt8ImageType::IndexType index;
};
/*
* Apply the given function at each sampling location
*/
template<typename TLambda>
void Apply(TLambda lambda)
{
// Explicit streaming over the morphed mask, based on the RAM parameter
typedef otb::RAMDrivenStrippedStreamingManager<UInt8ImageType> StreamingManagerType;
StreamingManagerType::Pointer m_StreamingManager = StreamingManagerType::New();
m_StreamingManager->SetAvailableRAMInMB(GetParameterInt("ram"));
UInt8ImageType::Pointer inputImage = m_MorphoFilter->GetOutput();
UInt8ImageType::RegionType entireRegion = inputImage->GetLargestPossibleRegion();
entireRegion.ShrinkByRadius(m_Radius);
m_StreamingManager->PrepareStreaming(inputImage, entireRegion );
UInt8ImageType::IndexType start;
start[0] = m_Radius[0] + 1;
start[1] = m_Radius[1] + 1;
int m_NumberOfDivisions = m_StreamingManager->GetNumberOfSplits();
UInt8ImageType::IndexType pos;
UInt8ImageType::IndexValueType step = GetParameterInt("grid.step");
pos.Fill(0);
for (int m_CurrentDivision = 0; m_CurrentDivision < m_NumberOfDivisions; m_CurrentDivision++)
{
otbAppLogINFO("Processing split " << (m_CurrentDivision + 1) << "/" << m_NumberOfDivisions);
UInt8ImageType::RegionType streamRegion = m_StreamingManager->GetSplit(m_CurrentDivision);
tf::PropagateRequestedRegion<UInt8ImageType>(inputImage, streamRegion);
itk::ImageRegionConstIterator<UInt8ImageType> inIt (inputImage, streamRegion);
for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt)
{
UInt8ImageType::IndexType idx = inIt.GetIndex();
idx[0] -= start[0];
idx[1] -= start[1];
if (idx[0] % step == 0 && idx[1] % step == 0)
{
UInt8ImageType::InternalPixelType pixVal = inIt.Get();
if (pixVal == 1)
{
// Update grid position
pos[0] = idx[0] / step;
pos[1] = idx[1] / step;
// Compute coordinates
UInt8ImageType::PointType geo;
m_MorphoFilter->GetOutput()->TransformIndexToPhysicalPoint(inIt.GetIndex(), geo);
DataNodeType::PointType point;
point[0] = geo[0];
point[1] = geo[1];
// Lambda call
lambda(pos, geo);
}
}
}
}
}
/*
* Allocate a std::vector of sample bundle
*/
std::vector<SampleBundle>
AllocateSamples(unsigned int nbOfClasses = 2)
{
// Nb of samples (maximum)
const UInt8ImageType::RegionType entireRegion = m_MorphoFilter->GetOutput()->GetLargestPossibleRegion();
const unsigned int maxNbOfCols = std::ceil(entireRegion.GetSize(0)/GetParameterInt("grid.step")) + 1;
const unsigned int maxNbOfRows = std::ceil(entireRegion.GetSize(1)/GetParameterInt("grid.step")) + 1;
unsigned int maxNbOfSamples = 1;
maxNbOfSamples *= maxNbOfCols;
maxNbOfSamples *= maxNbOfRows;
// Nb of classes
SampleBundle initSB(nbOfClasses);
std::vector<SampleBundle> bundles(maxNbOfSamples, initSB);
return bundles;
}
void SetBlackOrWhiteBundle(SampleBundle & bundle, unsigned int & count,
const UInt8ImageType::IndexType & pos, const UInt8ImageType::PointType & geo)
{
// Black or white
bool black = ((pos[0] + pos[1]) % 2 == 0);
bundle.GetModifiableSampleID() = count;
bundle.GetModifiablePosition() = geo;
bundle.GetModifiableBlack() = black;
bundle.GetModifiableIndex() = pos;
count++;
}
/*
* Samples are placed at regular intervals
*/
void SampleChessboard()
{
std::vector<SampleBundle> bundles = AllocateSamples();
unsigned int count = 0;
auto lambda = [this, &count, &bundles]
(const UInt8ImageType::IndexType & pos, const UInt8ImageType::PointType & geo) {
SetBlackOrWhiteBundle(bundles[count], count, pos, geo);
};
Apply(lambda);
bundles.resize(count);
// Export training/validation samples
PopulateVectorData(bundles);
}
void SampleBalanced()
{
// 1. Compute distribution of all samples
otbAppLogINFO("Computing samples distribution...");
std::vector<SampleBundle> bundles = AllocateSamples(GetParameterInt("strategy.balanced.nclasses"));
// Patch size
UInt8ImageType::SizeType patchSize;
patchSize.Fill(GetParameterInt("grid.psize"));
unsigned int count = 0;
auto lambda = [this, &bundles, &patchSize, &count]
(const UInt8ImageType::IndexType & pos, const UInt8ImageType::PointType & geo) {
// Update this sample distribution
if (tf::UpdateDistributionFromPatch<UInt8ImageType>(GetParameterUInt8Image("strategy.balanced.labelimage"),
geo, patchSize, bundles[count].GetModifiableDistribution()))
{
SetBlackOrWhiteBundle(bundles[count], count, pos, geo);
}
};
Apply(lambda);
bundles.resize(count);
otbAppLogINFO("Total number of candidates: " << count );
// 2. Seed = spatially random samples
otbAppLogINFO("Spatial sampling proportion " << GetParameterFloat("strategy.balanced.sp"));
const int samplingStep = static_cast<int>(1.0 / std::sqrt(GetParameterFloat("strategy.balanced.sp")));
otbAppLogINFO("Spatial sampling step " << samplingStep);
float step = 0;
std::vector<SampleBundle> seed(count);
std::vector<SampleBundle> candidates(count);
unsigned int seedCount = 0;
unsigned int candidatesCount = 0;
for (auto& d: bundles)
{
if (d.GetIndex()[0] % samplingStep + d.GetIndex()[1] % samplingStep == 0)
{
seed[seedCount] = d;
seedCount++;
}
else
{
candidates[candidatesCount] = d;
candidatesCount++;
}
step++;
}