diff --git a/include/otbTensorflowSampler.hxx b/include/otbTensorflowSampler.hxx index 966a37969c43ffdb9b7df32ea21dc8a0c7330dd2..420962b219bad8bdec75671c83b11551515734a9 100644 --- a/include/otbTensorflowSampler.hxx +++ b/include/otbTensorflowSampler.hxx @@ -179,57 +179,60 @@ TensorflowSampler<TInputImage, TVectorData>::Update() if (!itVector.Get()->IsRoot() && !itVector.Get()->IsDocument() && !itVector.Get()->IsFolder()) { DataNodePointer currentGeometry = itVector.Get(); - PointType point = currentGeometry->GetPoint(); - - // Get the label value - labelPix[0] = static_cast<InternalPixelType>(currentGeometry->GetFieldAsInt(m_Field)); - - bool hasBeenSampled = true; - for (unsigned int i = 0; i < nbInputs; i++) + if (!currentGeometry->HasField(m_Field)) { - // Get input - ImagePointerType inputPtr = const_cast<ImageType *>(this->GetInput(i)); + PointType point = currentGeometry->GetPoint(); - // Try to sample the image - if (!tf::SampleImage<ImageType>(inputPtr, m_OutputPatchImages[i], point, count, m_PatchSizes[i])) - { - // If not, reject this sample - hasBeenSampled = false; - } - // If NoData is provided, check if the sampled patch contains a no-data value - if (m_NoDataValues.count(i) > 0 && hasBeenSampled) + // Get the label value + labelPix[0] = static_cast<InternalPixelType>(currentGeometry->GetFieldAsInt(m_Field)); + + bool hasBeenSampled = true; + for (unsigned int i = 0; i < nbInputs; i++) { - IndexType outIndex; - outIndex[0] = 0; - outIndex[1] = count * m_PatchSizes[i][1]; - RegionType region(outIndex, m_PatchSizes[i]); + // Get input + ImagePointerType inputPtr = const_cast<ImageType *>(this->GetInput(i)); - IteratorType it(m_OutputPatchImages[i], region); - for (it.GoToBegin(); !it.IsAtEnd(); ++it) + // Try to sample the image + if (!tf::SampleImage<ImageType>(inputPtr, m_OutputPatchImages[i], point, count, m_PatchSizes[i])) + { + // If not, reject this sample + hasBeenSampled = false; + } + // If NoData is provided, check if the sampled patch contains a no-data value + if (m_NoDataValues.count(i) > 0 && hasBeenSampled) { - PixelType pix = it.Get(); - for (unsigned int band = 0; band < pix.Size(); band++) - if (pix[band] == m_NoDataValues[i]) - hasBeenSampled = false; + IndexType outIndex; + outIndex[0] = 0; + outIndex[1] = count * m_PatchSizes[i][1]; + RegionType region(outIndex, m_PatchSizes[i]); + + IteratorType it(m_OutputPatchImages[i], region); + for (it.GoToBegin(); !it.IsAtEnd(); ++it) + { + PixelType pix = it.Get(); + for (unsigned int band = 0; band < pix.Size(); band++) + if (pix[band] == m_NoDataValues[i]) + hasBeenSampled = false; + } } + } // Next input + if (hasBeenSampled) + { + // Fill label + labelIndex[1] = count; + m_OutputLabelImage->SetPixel(labelIndex, labelPix); + + // update count + count++; + } + else + { + rejected++; } - } // Next input - if (hasBeenSampled) - { - // Fill label - labelIndex[1] = count; - m_OutputLabelImage->SetPixel(labelIndex, labelPix); - // update count - count++; + // Update progress + progress.CompletedPixel(); } - else - { - rejected++; - } - - // Update progress - progress.CompletedPixel(); } ++itVector;