tricks.py 5.22 KiB
# -*- coding: utf-8 -*-
#==========================================================================
#
#   Copyright Remi Cresson (IRSTEA)
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#          http://www.apache.org/licenses/LICENSE-2.0.txt
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
#==========================================================================*/
import sys
import os
import numpy as np
import math
import time
import otbApplication
import tensorflow as tf
import shutil

def flatten_nparray(np_arr):
  """ Returns a 1D numpy array retulting from the flatten of the input
  """
  return np_arr.reshape((len(np_arr)))

def print_histo(np_arr, title=""):
  """ Prints the histogram of the input numpy array
  """
  np_flat = flatten_nparray(np_arr)
  np_hist = np.bincount(np_flat)
  np_vals = np.unique(np_flat)
  if (len(title) > 0):
    print(title + ":")
  print("Values : "+str(np_vals))
  print("Count  : "+str(np_hist))
  
def print_tensor_live(name, tensor):
  """ Print the shape of a tensor during a session run
  """
  return tf.Print(tensor, [tf.shape(tensor)], name + " shape")

def print_tensor_info(name, tensor):
  """ Print the shape of a tensor
  Args:
    name : the tensor's name (as we want it to be displayed)
    tensor : the tensor 
  """

  print(name+" : "+str(tensor.shape)+" (dtype="+str(tensor.dtype)+")")
  
def read_samples(fn, single=False):
  """ Read an image of patches and return a 4D numpy array
  TODO: Add an optional argument for the y-patchsize
  Args:
    fn: file name
    single: a boolean telling if there is only 1 image in the batch. 
            In this case, the image can be rectangular (not squared)
  """

  # Get input image size
  imageInfo = otbApplication.Registry.CreateApplication('ReadImageInfo')
  imageInfo.SetParameterString('in', fn)
  imageInfo.Execute()
  size_x = imageInfo.GetParameterInt('sizex')
  size_y = imageInfo.GetParameterInt('sizey')
  nbands = imageInfo.GetParameterInt('numberbands')

  print("Loading image "+str(fn)+" ("+str(size_x)+" x "+str(size_y)+" x "+str(nbands)+")")
  
  # Prepare the PixelValue application
  imageReader = otbApplication.Registry.CreateApplication('ExtractROI')
  imageReader.SetParameterString('in', fn)
  imageReader.SetParameterInt('sizex', size_x)
  imageReader.SetParameterInt('sizey', size_y)
  imageReader.Execute()
  outimg=imageReader.GetVectorImageAsNumpyArray('out', 'float')
  
  # quick stats
  print("Quick stats: min="+str(np.amin(outimg))+", max="+str(np.amax(outimg)) )
  
  # reshape
  if (single):
    return np.copy(outimg.reshape((1, size_y, size_x, nbands)))

  n_samples = int(size_y / size_x)
  outimg = outimg.reshape((n_samples, size_x, size_x, nbands))
  
  print("Returned numpy array shape: "+str(outimg.shape))
  return np.copy(outimg)
  
def getBatch(X, Y, i, batch_size):
	start_id = i*batch_size
	end_id = min( (i+1) * batch_size, X.shape[0])
	batch_x = X[start_id:end_id]
	batch_y = Y[start_id:end_id]
 
	return batch_x, batch_y

def CreateSavedModel(sess, inputs, outputs, directory):
  """
  Create a SavedModel
  
  Args:
    sess: the session
    inputs: the list of input names
    outputs: the list of output names
    directory: the output path for the SavedModel
  """

  print("Create a SavedModel in " + directory)

      
  # Get graph
  graph = tf.get_default_graph()
  
  # Get inputs
  input_dict = { i : graph.get_tensor_by_name(i) for i in inputs }
  output_dict = { o : graph.get_tensor_by_name(o) for o in outputs }  
        
  # Build the SavedModel
  builder = tf.saved_model.builder.SavedModelBuilder(directory)
  signature_def_map= {
    "model": tf.saved_model.signature_def_utils.predict_signature_def(
    input_dict,
    output_dict)
  }
  builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.TRAINING],signature_def_map)
  builder.add_meta_graph([tf.saved_model.tag_constants.SERVING])
  builder.save()

def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path, clear_devices=False):
  """
  Read a Checkpoint and build a SavedModel
  
  Args:
    ckpt_path: path to the checkpoint file (without the ".meta" extension)
    inputs: input list of placeholders names (e.g. ["x_cnn_1:0", "x_cnn_2:0"])
    outputs: output list of tensor outputs names (e.g. ["prediction:0", "features:0"])
    savedmodel_path: path to the SavedModel
  """
  tf.reset_default_graph()
  with tf.Session() as sess:
    
    # Restore variables from disk.
    model_saver = tf.train.import_meta_graph(ckpt_path+".meta", clear_devices=clear_devices)
    model_saver.restore(sess, ckpt_path)
    
    # Create a SavedModel
    #CreateSavedModel(sess, inputs, outputs, savedmodel_path)
    graph = tf.get_default_graph()
    tf.saved_model.simple_save(sess,
            savedmodel_path,
            inputs={ i : graph.get_tensor_by_name(i) for i in inputs },
            outputs={ o : graph.get_tensor_by_name(o) for o in outputs })