Commit 42a39b17 authored by Cresson Remi's avatar Cresson Remi

REFAC: read patches using GDAL, DOC: functions

parent b9a7b692
# -*- coding: utf-8 -*-
#==========================================================================
#
# Copyright Remi Cresson (IRSTEA)
# Copyright 2018-2019 Remi Cresson (IRSTEA)
# Copyright 2020 Remi Cresson (INRAE)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -25,16 +26,18 @@ from tricks import CheckpointToSavedModel
# Parser
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", help="checkpoint file prefix", required=True)
parser.add_argument("--inputs", help="input placeholder names", required=True, nargs='+')
parser.add_argument("--outputs", help="output placeholder names", required=True, nargs='+')
parser.add_argument("--model", help="output SavedModel", required=True)
parser.add_argument("--ckpt", help="Checkpoint file (without the \".meta\" extension)", required=True)
parser.add_argument("--inputs", help="Inputs names (e.g. [\"x_cnn_1:0\", \"x_cnn_2:0\"])", required=True, nargs='+')
parser.add_argument("--outputs", help="Outputs names (e.g. [\"prediction:0\", \"features:0\"])", required=True, nargs='+')
parser.add_argument("--model", help="Output directory for SavedModel", required=True)
parser.add_argument('--clear_devices', dest='clear_devices', action='store_true')
parser.set_defaults(clear_devices=False)
params = parser.parse_args()
if __name__ == "__main__":
CheckpointToSavedModel(params.ckpt, params.inputs, params.outputs, params.model, params.clear_devices)
CheckpointToSavedModel(ckpt_path=params.ckpt,
inputs=params.inputs,
outputs=params.outputs,
savedmodel_path=params.model,
clear_devices=params.clear_devices)
quit()
# -*- coding: utf-8 -*-
#==========================================================================
#
# Copyright Remi Cresson (IRSTEA)
# Copyright 2018-2019 Remi Cresson (IRSTEA)
# Copyright 2020 Remi Cresson (INRAE)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -16,146 +17,76 @@
# limitations under the License.
#
#==========================================================================*/
import sys
import os
import gdal
import numpy as np
import math
import time
import otbApplication
import tensorflow as tf
import shutil
from deprecated import deprecated
def flatten_nparray(np_arr):
""" Returns a 1D numpy array retulting from the flatten of the input
"""
return np_arr.reshape((len(np_arr)))
def print_histo(np_arr, title=""):
""" Prints the histogram of the input numpy array
"""
np_flat = flatten_nparray(np_arr)
np_hist = np.bincount(np_flat)
np_vals = np.unique(np_flat)
if (len(title) > 0):
print(title + ":")
print("Values : "+str(np_vals))
print("Count : "+str(np_hist))
@deprecated
def read_samples(fn, single=False):
return ReadImageAsNumpyArray(filename=fn, asPatches=not single)
def print_tensor_live(name, tensor):
""" Print the shape of a tensor during a session run
"""
return tf.Print(tensor, [tf.shape(tensor)], name + " shape")
def print_tensor_info(name, tensor):
""" Print the shape of a tensor
Args:
name : the tensor's name (as we want it to be displayed)
tensor : the tensor
def ReadImageAsNumpyArray(filename, asPatches=False):
"""
Read an image as numpy array.
@param filename File name of patches image
@param asPatches True if the image must be read as patches
@return 4D numpy array [batch, h, w, c]
"""
print(name+" : "+str(tensor.shape)+" (dtype="+str(tensor.dtype)+")")
def read_samples(fn, single=False):
""" Read an image of patches and return a 4D numpy array
TODO: Add an optional argument for the y-patchsize
Args:
fn: file name
single: a boolean telling if there is only 1 image in the batch.
In this case, the image can be rectangular (not squared)
"""
# Open a GDAL dataset
ds = gdal.Open(filename)
if (ds is None):
raise Exception("Unable to open file {}".format(filename))
# Raster infos
nBands = ds.RasterCount
szx = ds.RasterXSize
szy = ds.RasterYSize
# Get input image size
imageInfo = otbApplication.Registry.CreateApplication('ReadImageInfo')
imageInfo.SetParameterString('in', fn)
imageInfo.Execute()
size_x = imageInfo.GetParameterInt('sizex')
size_y = imageInfo.GetParameterInt('sizey')
nbands = imageInfo.GetParameterInt('numberbands')
# Raster array
myarray = ds.ReadAsArray()
print("Loading image "+str(fn)+" ("+str(size_x)+" x "+str(size_y)+" x "+str(nbands)+")")
# Prepare the PixelValue application
imageReader = otbApplication.Registry.CreateApplication('ExtractROI')
imageReader.SetParameterString('in', fn)
imageReader.SetParameterInt('sizex', size_x)
imageReader.SetParameterInt('sizey', size_y)
imageReader.Execute()
outimg=imageReader.GetVectorImageAsNumpyArray('out', 'float')
# quick stats
print("Quick stats: min="+str(np.amin(outimg))+", max="+str(np.amax(outimg)) )
# reshape
if (single):
return np.copy(outimg.reshape((1, size_y, size_x, nbands)))
# Re-order bands (when there is > 1 band)
if (len(myarray.shape) == 3):
axes = (1, 2, 0)
myarray = np.transpose(myarray, axes=axes)
n_samples = int(size_y / size_x)
outimg = outimg.reshape((n_samples, size_x, size_x, nbands))
print("Returned numpy array shape: "+str(outimg.shape))
return np.copy(outimg)
def getBatch(X, Y, i, batch_size):
start_id = i*batch_size
end_id = min( (i+1) * batch_size, X.shape[0])
batch_x = X[start_id:end_id]
batch_y = Y[start_id:end_id]
return batch_x, batch_y
if (asPatches):
n = int(szy / szx)
return myarray.reshape((n, szx, szx, nBands))
return myarray.reshape((1, szy, szx, nBands))
def CreateSavedModel(sess, inputs, outputs, directory):
"""
Create a SavedModel
Args:
sess: the session
inputs: the list of input names
outputs: the list of output names
directory: the output path for the SavedModel
@param sess TF session
@param inputs List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"])
@param outputs List of outputs names (e.g. ["prediction:0", "features:0"])
@param directory Path for the generated SavedModel
"""
print("Create a SavedModel in " + directory)
# Get graph
graph = tf.get_default_graph()
# Get inputs
input_dict = { i : graph.get_tensor_by_name(i) for i in inputs }
output_dict = { o : graph.get_tensor_by_name(o) for o in outputs }
# Build the SavedModel
builder = tf.saved_model.builder.SavedModelBuilder(directory)
signature_def_map= {
"model": tf.saved_model.signature_def_utils.predict_signature_def(
input_dict,
output_dict)
}
builder.add_meta_graph_and_variables(sess,[tf.saved_model.TRAINING],signature_def_map)
builder.add_meta_graph([tf.saved_model.SERVING])
builder.save()
inputs_names = { i : graph.get_tensor_by_name(i) for i in inputs }
outputs_names = { o : graph.get_tensor_by_name(o) for o in outputs }
tf.saved_model.simple_save(sess, directory, inputs=inputs_names, outputs=outputs_names)
def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path, clear_devices=False):
"""
Read a Checkpoint and build a SavedModel
Args:
ckpt_path: path to the checkpoint file (without the ".meta" extension)
inputs: input list of placeholders names (e.g. ["x_cnn_1:0", "x_cnn_2:0"])
outputs: output list of tensor outputs names (e.g. ["prediction:0", "features:0"])
savedmodel_path: path to the SavedModel
@param ckpt_path Path to the checkpoint file (without the ".meta" extension)
@param inputs List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"])
@param outputs List of outputs names (e.g. ["prediction:0", "features:0"])
@param savedmodel_path Path for the generated SavedModel
@param clear_devices Clear TF devices positionning (True/False)
"""
tf.reset_default_graph()
with tf.Session() as sess:
# Restore variables from disk.
# Restore variables from disk
model_saver = tf.train.import_meta_graph(ckpt_path+".meta", clear_devices=clear_devices)
model_saver.restore(sess, ckpt_path)
# Create a SavedModel
#CreateSavedModel(sess, inputs, outputs, savedmodel_path)
graph = tf.get_default_graph()
tf.saved_model.simple_save(sess,
savedmodel_path,
inputs={ i : graph.get_tensor_by_name(i) for i in inputs },
outputs={ o : graph.get_tensor_by_name(o) for o in outputs })
CreateSavedModel(sess, inputs=inputs, outputs=outputs, directory=savedmodel_path)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment