Commit dc7fc35c authored by Cresson Remi's avatar Cresson Remi
Browse files

Merge branch 'develop' of gitlab-ssh.irstea.fr:remi.cresson/otbtf

parents 0aa20d20 e3494494
......@@ -22,6 +22,10 @@ if(OTB_USE_TENSORFLOW)
endif()
# Tensorflow-independent APPS
OTB_CREATE_APPLICATION(NAME PatchesSelection
SOURCES otbPatchesSelection.cxx
LINK_LIBRARIES ${${otb-module}_LIBRARIES}
)
OTB_CREATE_APPLICATION(NAME PatchesExtraction
SOURCES otbPatchesExtraction.cxx
LINK_LIBRARIES ${${otb-module}_LIBRARIES}
......
......@@ -163,6 +163,12 @@ public:
// Input vector data
AddParameter(ParameterType_InputVectorData, "vec", "Positions of the samples (must be in the same projection as input image)");
// No data parameters
AddParameter(ParameterType_Bool, "usenodata", "Reject samples that have no-data value");
MandatoryOff ("usenodata");
AddParameter(ParameterType_Float, "nodataval", "No data value (used only if \"usenodata\" is on)");
SetDefaultParameterFloat( "nodataval", 0.0);
// Output label
AddParameter(ParameterType_OutputImage, "outlabels", "output labels");
SetDefaultOutputPixelType ("outlabels", ImagePixelType_uint8);
......@@ -191,6 +197,8 @@ public:
SamplerType::Pointer sampler = SamplerType::New();
sampler->SetInputVectorData(GetParameterVectorData("vec"));
sampler->SetField(GetParameterAsString("field"));
sampler->SetRejectPatchesWithNodata(GetParameterInt("usenodata")==1);
sampler->SetNodataValue(GetParameterFloat("nodataval"));
for (auto& bundle: m_Bundles)
{
sampler->PushBackInputWithPatchSize(bundle.m_ImageSource.Get(), bundle.m_PatchSize);
......
This diff is collapsed.
......@@ -24,6 +24,9 @@
// Tree iterator
#include "itkPreOrderTreeIterator.h"
// Image iterator
#include "itkImageRegionConstIterator.h"
namespace otb
{
......@@ -80,6 +83,8 @@ public:
ExtractROIMultiFilterPointerType;
typedef typename std::vector<ImagePointerType> ImagePointerListType;
typedef typename std::vector<SizeType> SizeListType;
typedef typename itk::ImageRegionConstIterator<ImageType>
IteratorType;
/** Vector data typedefs */
typedef TVectorData VectorDataType;
......@@ -104,6 +109,12 @@ public:
virtual void PushBackInputWithPatchSize(const ImageType *input, SizeType & patchSize);
const ImageType* GetInput(unsigned int index);
/** Set / get no-data related parameters */
itkSetMacro(NodataValue, InternalPixelType);
itkGetMacro(NodataValue, InternalPixelType);
itkSetMacro(RejectPatchesWithNodata, bool);
itkGetMacro(RejectPatchesWithNodata, bool);
/** Do the real work */
virtual void Update();
......@@ -134,6 +145,10 @@ private:
unsigned long m_NumberOfAcceptedSamples;
unsigned long m_NumberOfRejectedSamples;
// No data stuff
InternalPixelType m_NodataValue;
bool m_RejectPatchesWithNodata;
}; // end class
} // end namespace otb
......
......@@ -181,6 +181,29 @@ TensorflowSampler<TInputImage, TVectorData>
// If not, reject this sample
hasBeenSampled = false;
}
// Check if it contains no-data values
if (m_RejectPatchesWithNodata && hasBeenSampled)
{
IndexType outIndex;
outIndex[0] = 0;
outIndex[1] = count * m_PatchSizes[i][1];
RegionType region(outIndex, m_PatchSizes[i]);
IteratorType it(m_OutputPatchImages[i], region);
for (it.GoToBegin(); !it.IsAtEnd(); ++it)
{
PixelType pix = it.Get();
for (int i; i<pix.Size(); i++)
if (pix[i] == m_NodataValue)
{
hasBeenSampled = false;
break;
}
if (hasBeenSampled)
break;
}
}
} // Next input
if (hasBeenSampled)
{
......
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbTensorflowSamplingUtils.h"
namespace otb
{
namespace tf
{
//
// Update the distribution of the patch located at the specified location
//
template<class TImage, class TDistribution>
bool UpdateDistributionFromPatch(const typename TImage::Pointer inPtr,
typename TImage::PointType point, typename TImage::SizeType patchSize,
TDistribution & dist)
{
typename TImage::IndexType index;
bool canTransform = inPtr->TransformPhysicalPointToIndex(point, index);
if (canTransform)
{
index[0] -= patchSize[0] / 2;
index[1] -= patchSize[1] / 2;
typename TImage::RegionType inPatchRegion(index, patchSize);
if (inPtr->GetLargestPossibleRegion().IsInside(inPatchRegion))
{
// Fill patch
PropagateRequestedRegion<TImage>(inPtr, inPatchRegion);
typename itk::ImageRegionConstIterator<TImage> inIt (inPtr, inPatchRegion);
for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt)
{
dist.Update(inIt.Get());
}
return true;
}
}
return false;
}
} // namespace tf
} // namespace otb
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef MODULES_REMOTE_OTBTF_INCLUDE_OTBTENSORFLOWSAMPLINGUTILS_H_
#define MODULES_REMOTE_OTBTF_INCLUDE_OTBTENSORFLOWSAMPLINGUTILS_H_
#include "otbTensorflowCommon.h"
#include "vnl/vnl_vector.h"
namespace otb
{
namespace tf
{
template<class TImage>
class Distribution
{
public:
typedef typename TImage::PixelType ValueType;
typedef vnl_vector<float> CountsType;
Distribution(unsigned int nClasses){
m_NbOfClasses = nClasses;
m_Dist = CountsType(nClasses, 0);
}
Distribution(unsigned int nClasses, float fillValue){
m_NbOfClasses = nClasses;
m_Dist = CountsType(nClasses, fillValue);
}
Distribution(){
m_NbOfClasses = 2;
m_Dist = CountsType(m_NbOfClasses, 0);
}
Distribution(const Distribution & other){
m_Dist = other.Get();
m_NbOfClasses = m_Dist.size();
}
~Distribution(){}
void Update(const typename TImage::PixelType & pixel)
{
m_Dist[pixel]++;
}
void Update(const Distribution & other)
{
const CountsType otherDist = other.Get();
for (unsigned int c = 0 ; c < m_NbOfClasses ; c++)
m_Dist[c] += otherDist[c];
}
CountsType Get() const
{
return m_Dist;
}
CountsType GetNormalized() const
{
const float invNorm = 1.0 / std::sqrt(dot_product(m_Dist, m_Dist));
const CountsType normalizedDist = invNorm * m_Dist;
return normalizedDist;
}
float Cosinus(const Distribution & other) const
{
return dot_product(other.GetNormalized(), GetNormalized());
}
std::string ToString()
{
std::stringstream ss;
ss << "\n";
for (unsigned int c = 0 ; c < m_NbOfClasses ; c++)
ss << "\tClass #" << c << " : " << m_Dist[c] << "\n";
return ss.str();
}
private:
unsigned int m_NbOfClasses;
CountsType m_Dist;
};
// Update the distribution of the patch located at the specified location
template<class TImage, class TDistribution>
bool UpdateDistributionFromPatch(const typename TImage::Pointer inPtr,
typename TImage::PointType point, typename TImage::SizeType patchSize,
TDistribution & dist);
} // namesapce tf
} // namespace otb
#include "otbTensorflowSamplingUtils.cxx"
#endif /* MODULES_REMOTE_OTBTF_INCLUDE_OTBTENSORFLOWSAMPLINGUTILS_H_ */
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment