diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h index 477b798d58b550cba0d80347042c57a68cd01a2d..52675e935fd425deb13c91caad4e40a3d7412e3d 100644 --- a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h +++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h @@ -186,6 +186,12 @@ protected: void DoUpdateParameters() override; void DoExecute() override; +private: + /** + * Get the field of the input feature corresponding to the input field + */ + inline TOutputValue GetFeatureField(const ogr::Feature & feature, int field); + }; } diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx b/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx index 8de3e03623b48470f0d9454b0e6b67df877bdd74..879f9b8db50124026db448f98f28437ddd144640 100644 --- a/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx +++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.hxx @@ -245,6 +245,23 @@ TrainVectorBase<TInputValue, TOutputValue> return measurement; } +// Template specialization for the integer case (i.e.classification), to avoid a cast from double to integer +template <> +inline int +TrainVectorBase<float, int> +::GetFeatureField(const ogr::Feature & feature, int fieldIndex) +{ + return(feature.ogr().GetFieldAsInteger( fieldIndex )); +} + +template <class TInputValue, class TOutputValue> +inline TOutputValue +TrainVectorBase<TInputValue, TOutputValue> +::GetFeatureField(const ogr::Feature & feature, int fieldIndex) +{ + return(feature.ogr().GetFieldAsDouble( fieldIndex )); +} + template <class TInputValue, class TOutputValue> typename TrainVectorBase<TInputValue, TOutputValue>::SamplesWithLabel TrainVectorBase<TInputValue, TOutputValue> @@ -306,7 +323,7 @@ TrainVectorBase<TInputValue, TOutputValue> input->PushBack( mv ); if(cFieldIndex>=0 && ogr::Field(feature,cFieldIndex).HasBeenSet()) - target->PushBack( feature.ogr().GetFieldAsDouble( cFieldIndex ) ); + target->PushBack(GetFeatureField(feature,cFieldIndex)); else target->PushBack( 0. );