otbLabelImageSampleSelection.cxx 13.98 KiB
/*=========================================================================
     Copyright (c) 2018-2019 IRSTEA
     Copyright (c) 2020-2021 INRAE
     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.
=========================================================================*/
#include "itkFixedArray.h"
#include "itkObjectFactory.h"
#include "otbWrapperApplicationFactory.h"
// Application engine
#include "otbStandardFilterWatcher.h"
#include "itkFixedArray.h"
#include "vnl/vnl_vector.h"
#include "itkImageRegionIterator.h"
#include "itkImageRegionConstIterator.h"
// image utils
#include "otbTensorflowCommon.h"
#include <algorithm>
namespace otb
namespace Wrapper
class LabelImageSampleSelection : public Application
public:
  /** Standard class typedefs. */
  typedef LabelImageSampleSelection           Self;
  typedef Application                         Superclass;
  typedef itk::SmartPointer<Self>             Pointer;
  typedef itk::SmartPointer<const Self>       ConstPointer;
  /** Standard macro */
  itkNewMacro(Self);
  itkTypeMacro(LabelImageSampleSelection, Application);
  /** Vector data typedefs */
  typedef VectorDataType::DataTreeType                 DataTreeType;
  typedef itk::PreOrderTreeIterator<DataTreeType>      TreeIteratorType;
  typedef VectorDataType::DataNodeType                 DataNodeType;
  typedef DataNodeType::Pointer                        DataNodePointer;
  /** typedefs */
  typedef Int16ImageType                               LabelImageType;
  typedef unsigned int                                 IndexValueType;
  void DoUpdateParameters()
   * Display the percentage
  void ShowProgress(unsigned int count, unsigned int total, unsigned int step = 1000)
    if (count % step == 0)
      std::cout << std::setprecision(3) << "\r" << (100.0 * count / (float) total) << "%      " << std::flush;
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
void ShowProgressDone() { std::cout << "\rDone " << std::flush; std::cout << std::endl; } void DoInit() { // Documentation SetName("LabelImageSampleSelection"); SetDescription("This application extracts points from an input label image. " "This application is like \"SampleSelection\", but uses an input label " "image, rather than an input vector data."); SetDocLongDescription("This application produces a vector data containing " "a set of points centered on the pixels of the input label image. " "The user can control the number of points. The default strategy consists " "in producing the same number of points in each class. If one class has a " "smaller number of points than requested, this one is adjusted."); SetDocAuthors("Remi Cresson"); // Input terrain truth AddParameter(ParameterType_InputImage, "inref", "input terrain truth"); // Strategy AddParameter(ParameterType_Choice, "strategy", "Sampling strategy"); AddChoice("strategy.constant","Set the same samples counts for all classes"); SetParameterDescription("strategy.constant","Set the same samples counts for all classes"); AddParameter(ParameterType_Int, "strategy.constant.nb", "Number of samples for all classes"); SetParameterDescription("strategy.constant.nb", "Number of samples for all classes"); SetMinimumParameterIntValue("strategy.constant.nb",1); SetDefaultParameterInt("strategy.constant.nb",1000); AddChoice("strategy.total","Set the total number of samples to generate, and use class proportions."); SetParameterDescription("strategy.total","Set the total number of samples to generate, and use class proportions."); AddParameter(ParameterType_Int,"strategy.total.v","The number of samples to generate"); SetParameterDescription("strategy.total.v","The number of samples to generate"); SetMinimumParameterIntValue("strategy.total.v",1); SetDefaultParameterInt("strategy.total.v",1000); AddChoice("strategy.smallest","Set same number of samples for all classes, with the smallest class fully sampled"); SetParameterDescription("strategy.smallest","Set same number of samples for all classes, with the smallest class fully sampled"); AddChoice("strategy.all","Take all samples"); SetParameterDescription("strategy.all","Take all samples"); // Default strategy : smallest SetParameterString("strategy","constant"); // Input no-data value AddParameter(ParameterType_Int, "nodata", "nodata value"); MandatoryOn ("nodata"); SetDefaultParameterInt ("nodata", -1); // Padding AddParameter(ParameterType_Int, "pad", "padding, in pixels"); SetDefaultParameterInt ("pad", 0); MandatoryOff ("pad"); // Output points AddParameter(ParameterType_OutputVectorData, "outvec", "output set of points"); // Some example SetDocExampleParameterValue("inref", "rasterized_terrain_truth.tif"); SetDocExampleParameterValue("outvec", "terrain_truth_points_sel.sqlite");
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
AddRAMParameter(); } void DoExecute() { // Count the number of pixels in each class const LabelImageType::InternalPixelType MAX_NB_OF_CLASSES = itk::NumericTraits<LabelImageType::InternalPixelType>::max();; LabelImageType::InternalPixelType class_begin = MAX_NB_OF_CLASSES; LabelImageType::InternalPixelType class_end = 0; vnl_vector<IndexValueType> tmp_number_of_samples(MAX_NB_OF_CLASSES, 0); otbAppLogINFO("Computing number of pixels in each class"); // Explicit streaming over the input target image, based on the RAM parameter typedef otb::RAMDrivenStrippedStreamingManager<FloatVectorImageType> StreamingManagerType; StreamingManagerType::Pointer m_StreamingManager = StreamingManagerType::New(); m_StreamingManager->SetAvailableRAMInMB(GetParameterInt("ram")); // We pad the image, if this is requested by the user LabelImageType::Pointer inputImage = GetParameterInt16Image("inref"); LabelImageType::RegionType entireRegion = inputImage->GetLargestPossibleRegion(); entireRegion.ShrinkByRadius(GetParameterInt("pad")); m_StreamingManager->PrepareStreaming(inputImage, entireRegion ); // Get nodata value const LabelImageType::InternalPixelType nodata = GetParameterInt("nodata"); // First iteration to count the objects in each class int m_NumberOfDivisions = m_StreamingManager->GetNumberOfSplits(); for (int m_CurrentDivision = 0; m_CurrentDivision < m_NumberOfDivisions; m_CurrentDivision++) { LabelImageType::RegionType streamRegion = m_StreamingManager->GetSplit(m_CurrentDivision); tf::PropagateRequestedRegion<LabelImageType>(inputImage, streamRegion); itk::ImageRegionConstIterator<LabelImageType> inIt (inputImage, streamRegion); for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt) { LabelImageType::InternalPixelType pixVal = inIt.Get(); if (pixVal != nodata) { // Update min and max value if (pixVal > class_end) class_end = pixVal; if (pixVal < class_begin) class_begin = pixVal; tmp_number_of_samples(pixVal)++; } } ShowProgress(m_CurrentDivision, m_NumberOfDivisions, 1); } ShowProgressDone(); // Number of classes const LabelImageType::InternalPixelType number_of_classes = class_end - class_begin + 1; // Number of samples in each class (counted) vnl_vector<IndexValueType> number_of_samples = tmp_number_of_samples.extract(number_of_classes, class_begin); // Number of samples in each class (target) vnl_vector<IndexValueType> target_number_of_samples(number_of_classes, 0); otbAppLogINFO( "Number of classes: " << number_of_classes << " starting from " << class_begin << " to " << class_end << " (no-data is " << nodata << ")"); otbAppLogINFO( "Number of pixels in each class: " << number_of_samples );
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
// Check the smallest number of samples amongst classes IndexValueType min_elem_in_class = itk::NumericTraits<IndexValueType>::max(); for (LabelImageType::InternalPixelType classIdx = 0 ; classIdx < number_of_classes ; classIdx++) min_elem_in_class = std::min(min_elem_in_class, number_of_samples[classIdx]); // If one class is empty, throw an error if (min_elem_in_class == 0) { otbAppLogFATAL("There is at least one class with no sample!") } // Sampling step for each classes vnl_vector<IndexValueType> step_for_class(number_of_classes, 0); // Compute the sampling step for each classes, depending on the chosen strategy switch (this->GetParameterInt("strategy")) { // constant case 0: { // Set the target number of samples in each class target_number_of_samples.fill(GetParameterInt("strategy.constant.nb")); // re adjust the number of samples to select in each class if (min_elem_in_class < target_number_of_samples[0]) { otbAppLogWARNING("Smallest class has " << min_elem_in_class << " samples but a number of " << target_number_of_samples[0] << " is given. Using " << min_elem_in_class); target_number_of_samples.fill( min_elem_in_class ); } // Compute the sampling step for (LabelImageType::InternalPixelType classIdx = 0 ; classIdx < number_of_classes ; classIdx++) step_for_class[classIdx] = number_of_samples[classIdx] / target_number_of_samples[classIdx]; } break; // total case 1: { // Compute the sampling step IndexValueType step = number_of_samples.sum() / this->GetParameterInt("strategy.total.v"); if (step == 0) { otbAppLogWARNING("The number of samples available is smaller than the required number of samples. " << "Setting sampling step to 1."); step = 1; } step_for_class.fill(step); // Compute the target number of samples for (LabelImageType::InternalPixelType classIdx = 0 ; classIdx < number_of_classes ; classIdx++) target_number_of_samples[classIdx] = number_of_samples[classIdx] / step; } break; // smallest case 2: { // Set the target number of samples to the smallest class target_number_of_samples.fill( min_elem_in_class ); // Compute the sampling step for (LabelImageType::InternalPixelType classIdx = 0 ; classIdx < number_of_classes ; classIdx++) step_for_class[classIdx] = number_of_samples[classIdx] / target_number_of_samples[classIdx]; }
281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
break; // All case 3: { // Easy step_for_class.fill(1); target_number_of_samples = number_of_samples; } break; default: otbAppLogFATAL("Strategy mode unknown :"<<this->GetParameterString("strategy")); break; } // Print quick summary otbAppLogINFO("Sampling summary:"); otbAppLogINFO("\tClass\tStep\tTot"); for (LabelImageType::InternalPixelType i = 0 ; i < number_of_classes ; i++) { vnl_vector<int> tmp (3,0); tmp[0] = i + class_begin; tmp[1] = step_for_class[i]; tmp[2] = target_number_of_samples[i]; otbAppLogINFO("\t" << tmp); } // Create a new vector data // TODO: how to pre-allocate the datatree? m_OutVectorData = VectorDataType::New(); DataTreeType::Pointer tree = m_OutVectorData->GetDataTree(); DataNodePointer root = tree->GetRoot()->Get(); DataNodePointer document = DataNodeType::New(); document->SetNodeType(DOCUMENT); tree->Add(document, root); // Duno if this makes sense? m_OutVectorData->SetProjectionRef(inputImage->GetProjectionRef()); m_OutVectorData->SetOrigin(inputImage->GetOrigin()); m_OutVectorData->SetSpacing(inputImage->GetSpacing()); // Second iteration, to prepare the samples vnl_vector<IndexValueType> sampledCount(number_of_classes, 0); vnl_vector<IndexValueType> iteratorCount(number_of_classes, 0); IndexValueType n_tot = 0; const IndexValueType target_n_tot = target_number_of_samples.sum(); for (int m_CurrentDivision = 0; m_CurrentDivision < m_NumberOfDivisions; m_CurrentDivision++) { LabelImageType::RegionType streamRegion = m_StreamingManager->GetSplit(m_CurrentDivision); tf::PropagateRequestedRegion<LabelImageType>(inputImage, streamRegion); itk::ImageRegionConstIterator<LabelImageType> inIt (inputImage, streamRegion); for (inIt.GoToBegin() ; !inIt.IsAtEnd() ; ++inIt) { LabelImageType::InternalPixelType classVal = inIt.Get(); if (classVal != nodata) { classVal -= class_begin; // Update the current position iteratorCount[classVal]++; // Every Xi samples (Xi is the step for class i) if (iteratorCount[classVal] % ((int) step_for_class[classVal]) == 0 && sampledCount[classVal] < target_number_of_samples[classVal]) { // Add this sample sampledCount[classVal]++; n_tot++;
351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
ShowProgress(n_tot, target_n_tot); // Create a point LabelImageType::PointType geo; inputImage->TransformIndexToPhysicalPoint(inIt.GetIndex(), geo); DataNodeType::PointType point; point[0] = geo[0]; point[1] = geo[1]; // Add point to the VectorData tree DataNodePointer newDataNode = DataNodeType::New(); newDataNode->SetPoint(point); newDataNode->SetFieldAsInt("class", static_cast<int>(classVal)); tree->Add(newDataNode, document); } // sample this one } } // next pixel } // next streaming region ShowProgressDone(); otbAppLogINFO( "Number of samples in each class: " << sampledCount ); otbAppLogINFO( "Writing output vector data"); SetParameterOutputVectorData("outvec", m_OutVectorData); } private: VectorDataType::Pointer m_OutVectorData; }; // end of class } // end namespace wrapper } // end namespace otb OTB_APPLICATION_EXPORT(otb::Wrapper::LabelImageSampleSelection)