From 6c8abad77707954334900adf4c84328f794eaa4d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?C=C3=A9dric=20Traizet?= <cedric.traizet@c-s.fr>
Date: Mon, 29 Apr 2019 18:01:38 +0200
Subject: [PATCH] ENH: template specialization, use GetFieldAsInteger instead
 of double is the classification case

---
 .../include/otbTrainVectorBase.h              |  6 ++++++
 .../include/otbTrainVectorBase.hxx            | 19 ++++++++++++++++++-
 2 files changed, 24 insertions(+), 1 deletion(-)

diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h
index 477b798d58..52675e935f 100644
--- a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h
+++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h
@@ -186,6 +186,12 @@ protected:
   void DoUpdateParameters() override;
   void DoExecute() override;
 
+private:
+  /**
+   * Get the field of the input feature corresponding to the input field
+   */
+  inline TOutputValue GetFeatureField(const ogr::Feature & feature, int field);
+
 };
 
 }
diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx b/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx
index 8de3e03623..879f9b8db5 100644
--- a/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx
+++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx
@@ -245,6 +245,23 @@ TrainVectorBase<TInputValue, TOutputValue>
   return measurement;
 }
 
+// Template specialization for the integer case (i.e.classification), to avoid a cast from double to integer
+template <>
+inline int
+TrainVectorBase<float, int>
+::GetFeatureField(const ogr::Feature & feature, int fieldIndex)
+{
+  return(feature.ogr().GetFieldAsInteger( fieldIndex ));
+}
+
+template <class TInputValue, class TOutputValue>
+inline TOutputValue
+TrainVectorBase<TInputValue, TOutputValue>
+::GetFeatureField(const ogr::Feature & feature, int fieldIndex)
+{
+  return(feature.ogr().GetFieldAsDouble( fieldIndex ));
+}
+
 template <class TInputValue, class TOutputValue>
 typename TrainVectorBase<TInputValue, TOutputValue>::SamplesWithLabel
 TrainVectorBase<TInputValue, TOutputValue>
@@ -306,7 +323,7 @@ TrainVectorBase<TInputValue, TOutputValue>
         input->PushBack( mv );
 
         if(cFieldIndex>=0 && ogr::Field(feature,cFieldIndex).HasBeenSet())
-          target->PushBack( feature.ogr().GetFieldAsDouble( cFieldIndex ) );
+          target->PushBack(GetFeatureField(feature,cFieldIndex));
         else
           target->PushBack( 0. );
 
-- 
GitLab