An error occurred while loading the file. Please try again.
-
Cresson Remi authored7fdd1d9a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# -*- coding: utf-8 -*-
# ==========================================================================
#
# Copyright 2018-2019 Remi Cresson (IRSTEA)
# Copyright 2020 Remi Cresson (INRAE)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==========================================================================*/
import gdal
import numpy as np
import tensorflow.compat.v1 as tf
from deprecated import deprecated
tf.disable_v2_behavior()
def read_image_as_np(filename, as_patches=False):
"""
Read an image as numpy array.
@param filename File name of patches image
@param as_patches True if the image must be read as patches
@return 4D numpy array [batch, h, w, c]
"""
# Open a GDAL dataset
ds = gdal.Open(filename)
if ds is None:
raise Exception("Unable to open file {}".format(filename))
# Raster infos
n_bands = ds.RasterCount
szx = ds.RasterXSize
szy = ds.RasterYSize
# Raster array
myarray = ds.ReadAsArray()
# Re-order bands (when there is > 1 band)
if (len(myarray.shape) == 3):
axes = (1, 2, 0)
myarray = np.transpose(myarray, axes=axes)
if (as_patches):
n = int(szy / szx)
return myarray.reshape((n, szx, szx, n_bands))
return myarray.reshape((1, szy, szx, n_bands))
def create_savedmodel(sess, inputs, outputs, directory):
"""
Create a SavedModel
@param sess TF session
@param inputs List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"])
@param outputs List of outputs names (e.g. ["prediction:0", "features:0"])
@param directory Path for the generated SavedModel
"""
print("Create a SavedModel in " + directory)
graph = tf.compat.v1.get_default_graph()
inputs_names = {i: graph.get_tensor_by_name(i) for i in inputs}
outputs_names = {o: graph.get_tensor_by_name(o) for o in outputs}
tf.compat.v1.saved_model.simple_save(sess, directory, inputs=inputs_names, outputs=outputs_names)
def ckpt_to_savedmodel(ckpt_path, inputs, outputs, savedmodel_path, clear_devices=False):
"""
Read a Checkpoint and build a SavedModel
@param ckpt_path Path to the checkpoint file (without the ".meta" extension)
@param inputs List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"])
@param outputs List of outputs names (e.g. ["prediction:0", "features:0"])
@param savedmodel_path Path for the generated SavedModel
@param clear_devices Clear TF devices positionning (True/False)
"""
tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
# Restore variables from disk
model_saver = tf.compat.v1.train.import_meta_graph(ckpt_path + ".meta", clear_devices=clear_devices)
model_saver.restore(sess, ckpt_path)
# Create a SavedModel
create_savedmodel(sess, inputs=inputs, outputs=outputs, directory=savedmodel_path)
@deprecated
def read_samples(filename):
"""
Read a patches image.
@param filename: raster file name
"""
return read_image_as_np(filename, as_patches=True)
@deprecated
def CreateSavedModel(sess, inputs, outputs, directory):
"""
Create a SavedModel
@param sess TF session
@param inputs List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"])
@param outputs List of outputs names (e.g. ["prediction:0", "features:0"])
@param directory Path for the generated SavedModel
"""
create_savedmodel(sess, inputs, outputs, directory)
@deprecated
def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path, clear_devices=False):
"""
Read a Checkpoint and build a SavedModel
@param ckpt_path Path to the checkpoint file (without the ".meta" extension)
@param inputs List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"])
@param outputs List of outputs names (e.g. ["prediction:0", "features:0"])
@param savedmodel_path Path for the generated SavedModel
@param clear_devices Clear TF devices positionning (True/False)
"""
ckpt_to_savedmodel(ckpt_path, inputs, outputs, savedmodel_path, clear_devices)