diff --git a/app/otbPatchesExtraction.cxx b/app/otbPatchesExtraction.cxx index 64aa24e8b2f6a5998cb814ac5c91440dba70604a..b1b8ed714a422626dddde52ae6340da02cd1f31f 100644 --- a/app/otbPatchesExtraction.cxx +++ b/app/otbPatchesExtraction.cxx @@ -61,6 +61,9 @@ public: std::string m_KeyOut; // Key of output samples image std::string m_KeyPszX; // Key for samples sizes X std::string m_KeyPszY; // Key for samples sizes Y + std::string m_KeyNoData; // Key for no-data value + + FloatVectorImageType::InternalPixelType m_NoDataValue; // No data value }; @@ -77,7 +80,7 @@ public: // Create keys and descriptions std::stringstream ss_group_key, ss_desc_group, ss_key_in, ss_key_out, ss_desc_in, - ss_desc_out, ss_key_dims_x, ss_desc_dims_x, ss_key_dims_y, ss_desc_dims_y; + ss_desc_out, ss_key_dims_x, ss_desc_dims_x, ss_key_dims_y, ss_desc_dims_y, ss_key_nodata, ss_desc_nodata; ss_group_key << "source" << inputNumber; ss_desc_group << "Parameters for source " << inputNumber; ss_key_out << ss_group_key.str() << ".out"; @@ -88,6 +91,8 @@ public: ss_desc_dims_x << "X patch size for image " << inputNumber; ss_key_dims_y << ss_group_key.str() << ".patchsizey"; ss_desc_dims_y << "Y patch size for image " << inputNumber; + ss_key_dims_y << ss_group_key.str() << ".nodata"; + ss_desc_dims_y << "No-data value for image (used only if \"usenodata\" is on)" << inputNumber; // Populate group AddParameter(ParameterType_Group, ss_group_key.str(), ss_desc_group.str()); @@ -97,6 +102,8 @@ public: SetMinimumParameterIntValue (ss_key_dims_x.str(), 1); AddParameter(ParameterType_Int, ss_key_dims_y.str(), ss_desc_dims_y.str()); SetMinimumParameterIntValue (ss_key_dims_y.str(), 1); + AddParameter(ParameterType_Float, ss_key_nodata.str(), ss_desc_nodata.str()); + SetDefaultParameterFloat (ss_key_nodata.str(), 0); // Add a new bundle SourceBundle bundle; @@ -104,6 +111,7 @@ public: bundle.m_KeyOut = ss_key_out.str(); bundle.m_KeyPszX = ss_key_dims_x.str(); bundle.m_KeyPszY = ss_key_dims_y.str(); + bundle.m_KeyNoData = ss_key_nodata.str(); m_Bundles.push_back(bundle); @@ -123,6 +131,9 @@ public: // Patch size bundle.m_PatchSize[0] = GetParameterInt(bundle.m_KeyPszX); bundle.m_PatchSize[1] = GetParameterInt(bundle.m_KeyPszY); + + // No data value + bundle.m_NoDataValue = GetParameterFloat(bundle.m_KeyNoData); } } @@ -166,8 +177,6 @@ public: // No data parameters AddParameter(ParameterType_Bool, "usenodata", "Reject samples that have no-data value"); MandatoryOff ("usenodata"); - AddParameter(ParameterType_Float, "nodataval", "No data value (used only if \"usenodata\" is on)"); - SetDefaultParameterFloat( "nodataval", 0.0); // Output label AddParameter(ParameterType_OutputImage, "outlabels", "output labels"); @@ -201,13 +210,10 @@ public: { otbAppLogINFO("Rejecting samples that have at least one no-data value"); sampler->SetRejectPatchesWithNodata(true); - float ndval = GetParameterFloat("nodataval"); - otbAppLogINFO("No-data value: " << ndval); - sampler->SetNodataValue(ndval); } for (auto& bundle: m_Bundles) { - sampler->PushBackInputWithPatchSize(bundle.m_ImageSource.Get(), bundle.m_PatchSize); + sampler->PushBackInputWithPatchSize(bundle.m_ImageSource.Get(), bundle.m_PatchSize, bundle.m_NoDataValue); } // Run the filter diff --git a/include/otbTensorflowSampler.h b/include/otbTensorflowSampler.h index 274a28307aaaf48d02cbde875d814b978c8529dc..eb2645e73075b34c3f404c65e7aae874096f1849 100644 --- a/include/otbTensorflowSampler.h +++ b/include/otbTensorflowSampler.h @@ -106,12 +106,10 @@ public: itkGetConstMacro(InputVectorData, VectorDataPointer); /** Set / get image */ - virtual void PushBackInputWithPatchSize(const ImageType *input, SizeType & patchSize); + virtual void PushBackInputWithPatchSize(const ImageType *input, SizeType & patchSize, InternalPixelType nodataval); const ImageType* GetInput(unsigned int index); /** Set / get no-data related parameters */ - itkSetMacro(NodataValue, InternalPixelType); - itkGetMacro(NodataValue, InternalPixelType); itkSetMacro(RejectPatchesWithNodata, bool); itkGetMacro(RejectPatchesWithNodata, bool); @@ -146,7 +144,7 @@ private: unsigned long m_NumberOfRejectedSamples; // No data stuff - InternalPixelType m_NodataValue; + std::vector<InternalPixelType> m_NoDataValues; bool m_RejectPatchesWithNodata; }; // end class diff --git a/include/otbTensorflowSampler.hxx b/include/otbTensorflowSampler.hxx index f14a66af91629355a9867c750e164d66efb022d1..a99286c98434fd13e79bab0b5b0b2e55860988f6 100644 --- a/include/otbTensorflowSampler.hxx +++ b/include/otbTensorflowSampler.hxx @@ -23,16 +23,16 @@ TensorflowSampler<TInputImage, TVectorData> m_NumberOfAcceptedSamples = 0; m_NumberOfRejectedSamples = 0; m_RejectPatchesWithNodata = false; - m_NodataValue = 0; } template <class TInputImage, class TVectorData> void TensorflowSampler<TInputImage, TVectorData> -::PushBackInputWithPatchSize(const ImageType *input, SizeType & patchSize) +::PushBackInputWithPatchSize(const ImageType *input, SizeType & patchSize, InternalPixelType nodataval) { this->ProcessObject::PushBackInput(const_cast<ImageType*>(input)); m_PatchSizes.push_back(patchSize); + m_NoDataValues.push_back(nodataval); } template <class TInputImage, class TVectorData> @@ -198,7 +198,7 @@ TensorflowSampler<TInputImage, TVectorData> { PixelType pix = it.Get(); for (unsigned int band = 0 ; band < pix.Size() ; band++) - if (pix[band] == m_NodataValue) + if (pix[band] == m_NoDataValues[i]) hasBeenSampled = false; }