Commit 42a39b17 authored by Cresson Remi's avatar Cresson Remi

REFAC: read patches using GDAL, DOC: functions

parent b9a7b692
# -*- coding: utf-8 -*-
# Copyright Remi Cresson (IRSTEA)
# Copyright 2018-2019 Remi Cresson (IRSTEA)
# Copyright 2020 Remi Cresson (INRAE)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -25,16 +26,18 @@ from tricks import CheckpointToSavedModel
# Parser
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", help="checkpoint file prefix", required=True)
parser.add_argument("--inputs", help="input placeholder names", required=True, nargs='+')
parser.add_argument("--outputs", help="output placeholder names", required=True, nargs='+')
parser.add_argument("--model", help="output SavedModel", required=True)
parser.add_argument("--ckpt", help="Checkpoint file (without the \".meta\" extension)", required=True)
parser.add_argument("--inputs", help="Inputs names (e.g. [\"x_cnn_1:0\", \"x_cnn_2:0\"])", required=True, nargs='+')
parser.add_argument("--outputs", help="Outputs names (e.g. [\"prediction:0\", \"features:0\"])", required=True, nargs='+')
parser.add_argument("--model", help="Output directory for SavedModel", required=True)
parser.add_argument('--clear_devices', dest='clear_devices', action='store_true')
params = parser.parse_args()
if __name__ == "__main__":
CheckpointToSavedModel(params.ckpt, params.inputs, params.outputs, params.model, params.clear_devices)
# -*- coding: utf-8 -*-
# Copyright Remi Cresson (IRSTEA)
# Copyright 2018-2019 Remi Cresson (IRSTEA)
# Copyright 2020 Remi Cresson (INRAE)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -16,146 +17,76 @@
# limitations under the License.
import sys
import os
import gdal
import numpy as np
import math
import time
import otbApplication
import tensorflow as tf
import shutil
from deprecated import deprecated
def flatten_nparray(np_arr):
""" Returns a 1D numpy array retulting from the flatten of the input
return np_arr.reshape((len(np_arr)))
def print_histo(np_arr, title=""):
""" Prints the histogram of the input numpy array
np_flat = flatten_nparray(np_arr)
np_hist = np.bincount(np_flat)
np_vals = np.unique(np_flat)
if (len(title) > 0):
print(title + ":")
print("Values : "+str(np_vals))
print("Count : "+str(np_hist))
def read_samples(fn, single=False):
return ReadImageAsNumpyArray(filename=fn, asPatches=not single)
def print_tensor_live(name, tensor):
""" Print the shape of a tensor during a session run
return tf.Print(tensor, [tf.shape(tensor)], name + " shape")
def print_tensor_info(name, tensor):
""" Print the shape of a tensor
name : the tensor's name (as we want it to be displayed)
tensor : the tensor
def ReadImageAsNumpyArray(filename, asPatches=False):
Read an image as numpy array.
@param filename File name of patches image
@param asPatches True if the image must be read as patches
@return 4D numpy array [batch, h, w, c]
print(name+" : "+str(tensor.shape)+" (dtype="+str(tensor.dtype)+")")
def read_samples(fn, single=False):
""" Read an image of patches and return a 4D numpy array
TODO: Add an optional argument for the y-patchsize
fn: file name
single: a boolean telling if there is only 1 image in the batch.
In this case, the image can be rectangular (not squared)
# Open a GDAL dataset
ds = gdal.Open(filename)
if (ds is None):
raise Exception("Unable to open file {}".format(filename))
# Raster infos
nBands = ds.RasterCount
szx = ds.RasterXSize
szy = ds.RasterYSize
# Get input image size
imageInfo = otbApplication.Registry.CreateApplication('ReadImageInfo')
imageInfo.SetParameterString('in', fn)
size_x = imageInfo.GetParameterInt('sizex')
size_y = imageInfo.GetParameterInt('sizey')
nbands = imageInfo.GetParameterInt('numberbands')
# Raster array
myarray = ds.ReadAsArray()
print("Loading image "+str(fn)+" ("+str(size_x)+" x "+str(size_y)+" x "+str(nbands)+")")
# Prepare the PixelValue application
imageReader = otbApplication.Registry.CreateApplication('ExtractROI')
imageReader.SetParameterString('in', fn)
imageReader.SetParameterInt('sizex', size_x)
imageReader.SetParameterInt('sizey', size_y)
outimg=imageReader.GetVectorImageAsNumpyArray('out', 'float')
# quick stats
print("Quick stats: min="+str(np.amin(outimg))+", max="+str(np.amax(outimg)) )
# reshape
if (single):
return np.copy(outimg.reshape((1, size_y, size_x, nbands)))
# Re-order bands (when there is > 1 band)
if (len(myarray.shape) == 3):
axes = (1, 2, 0)
myarray = np.transpose(myarray, axes=axes)
n_samples = int(size_y / size_x)
outimg = outimg.reshape((n_samples, size_x, size_x, nbands))
print("Returned numpy array shape: "+str(outimg.shape))
return np.copy(outimg)
def getBatch(X, Y, i, batch_size):
start_id = i*batch_size
end_id = min( (i+1) * batch_size, X.shape[0])
batch_x = X[start_id:end_id]
batch_y = Y[start_id:end_id]
return batch_x, batch_y
if (asPatches):
n = int(szy / szx)
return myarray.reshape((n, szx, szx, nBands))
return myarray.reshape((1, szy, szx, nBands))
def CreateSavedModel(sess, inputs, outputs, directory):
Create a SavedModel
sess: the session
inputs: the list of input names
outputs: the list of output names
directory: the output path for the SavedModel
@param sess TF session
@param inputs List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"])
@param outputs List of outputs names (e.g. ["prediction:0", "features:0"])
@param directory Path for the generated SavedModel
print("Create a SavedModel in " + directory)
# Get graph
graph = tf.get_default_graph()
# Get inputs
input_dict = { i : graph.get_tensor_by_name(i) for i in inputs }
output_dict = { o : graph.get_tensor_by_name(o) for o in outputs }
# Build the SavedModel
builder = tf.saved_model.builder.SavedModelBuilder(directory)
signature_def_map= {
"model": tf.saved_model.signature_def_utils.predict_signature_def(
inputs_names = { i : graph.get_tensor_by_name(i) for i in inputs }
outputs_names = { o : graph.get_tensor_by_name(o) for o in outputs }
tf.saved_model.simple_save(sess, directory, inputs=inputs_names, outputs=outputs_names)
def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path, clear_devices=False):
Read a Checkpoint and build a SavedModel
ckpt_path: path to the checkpoint file (without the ".meta" extension)
inputs: input list of placeholders names (e.g. ["x_cnn_1:0", "x_cnn_2:0"])
outputs: output list of tensor outputs names (e.g. ["prediction:0", "features:0"])
savedmodel_path: path to the SavedModel
@param ckpt_path Path to the checkpoint file (without the ".meta" extension)
@param inputs List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"])
@param outputs List of outputs names (e.g. ["prediction:0", "features:0"])
@param savedmodel_path Path for the generated SavedModel
@param clear_devices Clear TF devices positionning (True/False)
with tf.Session() as sess:
# Restore variables from disk.
# Restore variables from disk
model_saver = tf.train.import_meta_graph(ckpt_path+".meta", clear_devices=clear_devices)
model_saver.restore(sess, ckpt_path)
# Create a SavedModel
#CreateSavedModel(sess, inputs, outputs, savedmodel_path)
graph = tf.get_default_graph()
inputs={ i : graph.get_tensor_by_name(i) for i in inputs },
outputs={ o : graph.get_tensor_by_name(o) for o in outputs })
CreateSavedModel(sess, inputs=inputs, outputs=outputs, directory=savedmodel_path)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment