From 2ed36bce71991f3ab3458d753c0bd9ce30eeef8e Mon Sep 17 00:00:00 2001
From: remicres <remi.cresson@teledetection.fr>
Date: Wed, 24 Aug 2016 12:25:11 +0000
Subject: [PATCH] ADD: criterion can be chosen amongst bs/ed/fls

---
 app/otbLSGRM.cxx | 182 +++++++++++++++++++++++++++++++----------------
 1 file changed, 122 insertions(+), 60 deletions(-)

diff --git a/app/otbLSGRM.cxx b/app/otbLSGRM.cxx
index a6618e7..2844ba7 100644
--- a/app/otbLSGRM.cxx
+++ b/app/otbLSGRM.cxx
@@ -12,7 +12,9 @@
 
 // GRM
 #include <iostream>
-#include "lsrmBaatzSegmenter.h"
+#include "lsgrmBaatzSegmenter.h"
+#include "lsgrmSpringSegmenter.h"
+#include "lsgrmFullLambdaScheduleSegmenter.h"
 #include "lsgrmController.h"
 
 // system tools
@@ -28,7 +30,7 @@ class LSGRM : public Application
 {
 public:
   /** Standard class typedefs. */
-  typedef LSGRM                        Self;
+  typedef LSGRM                         Self;
   typedef Application                   Superclass;
   typedef itk::SmartPointer<Self>       Pointer;
   typedef itk::SmartPointer<const Self> ConstPointer;
@@ -37,6 +39,15 @@ public:
   itkNewMacro(Self);
   itkTypeMacro(LSGRM, Application);
 
+  /** Useful typedefs */
+  typedef otb::VectorImage<float, 2>                    ImageType;
+  typedef lsgrm::BaatzSegmenter<ImageType>              BaatzSegmenterType;
+  typedef lsgrm::SpringSegmenter<ImageType>             SpringSegmenterType;
+  typedef lsgrm::FullLambdaScheduleSegmenter<ImageType> FLSSegmenterType;
+  typedef lsgrm::Controller<BaatzSegmenterType>         BaatzControllerType;
+  typedef lsgrm::Controller<SpringSegmenterType>        SpringControllerType;
+  typedef lsgrm::Controller<FLSSegmenterType>           FLSControllerType;
+
 private:
 
   /* Tiling mode choice */
@@ -47,6 +58,14 @@ private:
     TILING_NONE
   };
 
+  /* Criterion choice */
+  enum Criterion
+  {
+    CRITERION_BAATZ,
+    CRITERION_SPRING,
+    CRITERION_FLS
+  };
+
   void DoInit()
   {
     SetName("GenericRegionMerging");
@@ -54,35 +73,33 @@ private:
         "(GRM) and provides currently 3 homogeneity criteria: Euclidean Distance, "
         "Full Lambda Schedule and Baatz & Schape criterion.");
 
+    // Input and Output images
     AddParameter(ParameterType_InputImage, "in", "Input Image");
     AddParameter(ParameterType_OutputImage, "out", "Ouput Label Image");
     SetDefaultOutputPixelType("out", ImagePixelType_uint32);
 
-    //  AddParameter(ParameterType_Choice, "criterion", "Homogeneity criterion to use");
-    //  AddChoice("criterion.bs", "Baatz & Schape");
-    //  AddChoice("criterion.ed", "Euclidean Distance");
-    //  AddChoice("criterion.fls", "Full Lambda Schedule");
+    // Criterion choice
+    AddParameter(ParameterType_Choice, "criterion", "Homogeneity criterion to use");
+    AddChoice("criterion.bs", "Baatz & Schape");
+    AddChoice("criterion.ed", "Euclidean Distance");
+    AddChoice("criterion.fls", "Full Lambda Schedule");
 
+    // Generic parameters
     AddParameter(ParameterType_Float, "threshold", "Threshold for the criterion");
-
     AddParameter(ParameterType_Int, "niter", "Maximum number of iterations");
     SetDefaultParameterInt("niter", 75);
     MandatoryOff("niter");
 
-    AddParameter(ParameterType_Int, "speed", "Activate it to boost the segmentation speed");
-    SetDefaultParameterInt("speed", 0);
-    MandatoryOff("speed");
-
-    // For Baatz & Schape
-    AddParameter(ParameterType_Float, "cw", "Weight for the spectral homogeneity");
-    SetDefaultParameterFloat("cw", 0.5);
-    MandatoryOff("cw");
-    AddParameter(ParameterType_Float, "sw", "Weight for the spatial homogeneity");
-    SetDefaultParameterFloat("sw", 0.5);
-    MandatoryOff("sw");
+    // Specific parameters for Baatz & Schape
+    AddParameter(ParameterType_Float, "criterion.bs.cw", "Weight for the spectral homogeneity");
+    SetDefaultParameterFloat("criterion.bs.cw", 0.5);
+    MandatoryOff("criterion.bs.cw");
+    AddParameter(ParameterType_Float, "criterion.bs.sw", "Weight for the spatial homogeneity");
+    SetDefaultParameterFloat("criterion.bs.sw", 0.5);
+    MandatoryOff("criterion.bs.sw");
 
     // For large scale
-    AddParameter(ParameterType_Directory, "tmpdir", "temporary directory for tiles and graphs");
+    AddParameter(ParameterType_Directory, "tmpdir", "Directory for temporary files");
     AddParameter(ParameterType_Choice, "tiling", "Tiling layout for the large scale segmentation");
     AddChoice("tiling.auto", "Automatic tiling layout");
     AddChoice("tiling.user", "User tiling layout");
@@ -97,6 +114,7 @@ private:
   {
   }
 
+
   /*
    * Return a prefix for temporary files
    */
@@ -122,66 +140,112 @@ private:
     return prefix;
   }
 
-  void DoExecute()
+  /*
+   * This function sets the generic parameters of a controller and runs the segmentation
+   */
+  template<class TController>
+  UInt32ImageType::Pointer SetGenericParametersAndRunSegmentation(typename TController::Pointer& controller)
   {
-    /*
-        To add:
-        internal memory available
-        If we have to do the image division
-        if we have to clean up the directory
-        the output directory in case the global graph cannot fit in memory
+    // Set input image
+    controller->SetInputImage(GetParameterFloatVectorImage("in"));
 
-     */
-
-    using ImageType = otb::VectorImage<float, 2>;
-    using SegmenterType = lsgrm::BaatzSegmenter<ImageType>;
-    using ControllerType = lsgrm::Controller<SegmenterType>;
+    // Set threshold
+    float thres = GetParameterFloat("threshold");
+    controller->SetThreshold(thres*thres);
 
-    ImageType::Pointer inputImage = GetParameterFloatVectorImage("in");
+    // Set number of iterations
+    controller->SetNumberOfIterations(GetParameterInt("niter"));
 
-    ControllerType controller;
-    controller.SetInputImage(inputImage);
-    controller.SetTemporaryFilesPrefix(this->GetTemporaryFilesPrefix());
+    // Set temporary files prefix
+    controller->SetTemporaryFilesPrefix(this->GetTemporaryFilesPrefix());
 
-    // Tiling mode
+    // Switch tiling mode
     int inputTilingMode = GetParameterInt("tiling");
     if (inputTilingMode == TILING_AUTO)
       {
-      controller.SetTilingModeAuto();
+      // Automatic mode
+      controller->SetTilingModeAuto();
       }
     else if (inputTilingMode == TILING_USER)
       {
-      // User
-      controller.SetTilingModeUser();
-      controller.SetTileWidth(GetParameterInt("tiling.user.sizex"));
-      controller.SetTileHeight(GetParameterInt("tiling.user.sizey"));
-      controller.SetNumberOfFirstIterations(GetParameterInt("tiling.user.nfirstiter"));
-      controller.SetInternalMemoryAvailable(GetParameterInt("tiling.user.memory"));
+      // User mode
+      controller->SetTilingModeUser();
+      controller->SetTileWidth(GetParameterInt("tiling.user.sizex"));
+      controller->SetTileHeight(GetParameterInt("tiling.user.sizey"));
+      controller->SetNumberOfFirstIterations(GetParameterInt("tiling.user.nfirstiter"));
+      controller->SetInternalMemoryAvailable(GetParameterInt("tiling.user.memory"));
       }
     else if (inputTilingMode == TILING_NONE)
       {
-      controller.SetTilingModeNone();
+      // None mode
+      controller->SetTilingModeNone();
       }
     else
       {
-      otbAppLogFATAL("Unknow input tiling mode!");
+      otbAppLogFATAL("Unknow tiling mode!");
       }
 
-    // Specific parameters
-    grm::BaatzParam params;
-    params.m_SpectralWeight = GetParameterFloat("cw");
-    params.m_ShapeWeight = GetParameterFloat("sw");
-    controller.SetSpecificParameters(params);
-    float thres = GetParameterFloat("threshold");
-    controller.SetThreshold(thres*thres);
-    controller.SetNumberOfIterations(GetParameterInt("niter"));
-
     // Run the segmentation
-    controller.SetDebug(true);
-    controller.RunSegmentation();
+    controller->RunSegmentation();
 
-    // Output images
-    UInt32ImageType::Pointer labelImage = controller.GetLabeledClusteredOutput();;
+    // Get temporary files list
+    m_TemporaryFilesList = controller->GetTemporaryFilesList();
+
+    // Return the label image
+    return controller->GetLabeledClusteredOutput();
+  }
+
+
+
+  void DoExecute()
+  {
+    /*
+        To add:
+        the output directory in case the global graph cannot fit in memory
+     */
+
+    // Input image
+    ImageType::Pointer inputImage = GetParameterFloatVectorImage("in");
+
+    // Output image
+    UInt32ImageType::Pointer labelImage = UInt32ImageType::New();
+
+    // Switch criterion
+    int inputCriterion = GetParameterInt("criterion");
+    if (inputCriterion == CRITERION_BAATZ)
+      {
+      // Baatz controller
+      BaatzControllerType::Pointer baatzController = BaatzControllerType::New();
+
+      // Specific parameters
+      grm::BaatzParam params;
+      params.m_SpectralWeight = GetParameterFloat("criterion.bs.cw");
+      params.m_ShapeWeight = GetParameterFloat("criterion.bs.sw");
+      baatzController->SetSpecificParameters(params);
+
+      // Run segmentation
+      labelImage = SetGenericParametersAndRunSegmentation<BaatzControllerType>(baatzController);
+      }
+    else if (inputCriterion == CRITERION_SPRING)
+      {
+      // Spring controller
+      SpringControllerType::Pointer springController = SpringControllerType::New();
+
+      // Run segmentation
+      labelImage = SetGenericParametersAndRunSegmentation<SpringControllerType>(springController);
+      }
+    else if (inputCriterion == CRITERION_FLS)
+      {
+      // Full Lambda Schedule (FLS) controller
+      FLSControllerType::Pointer flsController = FLSControllerType::New();
+
+      // Run segmentation
+      labelImage = SetGenericParametersAndRunSegmentation<FLSControllerType>(flsController);
+      }
+    else
+      {
+      otbAppLogFATAL("Unknow criterion!")
+      }
 
     // Set output image projection, origin and spacing for labelImage
     labelImage->SetProjectionRef(inputImage->GetProjectionRef());
@@ -189,8 +253,6 @@ private:
     labelImage->SetSpacing(inputImage->GetSpacing());
     SetParameterOutputImage<UInt32ImageType>("out", labelImage);
 
-    // Get temporary files list
-    m_TemporaryFilesList = controller.GetTemporaryFilesList();
 
   }
 
-- 
GitLab