Commit 5a1d2322 authored by Cresson Remi's avatar Cresson Remi

Merge branch 'develop'

parents 1e8bf930 d2195d91
......@@ -30,8 +30,7 @@
#include "otbTensorflowSource.h"
// Streaming
#include "otbImageRegionSquareTileSplitter.h"
#include "itkStreamingImageFilter.h"
#include "otbTensorflowStreamerFilter.h"
namespace otb
{
......@@ -58,7 +57,7 @@ public:
/** Typedef for streaming */
typedef otb::ImageRegionSquareTileSplitter<FloatVectorImageType::ImageDimension> TileSplitterType;
typedef itk::StreamingImageFilter<FloatVectorImageType, FloatVectorImageType> StreamingFilterType;
typedef otb::TensorflowStreamerFilter<FloatVectorImageType, FloatVectorImageType> StreamingFilterType;
/** Typedefs for images */
typedef FloatVectorImageType::SizeType SizeType;
......@@ -198,9 +197,12 @@ public:
AddParameter(ParameterType_Bool, "optim.disabletiling", "Disable tiling");
MandatoryOff ("optim.disabletiling");
SetParameterDescription ("optim.disabletiling", "Tiling avoids to process a too large subset of image, but sometimes it can be useful to disable it");
AddParameter(ParameterType_Int, "optim.tilesize", "Tile width used to stream the filter output");
SetMinimumParameterIntValue ("optim.tilesize", 1);
SetDefaultParameterInt ("optim.tilesize", 16);
AddParameter(ParameterType_Int, "optim.tilesizex", "Tile width used to stream the filter output");
SetMinimumParameterIntValue ("optim.tilesizex", 1);
SetDefaultParameterInt ("optim.tilesizex", 16);
AddParameter(ParameterType_Int, "optim.tilesizey", "Tile height used to stream the filter output");
SetMinimumParameterIntValue ("optim.tilesizey", 1);
SetDefaultParameterInt ("optim.tilesizey", 16);
// Output image
AddParameter(ParameterType_OutputImage, "out", "output image");
......@@ -292,22 +294,28 @@ public:
if (GetParameterInt("optim.disabletiling") != 1)
{
// Get the tile size
const unsigned int tileSize = GetParameterInt("optim.tilesize");
otbAppLogINFO("Force tiling with squared tiles of " << tileSize)
SizeType tileSize;
tileSize[0] = GetParameterInt("optim.tilesizex");
tileSize[1] = GetParameterInt("optim.tilesizey");
// Check that the tile size is aligned to the field of expression
for (unsigned int i = 0 ; i < FloatVectorImageType::ImageDimension ; i++)
if (tileSize[i] % foe[i] != 0)
{
SizeType::SizeValueType newSize = 1 + std::floor(tileSize[i] / foe[i]);
newSize *= foe[i];
// Update the TensorFlow filter output information to get the output image size
m_TFFilter->UpdateOutputInformation();
otbAppLogWARNING("Aligning the tiling to the output expression field "
<< "for better performances (dim " << i << "). New value set to " << newSize)
// Splitting using square tiles
TileSplitterType::Pointer splitter = TileSplitterType::New();
splitter->SetTileSizeAlignment(tileSize);
unsigned int nbDesiredTiles = itk::Math::Ceil<unsigned int>(
double(m_TFFilter->GetOutput()->GetLargestPossibleRegion().GetNumberOfPixels() ) / (tileSize * tileSize) );
tileSize[i] = newSize;
}
otbAppLogINFO("Force tiling with squared tiles of " << tileSize)
// Use an itk::StreamingImageFilter to force the computation tile by tile
// Force the computation tile by tile
m_StreamFilter = StreamingFilterType::New();
m_StreamFilter->SetRegionSplitter(splitter);
m_StreamFilter->SetNumberOfStreamDivisions(nbDesiredTiles);
m_StreamFilter->SetOutputGridSize(tileSize);
m_StreamFilter->SetInput(m_TFFilter->GetOutput());
SetParameterOutputImage("out", m_StreamFilter->GetOutput());
......
......@@ -438,6 +438,7 @@ public:
}
// Setup the validation filter
const bool do_validation = HasUserValue("validation.mode");
if (GetParameterInt("validation.mode")==1) // class
{
otbAppLogINFO("Set validation mode to classification validation");
......@@ -467,6 +468,8 @@ public:
AddProcess(m_TrainModelFilter, "Training epoch #" + std::to_string(epoch));
m_TrainModelFilter->Update();
if (do_validation)
{
// Validate the model
if (epoch % GetParameterInt("validation.step") == 0)
{
......@@ -511,6 +514,7 @@ public:
PrintClassificationMetrics(m_ValidateModelFilter->GetConfusionMatrix(i), m_ValidateModelFilter->GetMapOfClasses(i));
}
} // Step is OK to perform validation
} // Do the validation against the validation data
} // Next epoch
......
......@@ -93,6 +93,7 @@ private:
ShareParameter("optim", "tfmodel.optim", "Processing time optimization", "This group of parameters allows optimization of processing time");
// Train shared parameters
ShareParameter("ram", "train.ram", "Available RAM (Mb)", "Available RAM (Mb)");
ShareParameter("vd", "train.io.vd", "Vector data for training", "Input vector data for training");
ShareParameter("valid", "train.io.valid", "Vector data for validation", "Input vector data for validation");
ShareParameter("out", "train.io.out", "Output classification model", "Output classification model");
......
......@@ -368,9 +368,6 @@ TensorflowMultisourceModelFilter<TInputImage, TOutputImage>
RegionType outputAlignedReqRegion(outputReqRegion);
EnlargeToAlignedRegion(outputAlignedReqRegion);
// Add a progress reporter
itk::ProgressReporter progress(this, 0, outputReqRegion.GetNumberOfPixels());
const unsigned int nInputs = this->GetNumberOfInputs();
// Create input tensors list
......
......@@ -55,6 +55,12 @@ TensorflowMultisourceModelTrain<TInputImage>
TensorListType outputs;
this->RunSession(inputs, outputs);
// Display outputs tensors
for (auto& o: outputs)
{
tf::PrintTensorInfos(o);
}
}
......
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef otbTensorflowStreamerFilter_h
#define otbTensorflowStreamerFilter_h
// Image2image
#include "itkImageToImageFilter.h"
namespace otb
{
/**
* \class TensorflowStreamerFilter
* \brief This filter generates an output image with an internal
* explicit streaming mechanism.
*
* \ingroup OTBTensorflow
*/
template <class TInputImage, class TOutputImage>
class ITK_EXPORT TensorflowStreamerFilter :
public itk::ImageToImageFilter<TInputImage, TOutputImage>
{
public:
/** Standard class typedefs. */
typedef TensorflowStreamerFilter Self;
typedef itk::ImageToImageFilter<TInputImage, TOutputImage> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Method for creation through the object factory. */
itkNewMacro(Self);
/** Run-time type information (and related methods). */
itkTypeMacro(TensorflowStreamerFilter, itk::ImageToImageFilter);
/** Images typedefs */
typedef typename Superclass::InputImageType ImageType;
typedef typename ImageType::IndexType IndexType;
typedef typename ImageType::IndexValueType IndexValueType;
typedef typename ImageType::SizeType SizeType;
typedef typename Superclass::InputImageRegionType RegionType;
typedef TOutputImage OutputImageType;
itkSetMacro(OutputGridSize, SizeType);
itkGetMacro(OutputGridSize, SizeType);
protected:
TensorflowStreamerFilter();
virtual ~TensorflowStreamerFilter() {};
virtual void UpdateOutputData(itk::DataObject *output){(void) output; this->GenerateData();}
virtual void GenerateData();
private:
TensorflowStreamerFilter(const Self&); //purposely not implemented
void operator=(const Self&); //purposely not implemented
SizeType m_OutputGridSize; // Output grid size
}; // end class
} // end namespace otb
#include "otbTensorflowStreamerFilter.hxx"
#endif
/*=========================================================================
Copyright (c) Remi Cresson (IRSTEA). All rights reserved.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef otbTensorflowStreamerFilter_txx
#define otbTensorflowStreamerFilter_txx
#include "otbTensorflowStreamerFilter.h"
#include "itkImageAlgorithm.h"
namespace otb
{
template <class TInputImage, class TOutputImage>
TensorflowStreamerFilter<TInputImage, TOutputImage>
::TensorflowStreamerFilter()
{
m_OutputGridSize.Fill(1);
}
/**
* Compute the output image
*/
template <class TInputImage, class TOutputImage>
void
TensorflowStreamerFilter<TInputImage, TOutputImage>
::GenerateData()
{
// Output pointer and requested region
OutputImageType * outputPtr = this->GetOutput();
const RegionType outputReqRegion = outputPtr->GetRequestedRegion();
outputPtr->SetBufferedRegion(outputReqRegion);
outputPtr->Allocate();
// Compute the aligned region
RegionType region;
for(unsigned int dim = 0; dim<OutputImageType::ImageDimension; ++dim)
{
// Get corners
IndexValueType lower = outputReqRegion.GetIndex(dim);
IndexValueType upper = lower + outputReqRegion.GetSize(dim);
// Compute deltas between corners and the grid
const IndexValueType deltaLo = lower % m_OutputGridSize[dim];
const IndexValueType deltaUp = upper % m_OutputGridSize[dim];
// Move corners to aligned positions
lower -= deltaLo;
if (deltaUp > 0)
{
upper += m_OutputGridSize[dim] - deltaUp;
}
// Update region
region.SetIndex(dim, lower);
region.SetSize(dim, upper - lower);
}
// Compute the number of subregions to process
const unsigned int nbTilesX = region.GetSize(0) / m_OutputGridSize[0];
const unsigned int nbTilesY = region.GetSize(1) / m_OutputGridSize[1];
// Progress
itk::ProgressReporter progress(this, 0, nbTilesX*nbTilesY);
// For each tile, propagate the input region and recopy the output
ImageType * inputImage = static_cast<ImageType * >( Superclass::ProcessObject::GetInput(0) );
unsigned int tx, ty;
RegionType subRegion;
subRegion.SetSize(m_OutputGridSize);
for (ty = 0; ty < nbTilesY; ty++)
{
subRegion.SetIndex(1, ty*m_OutputGridSize[1] + region.GetIndex(1));
for (tx = 0; tx < nbTilesX; tx++)
{
// Update the input subregion
subRegion.SetIndex(0, tx*m_OutputGridSize[0] + region.GetIndex(0));
// The actual region to copy
RegionType cpyRegion(subRegion);
cpyRegion.Crop(outputReqRegion);
// Propagate region
inputImage->SetRequestedRegion(cpyRegion);
inputImage->PropagateRequestedRegion();
inputImage->UpdateOutputData();
// Copy the subregion to output
itk::ImageAlgorithm::Copy( inputImage, outputPtr, cpyRegion, cpyRegion );
progress.CompletedPixel();
}
}
}
} // end namespace otb
#endif
# -*- coding: utf-8 -*-
#==========================================================================
#
# Copyright Remi Cresson (IRSTEA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#==========================================================================*/
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from tricks import *
# Logging
tf.logging.set_verbosity(tf.logging.INFO)
# Parser
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", help="checkpoint file prefix", required=True)
parser.add_argument("--inputs", help="input placeholder names", required=True, nargs='+')
parser.add_argument("--outputs", help="output placeholder names", required=True, nargs='+')
parser.add_argument("--model", help="output SavedModel", required=True)
params = parser.parse_args()
if __name__ == "__main__":
CheckpointToSavedModel(params.ckpt, params.inputs, params.outputs, params.model)
quit()
# -*- coding: utf-8 -*-
#==========================================================================
#
# Copyright Remi Cresson, Dino Ienco (IRSTEA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#==========================================================================*/
import sys
import os
import numpy as np
......@@ -14,35 +32,6 @@ from sklearn.utils import shuffle
from sklearn.metrics import confusion_matrix
from tricks import *
def export_model(sess,
export_dir,
x_cnn_placeholder,
x_rnn_placeholder,
is_training_placeholder,
testPrediction):
""" export a SavedModel
"""
# Update the export dir
model_dir = export_dir + "/saved_model/"
if os.path.exists(model_dir):
shutil.rmtree(model_dir)
print("Export model in " + model_dir)
# Add a builder (for LoadSavedModel)
builder = tf.saved_model.builder.SavedModelBuilder(model_dir)
signature_def_map= {
"model": tf.saved_model.signature_def_utils.predict_signature_def(
inputs = {"x_cnn" : x_cnn_placeholder,
"x_rnn" : x_rnn_placeholder,
"is_training" : is_training_placeholder},
outputs = {"prediction" : testPrediction})
}
builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.TRAINING],signature_def_map)
builder.add_meta_graph([tf.saved_model.tag_constants.SERVING])
builder.save()
def checkTest(ts_data, vhsr_data, batchsz, label_test):
tot_pred = []
# gt_test = []
......@@ -61,10 +50,6 @@ def checkTest(ts_data, vhsr_data, batchsz, label_test):
dropout:0.0,
x_cnn:batch_cnn_x})
del batch_rnn_x
del batch_cnn_x
del batch_y
for el in pred_temp:
tot_pred.append( el )
......@@ -241,8 +226,8 @@ n_channels = 4
nclasses = 8
# check number of arguments
if len(sys.argv) != 7:
print("Usage : <ts_train> <vhs_train> <label_train> <ts_valid> <vhs_valid> <label_valid>")
if len(sys.argv) != 8:
print("Usage : <ts_train> <vhs_train> <label_train> <ts_valid> <vhs_valid> <label_valid> <export_dir>")
sys.exit(1)
ts_train = read_samples(sys.argv[1])
......@@ -257,6 +242,8 @@ label_test = read_samples(sys.argv[6])
label_test = np.int32(label_test)
print_histo(label_test, "label_test")
export_dir = read_samples(sys.argv[7])
x_rnn = tf.placeholder(tf.float32,[None, 1, 1, n_dims*n_timestamps],name="x_rnn")
x_cnn = tf.placeholder(tf.float32,[None, patch_window, patch_window, n_channels],name="x_cnn")
y = tf.placeholder(tf.int32,[None, 1, 1, 1],name="y")
......@@ -324,19 +311,12 @@ for e in range(hm_epochs):
lossi+=loss
accS+=acc
del batch_rnn_x
del batch_cnn_x
del batch_y
print "Epoch:",e,"Train loss:",lossi/iterations,"| accuracy:",accS/iterations
c_loss = lossi/iterations
if c_loss < best_loss:
save_path = saver.save(sess, "models/model")
print("Model saved in path: %s" % save_path)
best_loss = c_loss
export_model(sess, "/tmp/m3_export", x_cnn, x_rnn, is_training_ph, testPrediction)
CreateSavedModel(sess, ["x_cnn:0","x_rnn:0","is_training:0"], ["prediction:0"], export_dir)
test_acc = checkTest(ts_test, vhsr_test, 1024, label_test)
\ No newline at end of file
# -*- coding: utf-8 -*-
#==========================================================================
#
# Copyright Remi Cresson (IRSTEA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#==========================================================================*/
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......@@ -89,12 +107,11 @@ def main(unused_argv):
# check number of arguments
if len(sys.argv) != 4:
print("Usage : <patches> <labels> <output_model_dir>")
print("Usage : <patches> <labels> <export_dir>")
sys.exit(1)
# Export dir
log_dir = sys.argv[3] + '/model_checkpoints/'
export_dir = sys.argv[3] + '/model_export/'
export_dir = sys.argv[3]
print("loading dataset")
......@@ -272,13 +289,9 @@ def main(unused_argv):
if step % 10 == 0:
# Print status to stdout.
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
#print('Step %d: (%.3f sec)' % (step, duration))
# Save a checkpoint and evaluate the model periodically.
if (curr_epoch + 1) % 1 == 0:
checkpoint_file = os.path.join(log_dir, 'model.ckpt')
saver.save(sess, checkpoint_file, global_step=step)
# Evaluate against the training set.
print('Training Data Eval:')
do_eval2(sess,
......@@ -301,16 +314,7 @@ def main(unused_argv):
batch_size)
# Let's export a SavedModel
shutil.rmtree(export_dir)
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
signature_def_map= {
"model": tf.saved_model.signature_def_utils.predict_signature_def(
inputs= {"x1": xs_placeholder},
outputs= {"prediction": testPrediction})
}
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING], signature_def_map)
builder.add_meta_graph([tf.saved_model.tag_constants.SERVING])
builder.save()
CreateSavedModel(sess, ["x1:0"], ["prediction:0"], export_dir)
quit()
......
# -*- coding: utf-8 -*-
#==========================================================================
#
# Copyright Remi Cresson (IRSTEA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#==========================================================================*/
import sys
import os
import numpy as np
import math
import time
import otbApplication
import tensorflow as tf
import shutil
def flatten_nparray(np_arr):
""" Returns a 1D numpy array retulting from the flatten of the input
......@@ -35,10 +55,13 @@ def print_tensor_info(name, tensor):
print(name+" : "+str(tensor.shape)+" (dtype="+str(tensor.dtype)+")")
def read_samples(fn):
def read_samples(fn, single=False):
""" Read an image of patches and return a 4D numpy array
TODO: Add an optional argument for the y-patchsize
Args:
fn: file name
single: a boolean telling if there is only 1 image in the batch.
In this case, the image can be rectangular (not squared)
"""
# Get input image size
......@@ -63,6 +86,9 @@ def read_samples(fn):
print("Quick stats: min="+str(np.amin(outimg))+", max="+str(np.amax(outimg)) )
# reshape
if (single):
return np.copy(outimg.reshape((1, size_y, size_x, nbands)))
n_samples = int(size_y / size_x)
outimg = outimg.reshape((n_samples, size_x, size_x, nbands))
......@@ -76,3 +102,55 @@ def getBatch(X, Y, i, batch_size):
batch_y = Y[start_id:end_id]
return batch_x, batch_y
def CreateSavedModel(sess, inputs, outputs, directory):
"""
Create a SavedModel
Args:
sess: the session
inputs: the list of input names
outputs: the list of output names
directory: the output path for the SavedModel
"""
print("Create a SavedModel in " + directory)
# Get graph
graph = tf.get_default_graph()
# Get inputs
input_dict = { i : graph.get_tensor_by_name(i) for i in inputs }
output_dict = { o : graph.get_tensor_by_name(o) for o in outputs }
# Build the SavedModel
builder = tf.saved_model.builder.SavedModelBuilder(directory)
signature_def_map= {
"model": tf.saved_model.signature_def_utils.predict_signature_def(
input_dict,
output_dict)
}
builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.TRAINING],signature_def_map)
builder.add_meta_graph([tf.saved_model.tag_constants.SERVING])
builder.save()
def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path):
"""
Read a Checkpoint and build a SavedModel
Args:
ckpt_path: path to the checkpoint file (without the ".meta" extension)
inputs: input list of placeholders names (e.g. ["x_cnn_1:0", "x_cnn_2:0"])
outputs: output list of tensor outputs names (e.g. ["prediction:0", "features:0"])
savedmodel_path: path to the SavedModel
"""
tf.reset_default_graph()
with tf.Session() as sess:
# Restore variables from disk.