otbPatchesSelection.cxx 26.85 KiB
/*=========================================================================
     Copyright (c) 2018-2019 IRSTEA
     Copyright (c) 2020-2022 INRAE
     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.
=========================================================================*/
#include "itkFixedArray.h"
#include "itkObjectFactory.h"
#include "otbWrapperApplicationFactory.h"
// Application engine
#include "otbStandardFilterWatcher.h"
#include "itkFixedArray.h"
// Image
#include "itkImageRegionConstIterator.h"
#include "itkUnaryFunctorImageFilter.h"
#include "itkFlatStructuringElement.h"
#include "itkBinaryErodeImageFilter.h"
#include "otbStreamingResampleImageFilter.h"
#include "itkNearestNeighborInterpolateImageFunction.h"
#include "itkMaskImageFilter.h"
// Image utils
#include "otbTensorflowCommon.h"
#include "otbTensorflowSamplingUtils.h"
#include "itkImageRegionConstIteratorWithOnlyIndex.h"
// Math
#include <random>
#include <limits>
namespace otb
namespace Wrapper
// Functor to retrieve nodata
template<class TPixel, class OutputPixel>
class IsNoData
public:
  IsNoData(){}
  ~IsNoData(){}
  inline OutputPixel operator()( const TPixel & A ) const
    for (unsigned int band = 0 ; band < A.Size() ; band++)
      if (A[band] != m_NoDataValue)
        return 1;
    return 0;
  void SetNoDataValue(typename TPixel::ValueType value)
    m_NoDataValue = value;
private:
  typename TPixel::ValueType m_NoDataValue;
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
class PatchesSelection : public Application { public: /** Standard class typedefs. */ typedef PatchesSelection Self; typedef Application Superclass; typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; /** Standard macro */ itkNewMacro(Self); itkTypeMacro(PatchesSelection, Application); /** Vector data typedefs */ typedef VectorDataType::DataTreeType DataTreeType; typedef itk::PreOrderTreeIterator<DataTreeType> TreeIteratorType; typedef VectorDataType::DataNodeType DataNodeType; typedef DataNodeType::Pointer DataNodePointer; typedef DataNodeType::PointType DataNodePointType; /** typedefs */ typedef IsNoData<FloatVectorImageType::PixelType, UInt8ImageType::PixelType > IsNoDataFunctorType; typedef itk::UnaryFunctorImageFilter<FloatVectorImageType, UInt8ImageType, IsNoDataFunctorType> IsNoDataFilterType; typedef itk::FlatStructuringElement<2> StructuringType; typedef StructuringType::RadiusType RadiusType; typedef itk::BinaryErodeImageFilter<UInt8ImageType, UInt8ImageType, StructuringType> MorphoFilterType; typedef otb::StreamingResampleImageFilter<UInt8ImageType,UInt8ImageType> PadFilterType; typedef itk::NearestNeighborInterpolateImageFunction<UInt8ImageType> NNInterpolatorType; typedef tf::Distribution<UInt8ImageType> DistributionType; typedef itk::MaskImageFilter<UInt8ImageType, UInt8ImageType, UInt8ImageType> MaskImageFilterType; void DoInit() { // Documentation SetName("PatchesSelection"); SetDescription("This application generate points sampled at regular interval over " "the input image region. The selection strategy, grid size and step, " " can be configured."); SetDocLongDescription("This application produces a vector data containing " "a set of points centered on the selected patches."); SetDocAuthors("Remi Cresson"); // Input image AddParameter(ParameterType_InputImage, "in", "input image"); AddParameter(ParameterType_InputImage, "mask", "input mask"); MandatoryOff("mask"); // Input no-data value AddParameter(ParameterType_Float, "nodata", "nodata value"); MandatoryOff ("nodata"); // Grid AddParameter(ParameterType_Group, "grid", "grid settings"); AddParameter(ParameterType_Int, "grid.step", "step between patches"); SetMinimumParameterIntValue ("grid.step", 1); AddParameter(ParameterType_Int, "grid.psize", "patches size"); SetMinimumParameterIntValue ("grid.psize", 1); AddParameter(ParameterType_Int, "grid.offsetx", "offset of the grid (x axis)"); SetDefaultParameterInt ("grid.offsetx", 0); AddParameter(ParameterType_Int, "grid.offsety", "offset of the grid (y axis)"); SetDefaultParameterInt ("grid.offsety", 0); // Strategy
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
AddParameter(ParameterType_Choice, "strategy", "Selection strategy for validation/training patches"); // Chess board AddChoice("strategy.chessboard", "Fifty fifty with chess-board-like layout. Only \"outtrain\" and " "\"outvalid\" output parameters are used."); // Split AddChoice("strategy.split", "The traditional training/validation/test split. The \"outtrain\", " "\"outvalid\" and \"outtest\" output parameters are used."); AddParameter(ParameterType_Bool, "strategy.split.random", "If false, samples will always be from " "the same group"); MandatoryOff ("strategy.split.random"); AddParameter(ParameterType_Float, "strategy.split.trainprop", "Proportion of training population."); SetMinimumParameterFloatValue ("strategy.split.trainprop", 0.0); SetDefaultParameterFloat ("strategy.split.trainprop", 50.0); AddParameter(ParameterType_Float, "strategy.split.validprop", "Proportion of validation population."); SetMinimumParameterFloatValue ("strategy.split.validprop", 0.0); SetDefaultParameterFloat ("strategy.split.validprop", 25.0); AddParameter(ParameterType_Float, "strategy.split.testprop", "Proportion of test population."); SetMinimumParameterFloatValue ("strategy.split.testprop", 0.0); SetDefaultParameterFloat ("strategy.split.testprop", 25.0); // All AddChoice("strategy.all", "All locations. Only the \"outtrain\" output parameter is used."); // Balanced (experimental) AddChoice("strategy.balanced", "you can chose the degree of spatial randomness vs class balance"); AddParameter(ParameterType_Float, "strategy.balanced.sp", "Spatial proportion: between 0 and 1, " "indicating the amount of randomly sampled data in space"); SetMinimumParameterFloatValue ("strategy.balanced.sp", 0); SetMaximumParameterFloatValue ("strategy.balanced.sp", 1); SetDefaultParameterFloat ("strategy.balanced.sp", 0.25); AddParameter(ParameterType_Int, "strategy.balanced.nclasses", "Number of classes"); SetMinimumParameterIntValue ("strategy.balanced.nclasses", 2); MandatoryOn ("strategy.balanced.nclasses"); AddParameter(ParameterType_InputImage, "strategy.balanced.labelimage", "input label image"); MandatoryOn ("strategy.balanced.labelimage"); // Output points AddParameter(ParameterType_OutputVectorData, "outtrain", "output set of points (training)"); AddParameter(ParameterType_OutputVectorData, "outvalid", "output set of points (validation)"); MandatoryOff("outvalid"); AddParameter(ParameterType_OutputVectorData, "outtest", "output set of points (test)"); MandatoryOff("outtest"); AddRAMParameter(); } class SampleBundle { public: SampleBundle(){} explicit SampleBundle(unsigned int nClasses): dist(DistributionType(nClasses)), id(0), group(true){ (void) point; (void) index; } ~SampleBundle(){} SampleBundle(const SampleBundle & other): dist(other.GetDistribution()), id(other.GetSampleID()), point(other.GetPosition()), group(other.GetGroup()), index(other.GetIndex()) {} DistributionType GetDistribution() const { return dist; } DistributionType& GetModifiableDistribution() { return dist; } unsigned int GetSampleID() const
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
{ return id; } unsigned int& GetModifiableSampleID() { return id; } DataNodePointType GetPosition() const { return point; } DataNodePointType& GetModifiablePosition() { return point; } int& GetModifiableGroup() { return group; } int GetGroup() const { return group; } UInt8ImageType::IndexType& GetModifiableIndex() { return index; } UInt8ImageType::IndexType GetIndex() const { return index; } private: DistributionType dist; unsigned int id; DataNodePointType point; int group; UInt8ImageType::IndexType index; }; /* * Apply the given function at each sampling location, checking if the patch is valid or not */ template<typename TLambda> void Apply(TLambda lambda) { int userOffX = GetParameterInt("grid.offsetx"); int userOffY = GetParameterInt("grid.offsety"); // Tell if the patch size is odd or even const bool isEven = GetParameterInt("grid.psize") % 2 == 0; otbAppLogINFO("Patch size is even: " << isEven); // Explicit streaming over the morphed mask, based on the RAM parameter typedef otb::RAMDrivenStrippedStreamingManager<UInt8ImageType> StreamingManagerType; StreamingManagerType::Pointer m_StreamingManager = StreamingManagerType::New(); m_StreamingManager->SetAvailableRAMInMB(GetParameterInt("ram")); UInt8ImageType::Pointer inputImage; bool readInput = true; if (!HasValue("nodata"))
281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
{ otbAppLogINFO("No value specified for no-data. Input image pixels no-data values will not be checked."); if (HasValue("mask")) { otbAppLogINFO("Using the provided \"mask\" parameter."); inputImage = GetParameterUInt8Image("mask"); } else { // This is just a hack to not trigger the whole morpho/pad pipeline inputImage = m_NoDataFilter->GetOutput(); readInput = false; } } else { inputImage = m_MorphoFilter->GetOutput(); // Offset update because the morpho filter pads the input image with 1 pixel border userOffX += 1; userOffY += 1; } UInt8ImageType::RegionType entireRegion = inputImage->GetLargestPossibleRegion(); entireRegion.ShrinkByRadius(m_Radius); m_StreamingManager->PrepareStreaming(inputImage, entireRegion ); UInt8ImageType::IndexType start; start[0] = m_Radius[0] + 1; start[1] = m_Radius[1] + 1; int m_NumberOfDivisions = m_StreamingManager->GetNumberOfSplits(); UInt8ImageType::IndexType pos; UInt8ImageType::IndexValueType step = GetParameterInt("grid.step"); pos.Fill(0); // Offset update userOffX %= step ; userOffY %= step ; for (int m_CurrentDivision = 0; m_CurrentDivision < m_NumberOfDivisions; m_CurrentDivision++) { otbAppLogINFO("Processing split " << (m_CurrentDivision + 1) << "/" << m_NumberOfDivisions); UInt8ImageType::RegionType streamRegion = m_StreamingManager->GetSplit(m_CurrentDivision); tf::PropagateRequestedRegion<UInt8ImageType>(inputImage, streamRegion); itk::ImageRegionConstIterator<UInt8ImageType> inIt (inputImage, streamRegion); for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt) { UInt8ImageType::IndexType idx = inIt.GetIndex(); idx[0] -= start[0]; idx[1] -= start[1]; if (idx[0] % step == userOffX && idx[1] % step == userOffY) { UInt8ImageType::InternalPixelType pixVal = 1; if (readInput) pixVal = inIt.Get(); if (pixVal == 1) { // Update grid position pos[0] = idx[0] / step; pos[1] = idx[1] / step; // Compute coordinates UInt8ImageType::PointType geo; inputImage->TransformIndexToPhysicalPoint(inIt.GetIndex(), geo); // Update geo if we want the corner or the center if (isEven)
351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
{ geo[0] -= 0.5 * std::abs(inputImage->GetSpacing()[0]); geo[1] -= 0.5 * std::abs(inputImage->GetSpacing()[1]); } // Lambda call lambda(pos, geo); } } } } } /* * Allocate a std::vector of sample bundle */ std::vector<SampleBundle> AllocateSamples(unsigned int nbOfClasses = 2) { // Nb of samples (maximum) const UInt8ImageType::RegionType entireRegion = m_NoDataFilter->GetOutput()->GetLargestPossibleRegion(); const unsigned int maxNbOfCols = std::ceil(entireRegion.GetSize(0)/GetParameterInt("grid.step")) + 1; const unsigned int maxNbOfRows = std::ceil(entireRegion.GetSize(1)/GetParameterInt("grid.step")) + 1; unsigned int maxNbOfSamples = 1; maxNbOfSamples *= maxNbOfCols; maxNbOfSamples *= maxNbOfRows; // Nb of classes SampleBundle initSB(nbOfClasses); std::vector<SampleBundle> bundles(maxNbOfSamples, initSB); return bundles; } void SetBlackOrWhiteBundle(SampleBundle & bundle, unsigned int & count, const UInt8ImageType::IndexType & pos, const UInt8ImageType::PointType & geo) { // Black or white int black = (pos[0] + pos[1]) % 2; bundle.GetModifiableSampleID() = count; bundle.GetModifiablePosition() = geo; bundle.GetModifiableGroup() = black; bundle.GetModifiableIndex() = pos; count++; } /* * Samples are placed at regular intervals with the same layout as a chessboard, * in two groups (A: black, B: white) */ void SampleChessboard() { std::vector<SampleBundle> bundles = AllocateSamples(); unsigned int count = 0; auto lambda = [this, &count, &bundles] (const UInt8ImageType::IndexType & pos, const UInt8ImageType::PointType & geo) { SetBlackOrWhiteBundle(bundles[count], count, pos, geo); }; Apply(lambda); bundles.resize(count); // Export training/validation samples PopulateVectorData(bundles); }
421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
void SetSplitBundle(SampleBundle & bundle, unsigned int & count, const UInt8ImageType::IndexType & pos, const UInt8ImageType::PointType & geo, const std::vector<int> & groups) { bundle.GetModifiableGroup() = groups[count]; bundle.GetModifiableSampleID() = count; bundle.GetModifiablePosition() = geo; bundle.GetModifiableIndex() = pos; count++; } /* * Samples are split in training/validation/test groups */ void SampleSplit(float trp, float vp, float tp) { std::vector<SampleBundle> bundles = AllocateSamples(); // Populate groups unsigned int nbSamples = bundles.size(); float tot = (trp + vp + tp); std::vector<float> props = {trp, vp, tp}; std::vector<float> incs, counts; for (auto& prop: props) { if (prop > 0) { incs.push_back(tot / prop); counts.push_back(.0); } else { incs.push_back(.0); counts.push_back((float) nbSamples); } } std::vector<int> groups; for (unsigned int i = 0; i < nbSamples; i++) { // Find the group with the less samples auto it = std::min_element(std::begin(counts), std::end(counts)); auto idx = std::distance(std::begin(counts), it); assert (idx > 0); // Assign the group number, and update counts groups.push_back(idx); counts[idx] += incs[idx]; } if (GetParameterInt("strategy.split.random") > 0) { // Shuffle groups auto rng = std::default_random_engine {}; std::shuffle(std::begin(groups), std::end(groups), rng); } unsigned int count = 0; auto lambda = [this, &count, &bundles, &groups] (const UInt8ImageType::IndexType & pos, const UInt8ImageType::PointType & geo) { SetSplitBundle(bundles[count], count, pos, geo, groups); }; Apply(lambda); bundles.resize(count); // Export training/validation samples PopulateVectorData(bundles); }
491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
void SampleBalanced() { // 1. Compute distribution of all samples otbAppLogINFO("Computing samples distribution..."); std::vector<SampleBundle> bundles = AllocateSamples(GetParameterInt("strategy.balanced.nclasses")); // Patch size UInt8ImageType::SizeType patchSize; patchSize.Fill(GetParameterInt("grid.psize")); unsigned int count = 0; auto lambda = [this, &bundles, &patchSize, &count] (const UInt8ImageType::IndexType & pos, const UInt8ImageType::PointType & geo) { // Update this sample distribution if (tf::UpdateDistributionFromPatch<UInt8ImageType>(GetParameterUInt8Image("strategy.balanced.labelimage"), geo, patchSize, bundles[count].GetModifiableDistribution())) { SetBlackOrWhiteBundle(bundles[count], count, pos, geo); } }; Apply(lambda); bundles.resize(count); otbAppLogINFO("Total number of candidates: " << count ); // 2. Seed = spatially random samples otbAppLogINFO("Spatial sampling proportion " << GetParameterFloat("strategy.balanced.sp")); const int samplingStep = static_cast<int>(1.0 / std::sqrt(GetParameterFloat("strategy.balanced.sp"))); otbAppLogINFO("Spatial sampling step " << samplingStep); float step = 0; std::vector<SampleBundle> seed(count); std::vector<SampleBundle> candidates(count); unsigned int seedCount = 0; unsigned int candidatesCount = 0; for (auto& d: bundles) { if (d.GetIndex()[0] % samplingStep + d.GetIndex()[1] % samplingStep == 0) { seed[seedCount] = d; seedCount++; } else { candidates[candidatesCount] = d; candidatesCount++; } step++; } seed.resize(seedCount); candidates.resize(candidatesCount); otbAppLogINFO("Spatial seed has " << seedCount << " samples"); unsigned int nbToRemove = static_cast<unsigned int>(seedCount - GetParameterFloat("strategy.balanced.sp") * count); otbAppLogINFO("Adjust spatial seed removing " << nbToRemove << " samples"); float removalRate = static_cast<float>(seedCount) / static_cast<float>(nbToRemove); float removalStep = 0; auto removeSamples = [&removalStep, &removalRate](SampleBundle & b) -> bool {
561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630
(void) b; bool ret = false; if (removalStep >= removalRate) { removalStep = fmod(removalStep, removalRate); ret = true; } else ret = false; removalStep++; return ret;; }; auto iterator = std::remove_if(seed.begin(), seed.end(), removeSamples); seed.erase(iterator, seed.end()); otbAppLogINFO("Spatial seed size : " << seed.size()); // 3. Compute seed distribution const unsigned int nbOfClasses = GetParameterInt("strategy.balanced.nclasses"); DistributionType seedDist(nbOfClasses); for (auto& d: seed) seedDist.Update(d.GetDistribution()); otbAppLogINFO("Spatial seed distribution: " << seedDist.ToString()); // 4. Select other samples to feed the seed otbAppLogINFO("Balance seed candidates size: " << candidates.size()); // Sort by cos auto comparator = [&seedDist](const SampleBundle & a, const SampleBundle & b) -> bool{ return a.GetDistribution().Cosinus(seedDist) > b.GetDistribution().Cosinus(seedDist); }; sort(candidates.begin(), candidates.end(), comparator); DistributionType idealDist(nbOfClasses, 1.0 / std::sqrt(static_cast<float>(nbOfClasses))); float minCos = 0; unsigned int samplesAdded = 0; seed.resize(seed.size()+candidates.size(), SampleBundle(nbOfClasses)); while(candidates.size() > 0) { // Get the less correlated sample SampleBundle candidate = candidates.back(); // Update distribution seedDist.Update(candidate.GetDistribution()); // Compute cos of the updated distribution float idealCos = seedDist.Cosinus(idealDist); if (idealCos > minCos) { minCos = idealCos; seed[seedCount] = candidate; seedCount++; candidates.pop_back(); samplesAdded++; } else { break; } } seed.resize(seedCount); otbAppLogINFO("Final samples number: " << seed.size() << " (" << samplesAdded << " samples added)"); otbAppLogINFO("Final samples distribution: " << seedDist.ToString()); // 5. Export training/validation samples PopulateVectorData(seed);
631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700
} void PopulateVectorData(const std::vector<SampleBundle> & samples) { // Get data tree DataTreeType::Pointer treeTrain = m_OutVectorDataTrain->GetDataTree(); DataTreeType::Pointer treeValid = m_OutVectorDataValid->GetDataTree(); DataTreeType::Pointer treeTest = m_OutVectorDataTest->GetDataTree(); DataNodePointer rootTrain = treeTrain->GetRoot()->Get(); DataNodePointer rootValid = treeValid->GetRoot()->Get(); DataNodePointer rootTest = treeTest->GetRoot()->Get(); DataNodePointer documentTrain = DataNodeType::New(); DataNodePointer documentValid = DataNodeType::New(); DataNodePointer documentTest = DataNodeType::New(); documentTrain->SetNodeType(DOCUMENT); documentValid->SetNodeType(DOCUMENT); documentTest->SetNodeType(DOCUMENT); treeTrain->Add(documentTrain, rootTrain); treeValid->Add(documentValid, rootValid); treeTest->Add(documentTest, rootTest); unsigned int id = 0; for (const auto& sample: samples) { // Add point to the VectorData tree DataNodePointer newDataNode = DataNodeType::New(); newDataNode->SetPoint(sample.GetPosition()); newDataNode->SetFieldAsInt("id", id); id++; // select this sample if (sample.GetGroup() == 0) { // Train treeTrain->Add(newDataNode, documentTrain); } else if (sample.GetGroup() == 1) { // Valid treeValid->Add(newDataNode, documentValid); } else if (sample.GetGroup() == 2) { // Test treeTest->Add(newDataNode, documentTest); } } } void DoExecute() { otbAppLogINFO("Grid step : " << this->GetParameterInt("grid.step")); otbAppLogINFO("Patch size : " << this->GetParameterInt("grid.psize")); // Compute no-data mask m_NoDataFilter = IsNoDataFilterType::New(); float nodataValue = std::numeric_limits<float>::quiet_NaN(); if (HasValue("nodata")) { nodataValue = GetParameterFloat("nodata"); } m_NoDataFilter->GetFunctor().SetNoDataValue(nodataValue); m_NoDataFilter->SetInput(GetParameterFloatVectorImage("in")); m_NoDataFilter->UpdateOutputInformation(); UInt8ImageType::Pointer src = m_NoDataFilter->GetOutput(); // If mask available, use it if (HasValue("mask")) {
701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770
if (GetParameterUInt8Image("mask")->GetLargestPossibleRegion().GetSize() != GetParameterFloatVectorImage("in")->GetLargestPossibleRegion().GetSize()) otbAppLogFATAL("Mask must have the same size as the input image!"); m_MaskImageFilter = MaskImageFilterType::New(); m_MaskImageFilter->SetInput(m_NoDataFilter->GetOutput()); m_MaskImageFilter->SetMaskImage(GetParameterUInt8Image("mask")); m_MaskImageFilter->UpdateOutputInformation(); src = m_MaskImageFilter->GetOutput(); } // Padding 1 pixel UInt8ImageType::SizeType size = src->GetLargestPossibleRegion().GetSize(); size[0] += 2; size[1] += 2; UInt8ImageType::SpacingType spacing = src->GetSignedSpacing(); UInt8ImageType::PointType origin = src->GetOrigin(); origin[0] -= spacing[0]; origin[1] -= spacing[1]; m_PadFilter = PadFilterType::New(); NNInterpolatorType::Pointer nnInterpolator = NNInterpolatorType::New(); m_PadFilter->SetInterpolator(nnInterpolator); m_PadFilter->SetInput( src ); m_PadFilter->SetOutputOrigin(origin); m_PadFilter->SetOutputSpacing(spacing); m_PadFilter->SetOutputSize(size); m_PadFilter->SetEdgePaddingValue( 0 ); m_PadFilter->UpdateOutputInformation(); // Morpho m_Radius[0] = this->GetParameterInt("grid.psize") / 2; m_Radius[1] = this->GetParameterInt("grid.psize") / 2; StructuringType se = StructuringType::Box(m_Radius); m_MorphoFilter = MorphoFilterType::New(); m_MorphoFilter->SetKernel(se); m_MorphoFilter->SetInput(m_PadFilter->GetOutput()); m_MorphoFilter->SetForegroundValue(1); m_MorphoFilter->SetBackgroundValue(0); m_MorphoFilter->UpdateOutputInformation(); // Prepare output vector data m_OutVectorDataTrain = VectorDataType::New(); m_OutVectorDataValid = VectorDataType::New(); m_OutVectorDataTest = VectorDataType::New(); m_OutVectorDataTrain->SetProjectionRef(m_MorphoFilter->GetOutput()->GetProjectionRef()); m_OutVectorDataValid->SetProjectionRef(m_MorphoFilter->GetOutput()->GetProjectionRef()); m_OutVectorDataTest->SetProjectionRef(m_MorphoFilter->GetOutput()->GetProjectionRef()); if (GetParameterAsString("strategy") == "chessboard") { otbAppLogINFO("Sampling at regular interval in space (\"Chessboard\" like)"); SampleChessboard(); if (HasValue("outtest")) { otbAppLogWARNING("The \"outtest\" parameter is unused with the \"chessboard\" sampling strategy.") } } else if (GetParameterAsString("strategy") == "balanced") { otbAppLogINFO("Sampling with balancing strategy"); SampleBalanced(); } else if (GetParameterAsString("strategy") == "split") { otbAppLogINFO("Sampling with split strategy (Train/Validation/test)"); float vp = .0; float tp = .0; if (HasValue("outvalid"))
771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827
{ vp = GetParameterFloat("strategy.split.validprop"); } if (HasValue("outtest")) { tp = GetParameterFloat("strategy.split.testprop"); } SampleSplit(GetParameterFloat("strategy.split.trainprop"), vp, tp); } else if (GetParameterAsString("strategy") == "all") { otbAppLogINFO("Sampling all locations (only \"outtrain\" output parameter will be used"); SampleSplit(1.0, .0, .0); if (HasValue("outtest") || HasValue("outvalid")) { otbAppLogWARNING("The \"outvalid\" and \"outtest\" parameters are unused with the \"all\" sampling strategy.") } } otbAppLogINFO( "Writing output samples positions"); SetParameterOutputVectorData("outtrain", m_OutVectorDataTrain); if (HasValue("outvalid") && GetParameterAsString("strategy") != "all") { SetParameterOutputVectorData("outvalid", m_OutVectorDataValid); } if (HasValue("outtest") && GetParameterAsString("strategy") == "split") { SetParameterOutputVectorData("outtest", m_OutVectorDataTest); } } void DoUpdateParameters() { } private: RadiusType m_Radius; IsNoDataFilterType::Pointer m_NoDataFilter; PadFilterType::Pointer m_PadFilter; MorphoFilterType::Pointer m_MorphoFilter; VectorDataType::Pointer m_OutVectorDataTrain; VectorDataType::Pointer m_OutVectorDataValid; VectorDataType::Pointer m_OutVectorDataTest; MaskImageFilterType::Pointer m_MaskImageFilter; }; // end of class } // end namespace wrapper } // end namespace otb OTB_APPLICATION_EXPORT( otb::Wrapper::PatchesSelection )