Commit 5a1d2322 authored by Cresson Remi's avatar Cresson Remi

Merge branch 'develop'

parents 1e8bf930 d2195d91
...@@ -30,8 +30,7 @@ ...@@ -30,8 +30,7 @@
#include "otbTensorflowSource.h" #include "otbTensorflowSource.h"
// Streaming // Streaming
#include "otbImageRegionSquareTileSplitter.h" #include "otbTensorflowStreamerFilter.h"
#include "itkStreamingImageFilter.h"
namespace otb namespace otb
{ {
...@@ -58,7 +57,7 @@ public: ...@@ -58,7 +57,7 @@ public:
/** Typedef for streaming */ /** Typedef for streaming */
typedef otb::ImageRegionSquareTileSplitter<FloatVectorImageType::ImageDimension> TileSplitterType; typedef otb::ImageRegionSquareTileSplitter<FloatVectorImageType::ImageDimension> TileSplitterType;
typedef itk::StreamingImageFilter<FloatVectorImageType, FloatVectorImageType> StreamingFilterType; typedef otb::TensorflowStreamerFilter<FloatVectorImageType, FloatVectorImageType> StreamingFilterType;
/** Typedefs for images */ /** Typedefs for images */
typedef FloatVectorImageType::SizeType SizeType; typedef FloatVectorImageType::SizeType SizeType;
...@@ -198,9 +197,12 @@ public: ...@@ -198,9 +197,12 @@ public:
AddParameter(ParameterType_Bool, "optim.disabletiling", "Disable tiling"); AddParameter(ParameterType_Bool, "optim.disabletiling", "Disable tiling");
MandatoryOff ("optim.disabletiling"); MandatoryOff ("optim.disabletiling");
SetParameterDescription ("optim.disabletiling", "Tiling avoids to process a too large subset of image, but sometimes it can be useful to disable it"); SetParameterDescription ("optim.disabletiling", "Tiling avoids to process a too large subset of image, but sometimes it can be useful to disable it");
AddParameter(ParameterType_Int, "optim.tilesize", "Tile width used to stream the filter output"); AddParameter(ParameterType_Int, "optim.tilesizex", "Tile width used to stream the filter output");
SetMinimumParameterIntValue ("optim.tilesize", 1); SetMinimumParameterIntValue ("optim.tilesizex", 1);
SetDefaultParameterInt ("optim.tilesize", 16); SetDefaultParameterInt ("optim.tilesizex", 16);
AddParameter(ParameterType_Int, "optim.tilesizey", "Tile height used to stream the filter output");
SetMinimumParameterIntValue ("optim.tilesizey", 1);
SetDefaultParameterInt ("optim.tilesizey", 16);
// Output image // Output image
AddParameter(ParameterType_OutputImage, "out", "output image"); AddParameter(ParameterType_OutputImage, "out", "output image");
...@@ -292,22 +294,28 @@ public: ...@@ -292,22 +294,28 @@ public:
if (GetParameterInt("optim.disabletiling") != 1) if (GetParameterInt("optim.disabletiling") != 1)
{ {
// Get the tile size // Get the tile size
const unsigned int tileSize = GetParameterInt("optim.tilesize"); SizeType tileSize;
otbAppLogINFO("Force tiling with squared tiles of " << tileSize) tileSize[0] = GetParameterInt("optim.tilesizex");
tileSize[1] = GetParameterInt("optim.tilesizey");
// Check that the tile size is aligned to the field of expression
for (unsigned int i = 0 ; i < FloatVectorImageType::ImageDimension ; i++)
if (tileSize[i] % foe[i] != 0)
{
SizeType::SizeValueType newSize = 1 + std::floor(tileSize[i] / foe[i]);
newSize *= foe[i];
// Update the TensorFlow filter output information to get the output image size otbAppLogWARNING("Aligning the tiling to the output expression field "
m_TFFilter->UpdateOutputInformation(); << "for better performances (dim " << i << "). New value set to " << newSize)
// Splitting using square tiles tileSize[i] = newSize;
TileSplitterType::Pointer splitter = TileSplitterType::New(); }
splitter->SetTileSizeAlignment(tileSize);
unsigned int nbDesiredTiles = itk::Math::Ceil<unsigned int>( otbAppLogINFO("Force tiling with squared tiles of " << tileSize)
double(m_TFFilter->GetOutput()->GetLargestPossibleRegion().GetNumberOfPixels() ) / (tileSize * tileSize) );
// Use an itk::StreamingImageFilter to force the computation tile by tile // Force the computation tile by tile
m_StreamFilter = StreamingFilterType::New(); m_StreamFilter = StreamingFilterType::New();
m_StreamFilter->SetRegionSplitter(splitter); m_StreamFilter->SetOutputGridSize(tileSize);
m_StreamFilter->SetNumberOfStreamDivisions(nbDesiredTiles);
m_StreamFilter->SetInput(m_TFFilter->GetOutput()); m_StreamFilter->SetInput(m_TFFilter->GetOutput());
SetParameterOutputImage("out", m_StreamFilter->GetOutput()); SetParameterOutputImage("out", m_StreamFilter->GetOutput());
......
...@@ -438,6 +438,7 @@ public: ...@@ -438,6 +438,7 @@ public:
} }
// Setup the validation filter // Setup the validation filter
const bool do_validation = HasUserValue("validation.mode");
if (GetParameterInt("validation.mode")==1) // class if (GetParameterInt("validation.mode")==1) // class
{ {
otbAppLogINFO("Set validation mode to classification validation"); otbAppLogINFO("Set validation mode to classification validation");
...@@ -467,50 +468,53 @@ public: ...@@ -467,50 +468,53 @@ public:
AddProcess(m_TrainModelFilter, "Training epoch #" + std::to_string(epoch)); AddProcess(m_TrainModelFilter, "Training epoch #" + std::to_string(epoch));
m_TrainModelFilter->Update(); m_TrainModelFilter->Update();
// Validate the model if (do_validation)
if (epoch % GetParameterInt("validation.step") == 0) {
// Validate the model
if (epoch % GetParameterInt("validation.step") == 0)
{ {
// 1. Evaluate the metrics against the learning data // 1. Evaluate the metrics against the learning data
for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++) for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstLearningData.size() ; i++)
{ {
m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstLearningData[i]); m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstLearningData[i]);
} }
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData); m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstLearningData);
// As we use the learning data here, it's rational to use the same option as streaming during training // As we use the learning data here, it's rational to use the same option as streaming during training
m_ValidateModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming")); m_ValidateModelFilter->SetUseStreaming(GetParameterInt("training.usestreaming"));
// Update // Update
AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)"); AddProcess(m_ValidateModelFilter, "Evaluate model (Learning data)");
m_ValidateModelFilter->Update(); m_ValidateModelFilter->Update();
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{ {
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":");
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i)); PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
} }
// 2. Evaluate the metrics against the validation data // 2. Evaluate the metrics against the validation data
// Here we just change the input sources and references // Here we just change the input sources and references
for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstValidationData.size() ; i++) for (unsigned int i = 0 ; i < m_InputSourcesForEvaluationAgainstValidationData.size() ; i++)
{ {
m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstValidationData[i]); m_ValidateModelFilter->SetInput(i, m_InputSourcesForEvaluationAgainstValidationData[i]);
} }
m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstValidationData); m_ValidateModelFilter->SetInputReferences(m_InputTargetsForEvaluationAgainstValidationData);
m_ValidateModelFilter->SetUseStreaming(GetParameterInt("validation.usestreaming")); m_ValidateModelFilter->SetUseStreaming(GetParameterInt("validation.usestreaming"));
// Update // Update
AddProcess(m_ValidateModelFilter, "Evaluate model (Validation data)"); AddProcess(m_ValidateModelFilter, "Evaluate model (Validation data)");
m_ValidateModelFilter->Update(); m_ValidateModelFilter->Update();
for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++) for (unsigned int i = 0 ; i < m_TargetTensorsNames.size() ; i++)
{ {
otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":"); otbAppLogINFO("Metrics for target \"" << m_TargetTensorsNames[i] << "\":");
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i)); PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
} }
} // Step is OK to perform validation } // Step is OK to perform validation
} // Do the validation against the validation data
} // Next epoch } // Next epoch
......
...@@ -93,6 +93,7 @@ private: ...@@ -93,6 +93,7 @@ private:
ShareParameter("optim", "tfmodel.optim", "Processing time optimization", "This group of parameters allows optimization of processing time"); ShareParameter("optim", "tfmodel.optim", "Processing time optimization", "This group of parameters allows optimization of processing time");
// Train shared parameters // Train shared parameters
ShareParameter("ram", "train.ram", "Available RAM (Mb)", "Available RAM (Mb)");
ShareParameter("vd", "train.io.vd", "Vector data for training", "Input vector data for training"); ShareParameter("vd", "train.io.vd", "Vector data for training", "Input vector data for training");
ShareParameter("valid", "train.io.valid", "Vector data for validation", "Input vector data for validation"); ShareParameter("valid", "train.io.valid", "Vector data for validation", "Input vector data for validation");
ShareParameter("out", "train.io.out", "Output classification model", "Output classification model"); ShareParameter("out", "train.io.out", "Output classification model", "Output classification model");
......
...@@ -368,9 +368,6 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage> ...@@ -368,9 +368,6 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
RegionType outputAlignedReqRegion(outputReqRegion); RegionType outputAlignedReqRegion(outputReqRegion);
EnlargeToAlignedRegion(outputAlignedReqRegion); EnlargeToAlignedRegion(outputAlignedReqRegion);
// Add a progress reporter
itk::ProgressReporter progress(this, 0, outputReqRegion.GetNumberOfPixels());
const unsigned int nInputs = this->GetNumberOfInputs(); const unsigned int nInputs = this->GetNumberOfInputs();
// Create input tensors list // Create input tensors list
......
...@@ -55,6 +55,12 @@ TensorflowMultisourceModelTrain<TInputImage> ...@@ -55,6 +55,12 @@ TensorflowMultisourceModelTrain<TInputImage>
TensorListType outputs; TensorListType outputs;
this->RunSession(inputs, outputs); this->RunSession(inputs, outputs);
// Display outputs tensors
for (auto& o: outputs)
{
tf::PrintTensorInfos(o);
}
} }
......
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef otbTensorflowStreamerFilter_h
#define otbTensorflowStreamerFilter_h
// Image2image
#include "itkImageToImageFilter.h"
namespace otb
{
/**
* \class TensorflowStreamerFilter
* \brief This filter generates an output image with an internal
* explicit streaming mechanism.
*
* \ingroup OTBTensorflow
*/
template <class TInputImage, class TOutputImage>
class ITK_EXPORT TensorflowStreamerFilter :
public itk::ImageToImageFilter<TInputImage, TOutputImage>
{
public:
/** Standard class typedefs. */
typedef TensorflowStreamerFilter Self;
typedef itk::ImageToImageFilter<TInputImage, TOutputImage> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Method for creation through the object factory. */
itkNewMacro(Self);
/** Run-time type information (and related methods). */
itkTypeMacro(TensorflowStreamerFilter, itk::ImageToImageFilter);
/** Images typedefs */
typedef typename Superclass::InputImageType ImageType;
typedef typename ImageType::IndexType IndexType;
typedef typename ImageType::IndexValueType IndexValueType;
typedef typename ImageType::SizeType SizeType;
typedef typename Superclass::InputImageRegionType RegionType;
typedef TOutputImage OutputImageType;
itkSetMacro(OutputGridSize, SizeType);
itkGetMacro(OutputGridSize, SizeType);
protected:
TensorflowStreamerFilter();
virtual ~TensorflowStreamerFilter() {};
virtual void UpdateOutputData(itk::DataObject *output){(void) output; this->GenerateData();}
virtual void GenerateData();
private:
TensorflowStreamerFilter(const Self&); //purposely not implemented
void operator=(const Self&); //purposely not implemented
SizeType m_OutputGridSize; // Output grid size
}; // end class
} // end namespace otb
#include "otbTensorflowStreamerFilter.hxx"
#endif
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef otbTensorflowStreamerFilter_txx
#define otbTensorflowStreamerFilter_txx
#include "otbTensorflowStreamerFilter.h"
#include "itkImageAlgorithm.h"
namespace otb
{
template <class TInputImage, class TOutputImage>
TensorflowStreamerFilter<TInputImage, TOutputImage>
::TensorflowStreamerFilter()
{
m_OutputGridSize.Fill(1);
}
/**
* Compute the output image
*/
template <class TInputImage, class TOutputImage>
void
TensorflowStreamerFilter<TInputImage, TOutputImage>
::GenerateData()
{
// Output pointer and requested region
OutputImageType * outputPtr = this->GetOutput();
const RegionType outputReqRegion = outputPtr->GetRequestedRegion();
outputPtr->SetBufferedRegion(outputReqRegion);
outputPtr->Allocate();
// Compute the aligned region
RegionType region;
for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim)
{
// Get corners
IndexValueType lower = outputReqRegion.GetIndex(dim);
IndexValueType upper = lower + outputReqRegion.GetSize(dim);
// Compute deltas between corners and the grid
const IndexValueType deltaLo = lower % m_OutputGridSize[dim];
const IndexValueType deltaUp = upper % m_OutputGridSize[dim];
// Move corners to aligned positions
lower -= deltaLo;
if (deltaUp > 0)
{
upper += m_OutputGridSize[dim] - deltaUp;
}
// Update region
region.SetIndex(dim, lower);
region.SetSize(dim, upper - lower);
}
// Compute the number of subregions to process
const unsigned int nbTilesX = region.GetSize(0) / m_OutputGridSize[0];
const unsigned int nbTilesY = region.GetSize(1) / m_OutputGridSize[1];
// Progress
itk::ProgressReporter progress(this, 0, nbTilesX*nbTilesY);
// For each tile, propagate the input region and recopy the output
ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(0) );
unsigned int tx, ty;
RegionType subRegion;
subRegion.SetSize(m_OutputGridSize);
for (ty = 0; ty < nbTilesY; ty++)
{
subRegion.SetIndex(1, ty*m_OutputGridSize[1] + region.GetIndex(1));
for (tx = 0; tx < nbTilesX; tx++)
{
// Update the input subregion
subRegion.SetIndex(0, tx*m_OutputGridSize[0] + region.GetIndex(0));
// The actual region to copy
RegionType cpyRegion(subRegion);
cpyRegion.Crop(outputReqRegion);
// Propagate region
inputImage->SetRequestedRegion(cpyRegion);
inputImage->PropagateRequestedRegion();
inputImage->UpdateOutputData();
// Copy the subregion to output
itk::ImageAlgorithm::Copy( inputImage, outputPtr, cpyRegion, cpyRegion );
progress.CompletedPixel();
}
}
}
} // end namespace otb
#endif
# -*- coding: utf-8 -*-
#==========================================================================
#
# Copyright Remi Cresson (IRSTEA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#==========================================================================*/
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from tricks import *
# Logging
tf.logging.set_verbosity(tf.logging.INFO)
# Parser
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", help="checkpoint file prefix", required=True)
parser.add_argument("--inputs", help="input placeholder names", required=True, nargs='+')
parser.add_argument("--outputs", help="output placeholder names", required=True, nargs='+')
parser.add_argument("--model", help="output SavedModel", required=True)
params = parser.parse_args()
if __name__ == "__main__":
CheckpointToSavedModel(params.ckpt, params.inputs, params.outputs, params.model)
quit()
# -*- coding: utf-8 -*-
#==========================================================================
#
# Copyright Remi Cresson, Dino Ienco (IRSTEA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#==========================================================================*/
import sys import sys
import os import os
import numpy as np import numpy as np
...@@ -13,35 +31,6 @@ from sklearn.ensemble import RandomForestClassifier ...@@ -13,35 +31,6 @@ from sklearn.ensemble import RandomForestClassifier
from sklearn.utils import shuffle from sklearn.utils import shuffle
from sklearn.metrics import confusion_matrix from sklearn.metrics import confusion_matrix
from tricks import * from tricks import *
def export_model(sess,
export_dir,
x_cnn_placeholder,
x_rnn_placeholder,
is_training_placeholder,
testPrediction):
""" export a SavedModel
"""
# Update the export dir
model_dir = export_dir + "/saved_model/"
if os.path.exists(model_dir):
shutil.rmtree(model_dir)
print("Export model in " + model_dir)
# Add a builder (for LoadSavedModel)
builder = tf.saved_model.builder.SavedModelBuilder(model_dir)
signature_def_map= {
"model": tf.saved_model.signature_def_utils.predict_signature_def(
inputs = {"x_cnn" : x_cnn_placeholder,
"x_rnn" : x_rnn_placeholder,
"is_training" : is_training_placeholder},
outputs = {"prediction" : testPrediction})
}
builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.TRAINING],signature_def_map)
builder.add_meta_graph([tf.saved_model.tag_constants.SERVING])
builder.save()
def checkTest(ts_data, vhsr_data, batchsz, label_test): def checkTest(ts_data, vhsr_data, batchsz, label_test):
tot_pred = [] tot_pred = []
...@@ -60,11 +49,7 @@ def checkTest(ts_data, vhsr_data, batchsz, label_test): ...@@ -60,11 +49,7 @@ def checkTest(ts_data, vhsr_data, batchsz, label_test):
is_training_ph:True, is_training_ph:True,
dropout:0.0, dropout:0.0,
x_cnn:batch_cnn_x}) x_cnn:batch_cnn_x})
del batch_rnn_x
del batch_cnn_x
del batch_y
for el in pred_temp: for el in pred_temp:
tot_pred.append( el ) tot_pred.append( el )
...@@ -241,8 +226,8 @@ n_channels = 4 ...@@ -241,8 +226,8 @@ n_channels = 4
nclasses = 8 nclasses = 8
# check number of arguments # check number of arguments
if len(sys.argv) != 7: if len(sys.argv) != 8:
print("Usage : <ts_train> <vhs_train> <label_train> <ts_valid> <vhs_valid> <label_valid>") print("Usage : <ts_train> <vhs_train> <label_train> <ts_valid> <vhs_valid> <label_valid> <export_dir>")
sys.exit(1) sys.exit(1)
ts_train = read_samples(sys.argv[1]) ts_train = read_samples(sys.argv[1])
...@@ -257,6 +242,8 @@ label_test = read_samples(sys.argv[6]) ...@@ -257,6 +242,8 @@ label_test = read_samples(sys.argv[6])
label_test = np.int32(label_test) label_test = np.int32(label_test)
print_histo(label_test, "label_test") print_histo(label_test, "label_test")
export_dir = read_samples(sys.argv[7])
x_rnn = tf.placeholder(tf.float32,[None, 1, 1, n_dims*n_timestamps],name="x_rnn") x_rnn = tf.placeholder(tf.float32,[None, 1, 1, n_dims*n_timestamps],name="x_rnn")
x_cnn = tf.placeholder(tf.float32,[None, patch_window, patch_window, n_channels],name="x_cnn") x_cnn = tf.placeholder(tf.float32,[None, patch_window, patch_window, n_channels],name="x_cnn")
y = tf.placeholder(tf.int32,[None, 1, 1, 1],name="y") y = tf.placeholder(tf.int32,[None, 1, 1, 1],name="y")
...@@ -324,19 +311,12 @@ for e in range(hm_epochs): ...@@ -324,19 +311,12 @@ for e in range(hm_epochs):
lossi+=loss lossi+=loss
accS+=acc accS+=acc
del batch_rnn_x
del batch_cnn_x
del batch_y
print "Epoch:",e,"Train loss:",lossi/iterations,"| accuracy:",accS/iterations print "Epoch:",e,"Train loss:",lossi/iterations,"| accuracy:",accS/iterations
c_loss = lossi/iterations c_loss = lossi/iterations
if c_loss < best_loss: if c_loss < best_loss:
save_path = saver.save(sess, "models/model")
print("Model saved in path: %s" % save_path)
best_loss = c_loss best_loss = c_loss
export_model(sess, "/tmp/m3_export", x_cnn, x_rnn, is_training_ph, testPrediction) CreateSavedModel(sess, ["x_cnn:0","x_rnn:0","is_training:0"], ["prediction:0"], export_dir)
test_acc = checkTest(ts_test, vhsr_test, 1024, label_test) test_acc = checkTest(ts_test, vhsr_test, 1024, label_test)
\ No newline at end of file
# -*- coding: utf-8 -*-
#==========================================================================
#
# Copyright Remi Cresson (IRSTEA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
#