otbImageDimensionalityReduction.cxx 9.66 KiB
/*=========================================================================
  Program:   ORFEO Toolbox
  Language:  C++
  Date:      $Date$
  Version:   $Revision$
  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
  See OTBCopyright.txt for details.
     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.
=========================================================================*/
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
#include "itkUnaryFunctorImageFilter.h"
#include "otbChangeLabelImageFilter.h"
#include "otbStandardWriterWatcher.h"
#include "otbStatisticsXMLFileReader.h"
#include "otbShiftScaleVectorImageFilter.h"
#include "ImageDimensionalityReductionFilter.h"
#include "otbMultiToMonoChannelExtractROI.h"
#include "otbImageToVectorImageCastFilter.h"
#include "DimensionalityReductionModelFactory.h"
namespace otb
namespace Functor
/**
 * simple affine function : y = ax+b
template<class TInput, class TOutput>
class AffineFunctor
public:
  typedef double InternalType;
  // constructor
  AffineFunctor() : m_A(1.0),m_B(0.0) {}
  // destructor
  virtual ~AffineFunctor() {}
  void SetA(InternalType a)
    m_A = a;
  void SetB(InternalType b)
    m_B = b;
  inline TOutput operator()(const TInput & x) const
    return static_cast<TOutput>( static_cast<InternalType>(x)*m_A + m_B);
private:
  InternalType m_A;
  InternalType m_B;
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
namespace Wrapper { class ImageDimensionalityReduction : public Application { public: /** Standard class typedefs. */ typedef ImageDimensionalityReduction Self; typedef Application Superclass; typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; /** Standard macro */ itkNewMacro(Self); itkTypeMacro(ImageDimensionalityReduction, otb::Application); /** Filters typedef */ typedef UInt8ImageType MaskImageType; typedef itk::VariableLengthVector<FloatVectorImageType::InternalPixelType> MeasurementType; typedef otb::StatisticsXMLFileReader<MeasurementType> StatisticsReader; typedef otb::ShiftScaleVectorImageFilter<FloatVectorImageType, FloatVectorImageType> RescalerType; typedef itk::UnaryFunctorImageFilter< FloatImageType, FloatImageType, otb::Functor::AffineFunctor<float,float> > OutputRescalerType; typedef otb::ImageDimensionalityReductionFilter<FloatVectorImageType, FloatVectorImageType, MaskImageType> DimensionalityReductionFilterType; typedef DimensionalityReductionFilterType::Pointer DimensionalityReductionFilterPointerType; typedef DimensionalityReductionFilterType::ModelType ModelType; typedef ModelType::Pointer ModelPointerType; typedef DimensionalityReductionFilterType::ValueType ValueType; typedef DimensionalityReductionFilterType::LabelType LabelType; typedef otb::DimensionalityReductionModelFactory<ValueType, LabelType> DimensionalityReductionModelFactoryType; protected: ~ImageDimensionalityReduction() ITK_OVERRIDE { DimensionalityReductionModelFactoryType::CleanFactories(); } private: void DoInit() ITK_OVERRIDE { SetName("DimensionalityReduction"); SetDescription("Performs dimensionality reduction of the input image according to a dimensionality reduction model file."); // Documentation SetDocName("DimensionalityReduction"); SetDocLongDescription("This application reduces the dimension of an input" " image, based on a machine learning model file produced by" " the TrainDimensionalityReduction application. Pixels of the " "output image will contain the reduced values from" "the model. The input pixels" " can be optionally centered and reduced according " "to the statistics file produced by the " "ComputeImagesStatistics application. "); SetDocLimitations("The input image must contain the feature bands used for" " the model training. " "If a statistics file was used during training by the " "Training application, it is mandatory to use the same " "statistics file for reduction."); SetDocAuthors("OTB-Team"); SetDocSeeAlso("TrainDimensionalityReduction, ComputeImagesStatistics"); AddDocTag(Tags::Learning); AddParameter(ParameterType_InputImage, "in", "Input Image"); SetParameterDescription( "in", "The input image to predict.");
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
AddParameter(ParameterType_InputImage, "mask", "Input Mask"); SetParameterDescription( "mask", "The mask allow restricting " "classification of the input image to the area where mask pixel values " "are greater than 0."); MandatoryOff("mask"); AddParameter(ParameterType_InputFilename, "model", "Model file"); SetParameterDescription("model", "A dimensionality reduction model file (produced by " "TrainRegression application)."); AddParameter(ParameterType_InputFilename, "imstat", "Statistics file"); SetParameterDescription("imstat", "A XML file containing mean and standard" " deviation to center and reduce samples before prediction " "(produced by ComputeImagesStatistics application). If this file contains" "one more bands than the sample size, the last stat of last band will be" "applied to expand the output predicted value"); MandatoryOff("imstat"); AddParameter(ParameterType_OutputImage, "out", "Output Image"); SetParameterDescription( "out", "Output image containing reduced values"); AddRAMParameter(); // Doc example parameter settings SetDocExampleParameterValue("in", "QB_1_ortho.tif"); SetDocExampleParameterValue("imstat", "EstimateImageStatisticsQB1.xml"); SetDocExampleParameterValue("model", "clsvmModelQB1.model"); SetDocExampleParameterValue("out", "ReducedImageQB1.tif"); } void DoUpdateParameters() ITK_OVERRIDE { // Nothing to do here : all parameters are independent } void DoExecute() ITK_OVERRIDE { // Load input image FloatVectorImageType::Pointer inImage = GetParameterImage("in"); inImage->UpdateOutputInformation(); unsigned int nbFeatures = inImage->GetNumberOfComponentsPerPixel(); // Load DR model using a factory otbAppLogINFO("Loading model"); m_Model = DimensionalityReductionModelFactoryType::CreateDimensionalityReductionModel(GetParameterString("model"), DimensionalityReductionModelFactoryType::ReadMode); if (m_Model.IsNull()) { otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type"); } m_Model->Load(GetParameterString("model")); otbAppLogINFO("Model loaded"); // Classify m_ClassificationFilter = DimensionalityReductionFilterType::New(); m_ClassificationFilter->SetModel(m_Model); FloatVectorImageType::Pointer outputImage = m_ClassificationFilter->GetOutput(); // Normalize input image if asked if(IsParameterEnabled("imstat") ) { otbAppLogINFO("Input image normalization activated."); // Normalize input image (optional) StatisticsReader::Pointer statisticsReader = StatisticsReader::New(); MeasurementType meanMeasurementVector; MeasurementType stddevMeasurementVector;
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
m_Rescaler = RescalerType::New(); // Load input image statistics statisticsReader->SetFileName(GetParameterString("imstat")); meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean"); stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev"); otbAppLogINFO( "mean used: " << meanMeasurementVector ); otbAppLogINFO( "standard deviation used: " << stddevMeasurementVector ); if (meanMeasurementVector.Size() != nbFeatures) { otbAppLogFATAL("Wrong number of components in statistics file : "<<meanMeasurementVector.Size()); } // Rescale vector image m_Rescaler->SetScale(stddevMeasurementVector); m_Rescaler->SetShift(meanMeasurementVector); m_Rescaler->SetInput(inImage); m_ClassificationFilter->SetInput(m_Rescaler->GetOutput()); } else { otbAppLogINFO("Input image normalization deactivated."); m_ClassificationFilter->SetInput(inImage); } if(IsParameterEnabled("mask")) { otbAppLogINFO("Using input mask"); // Load mask image and cast into LabeledImageType MaskImageType::Pointer inMask = GetParameterUInt8Image("mask"); m_ClassificationFilter->SetInputMask(inMask); } SetParameterOutputImage<FloatVectorImageType>("out", outputImage); } DimensionalityReductionFilterType::Pointer m_ClassificationFilter; ModelPointerType m_Model; RescalerType::Pointer m_Rescaler; OutputRescalerType::Pointer m_OutRescaler; }; } } OTB_APPLICATION_EXPORT(otb::Wrapper::ImageDimensionalityReduction)