From 6c8abad77707954334900adf4c84328f794eaa4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Traizet?= <cedric.traizet@c-s.fr> Date: Mon, 29 Apr 2019 18:01:38 +0200 Subject: [PATCH] ENH: template specialization, use GetFieldAsInteger instead of double is the classification case --- .../include/otbTrainVectorBase.h | 6 ++++++ .../include/otbTrainVectorBase.hxx | 19 ++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h index 477b798d58..52675e935f 100644 --- a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h +++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h @@ -186,6 +186,12 @@ protected: void DoUpdateParameters() override; void DoExecute() override; +private: + /** + * Get the field of the input feature corresponding to the input field + */ + inline TOutputValue GetFeatureField(const ogr::Feature & feature, int field); + }; } diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx b/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx index 8de3e03623..879f9b8db5 100644 --- a/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx +++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx @@ -245,6 +245,23 @@ TrainVectorBase<TInputValue, TOutputValue> return measurement; } +// Template specialization for the integer case (i.e.classification), to avoid a cast from double to integer +template <> +inline int +TrainVectorBase<float, int> +::GetFeatureField(const ogr::Feature & feature, int fieldIndex) +{ + return(feature.ogr().GetFieldAsInteger( fieldIndex )); +} + +template <class TInputValue, class TOutputValue> +inline TOutputValue +TrainVectorBase<TInputValue, TOutputValue> +::GetFeatureField(const ogr::Feature & feature, int fieldIndex) +{ + return(feature.ogr().GetFieldAsDouble( fieldIndex )); +} + template <class TInputValue, class TOutputValue> typename TrainVectorBase<TInputValue, TOutputValue>::SamplesWithLabel TrainVectorBase<TInputValue, TOutputValue> @@ -306,7 +323,7 @@ TrainVectorBase<TInputValue, TOutputValue> input->PushBack( mv ); if(cFieldIndex>=0 && ogr::Field(feature,cFieldIndex).HasBeenSet()) - target->PushBack( feature.ogr().GetFieldAsDouble( cFieldIndex ) ); + target->PushBack(GetFeatureField(feature,cFieldIndex)); else target->PushBack( 0. ); -- GitLab