diff --git a/app/otbPatchesExtraction.cxx b/app/otbPatchesExtraction.cxx index 79f5711d3df3ae6fc152db242975f44e33bf541a..2f9cee23781984f9f9749b7cd3c360274c5230fd 100644 --- a/app/otbPatchesExtraction.cxx +++ b/app/otbPatchesExtraction.cxx @@ -163,6 +163,12 @@ public: // Input vector data AddParameter(ParameterType_InputVectorData, "vec", "Positions of the samples (must be in the same projection as input image)"); + // No data parameters + AddParameter(ParameterType_Bool, "usenodata", "Reject samples that have no-data value"); + MandatoryOff ("usenodata"); + AddParameter(ParameterType_Float, "nodataval", "No data value (used only if \"usenodata\" is on)"); + SetDefaultParameterFloat( "nodataval", 0.0); + // Output label AddParameter(ParameterType_OutputImage, "outlabels", "output labels"); SetDefaultOutputPixelType ("outlabels", ImagePixelType_uint8); @@ -191,6 +197,8 @@ public: SamplerType::Pointer sampler = SamplerType::New(); sampler->SetInputVectorData(GetParameterVectorData("vec")); sampler->SetField(GetParameterAsString("field")); + sampler->SetRejectPatchesWithNodata(GetParameterInt("usenodata")==1); + sampler->SetNodataValue(GetParameterFloat("nodataval")); for (auto& bundle: m_Bundles) { sampler->PushBackInputWithPatchSize(bundle.m_ImageSource.Get(), bundle.m_PatchSize); diff --git a/app/otbPatchesSelection.cxx b/app/otbPatchesSelection.cxx index d32950f6227fd5c96547c3f3e54c5ea25dd69570..f84fc9a639294a4cfc26b8856bb3f70c4ec3d635 100644 --- a/app/otbPatchesSelection.cxx +++ b/app/otbPatchesSelection.cxx @@ -27,6 +27,7 @@ // image utils #include "otbTensorflowCommon.h" #include "otbTensorflowSamplingUtils.h" +#include "itkImageRegionConstIteratorWithOnlyIndex.h" // Functor to retrieve nodata template<class TPixel, class OutputPixel> @@ -122,6 +123,8 @@ public: AddParameter(ParameterType_Float, "nodata", "nodata value"); MandatoryOn ("nodata"); SetDefaultParameterFloat ("nodata", 0); + AddParameter(ParameterType_Bool, "nocheck", "If on, no check on the validity of patches is performed"); + MandatoryOff ("nocheck"); // Grid AddParameter(ParameterType_Group, "grid", "grid settings"); @@ -238,6 +241,18 @@ public: */ template<typename TLambda> void Apply(TLambda lambda) + { + if (GetParameterInt("nockeck")==1) + ApplyFast(lambda); + else + ApplyWithCheck(lambda); + } + + /* + * Apply the given function at each sampling location, checking if the patch is valid or not + */ + template<typename TLambda> + void ApplyWithCheck(TLambda lambda) { // Explicit streaming over the morphed mask, based on the RAM parameter @@ -297,6 +312,49 @@ public: } } + /* + * Apply the given function at each sampling location, without checking the valid pixels under + */ + template<typename TLambda> + void ApplyFast(TLambda lambda) + { + + FloatVectorImageType::Pointer inputImage = GetParameterFloatVectorImage("in"); + FloatVectorImageType::RegionType entireRegion = inputImage->GetLargestPossibleRegion(); + entireRegion.ShrinkByRadius(m_Radius); + FloatVectorImageType::IndexType start; + start[0] = m_Radius[0] + 1; + start[1] = m_Radius[1] + 1; + FloatVectorImageType::IndexType pos; + pos.Fill(0); + FloatVectorImageType::IndexValueType step = GetParameterInt("grid.step"); + + typedef itk::ImageRegionConstIteratorWithOnlyIndex<FloatVectorImageType> IteratorType; + IteratorType inIt (inputImage, entireRegion); + for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt) + { + FloatVectorImageType::IndexType idx = inIt.GetIndex(); + idx[0] -= start[0]; + idx[1] -= start[1]; + if (idx[0] % step == 0 && idx[1] % step == 0) + { + // Update grid position + pos[0] = idx[0] / step; + pos[1] = idx[1] / step; + + // Compute coordinates + FloatVectorImageType::PointType geo; + inputImage->TransformIndexToPhysicalPoint(inIt.GetIndex(), geo); + DataNodeType::PointType point; + point[0] = geo[0]; + point[1] = geo[1]; + + // Lambda call + lambda(pos, geo); + } + } + } + /* * Allocate a std::vector of sample bundle */ @@ -595,6 +653,12 @@ public: SampleChessboard(); } + else if (GetParameterAsString("strategy") == "chessboardfast") + { + otbAppLogINFO("Sampling at regular interval in space (\"Chessboard\" like) without checking image content"); + + SampleChessboardFast(); + } else if (GetParameterAsString("strategy") == "balanced") { otbAppLogINFO("Sampling with balancing strategy"); diff --git a/include/otbTensorflowSampler.h b/include/otbTensorflowSampler.h index 0fe60193f464fafe9e09d481ae72048f132fcc8f..d86b769d530eb9741797c58a7be270d945ff0b86 100644 --- a/include/otbTensorflowSampler.h +++ b/include/otbTensorflowSampler.h @@ -24,6 +24,9 @@ // Tree iterator #include "itkPreOrderTreeIterator.h" +// Image iterator +#include "itkImageRegionConstIterator.h" + namespace otb { @@ -80,6 +83,7 @@ public: ExtractROIMultiFilterPointerType; typedef typename std::vector<ImagePointerType> ImagePointerListType; typedef typename std::vector<SizeType> SizeListType; + typedef typename itk::ImageRegionConstIterator IteratorType; /** Vector data typedefs */ typedef TVectorData VectorDataType; @@ -104,6 +108,12 @@ public: virtual void PushBackInputWithPatchSize(const ImageType *input, SizeType & patchSize); const ImageType* GetInput(unsigned int index); + /** Set / get no-data related parameters */ + itkSetMacro(NodataValue, InternalPixelType); + itkGetMacro(NodataValue, InternalPixelType); + itkSetMacro(RejectPatchesWithNodata, bool); + itkGetMacro(RejectPatchesWithNodata, bool); + /** Do the real work */ virtual void Update(); @@ -134,6 +144,10 @@ private: unsigned long m_NumberOfAcceptedSamples; unsigned long m_NumberOfRejectedSamples; + // No data stuff + InternalPixelType m_NodataValue; + bool m_RejectPatchesWithNodata; + }; // end class } // end namespace otb diff --git a/include/otbTensorflowSampler.hxx b/include/otbTensorflowSampler.hxx index b7d4dd82a0f4e96a69c5469aff863188c99d3f8f..1a830c2b151ec7cc49c0fc2039afd70f26af8215 100644 --- a/include/otbTensorflowSampler.hxx +++ b/include/otbTensorflowSampler.hxx @@ -181,6 +181,29 @@ TensorflowSampler<TInputImage, TVectorData> // If not, reject this sample hasBeenSampled = false; } + // Check if it contains no-data values + if (m_RejectPatchesWithNodata && hasBeenSampled) + { + IndexType outIndex; + outIndex[0] = 0; + outIndex[1] = count * m_PatchSizes[i][1]; + RegionType region(outIndex, m_PatchSizes); + + IteratorType it(m_OutputPatchImages[i], region); + for (it.GoToBegin(); !it.IsAtEnd(); ++it) + { + PixelType pix = it.Get(); + for (int i; i<pix.Size(); i++) + if (pix[i] == m_NodataValue) + { + hasBeenSampled = false; + break; + } + if (hasBeenSampled) + break; + } + + } } // Next input if (hasBeenSampled) {