Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
What's new
7
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Open sidebar
Lozac'h Loic
otbtf
Commits
58e508bd
Commit
58e508bd
authored
Sep 12, 2018
by
remi cresson
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
DOC: add copyright header + a bit of refactoring
parent
ad741859
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
61 deletions
+45
-61
python/create_model_ienco-m3_patchbased.py
python/create_model_ienco-m3_patchbased.py
+24
-44
python/create_model_maggiori17_fullyconv.py
python/create_model_maggiori17_fullyconv.py
+21
-17
No files found.
python/create_model_ienco-m3_patchbased.py
View file @
58e508bd
# -*- coding: utf-8 -*-
#==========================================================================
#
# Copyright Remi Cresson, Dino Ienco (IRSTEA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#==========================================================================*/
import
sys
import
os
import
numpy
as
np
...
...
@@ -14,35 +32,6 @@ from sklearn.utils import shuffle
from
sklearn.metrics
import
confusion_matrix
from
tricks
import
*
def
export_model
(
sess
,
export_dir
,
x_cnn_placeholder
,
x_rnn_placeholder
,
is_training_placeholder
,
testPrediction
):
""" export a SavedModel
"""
# Update the export dir
model_dir
=
export_dir
+
"/saved_model/"
if
os
.
path
.
exists
(
model_dir
):
shutil
.
rmtree
(
model_dir
)
print
(
"Export model in "
+
model_dir
)
# Add a builder (for LoadSavedModel)
builder
=
tf
.
saved_model
.
builder
.
SavedModelBuilder
(
model_dir
)
signature_def_map
=
{
"model"
:
tf
.
saved_model
.
signature_def_utils
.
predict_signature_def
(
inputs
=
{
"x_cnn"
:
x_cnn_placeholder
,
"x_rnn"
:
x_rnn_placeholder
,
"is_training"
:
is_training_placeholder
},
outputs
=
{
"prediction"
:
testPrediction
})
}
builder
.
add_meta_graph_and_variables
(
sess
,[
tf
.
saved_model
.
tag_constants
.
TRAINING
],
signature_def_map
)
builder
.
add_meta_graph
([
tf
.
saved_model
.
tag_constants
.
SERVING
])
builder
.
save
()
def
checkTest
(
ts_data
,
vhsr_data
,
batchsz
,
label_test
):
tot_pred
=
[]
# gt_test = []
...
...
@@ -61,10 +50,6 @@ def checkTest(ts_data, vhsr_data, batchsz, label_test):
dropout
:
0.0
,
x_cnn
:
batch_cnn_x
})
del
batch_rnn_x
del
batch_cnn_x
del
batch_y
for
el
in
pred_temp
:
tot_pred
.
append
(
el
)
...
...
@@ -241,8 +226,8 @@ n_channels = 4
nclasses
=
8
# check number of arguments
if
len
(
sys
.
argv
)
!=
7
:
print
(
"Usage : <ts_train> <vhs_train> <label_train> <ts_valid> <vhs_valid> <label_valid>"
)
if
len
(
sys
.
argv
)
!=
8
:
print
(
"Usage : <ts_train> <vhs_train> <label_train> <ts_valid> <vhs_valid> <label_valid>
<export_dir>
"
)
sys
.
exit
(
1
)
ts_train
=
read_samples
(
sys
.
argv
[
1
])
...
...
@@ -257,6 +242,8 @@ label_test = read_samples(sys.argv[6])
label_test
=
np
.
int32
(
label_test
)
print_histo
(
label_test
,
"label_test"
)
export_dir
=
read_samples
(
sys
.
argv
[
7
])
x_rnn
=
tf
.
placeholder
(
tf
.
float32
,[
None
,
1
,
1
,
n_dims
*
n_timestamps
],
name
=
"x_rnn"
)
x_cnn
=
tf
.
placeholder
(
tf
.
float32
,[
None
,
patch_window
,
patch_window
,
n_channels
],
name
=
"x_cnn"
)
y
=
tf
.
placeholder
(
tf
.
int32
,[
None
,
1
,
1
,
1
],
name
=
"y"
)
...
...
@@ -324,19 +311,12 @@ for e in range(hm_epochs):
lossi
+=
loss
accS
+=
acc
del
batch_rnn_x
del
batch_cnn_x
del
batch_y
print
"Epoch:"
,
e
,
"Train loss:"
,
lossi
/
iterations
,
"| accuracy:"
,
accS
/
iterations
c_loss
=
lossi
/
iterations
if
c_loss
<
best_loss
:
save_path
=
saver
.
save
(
sess
,
"models/model"
)
print
(
"Model saved in path: %s"
%
save_path
)
best_loss
=
c_loss
export_model
(
sess
,
"/tmp/m3_export"
,
x_cnn
,
x_rnn
,
is_training_ph
,
testPrediction
)
CreateSavedModel
(
sess
,
[
"x_cnn:0"
,
"x_rnn:0"
,
"is_training:0"
],
[
"prediction:0"
],
export_dir
)
test_acc
=
checkTest
(
ts_test
,
vhsr_test
,
1024
,
label_test
)
\ No newline at end of file
python/create_model_maggiori17_fullyconv.py
View file @
58e508bd
# -*- coding: utf-8 -*-
#==========================================================================
#
# Copyright Remi Cresson (IRSTEA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#==========================================================================*/
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
@@ -89,12 +107,11 @@ def main(unused_argv):
# check number of arguments
if
len
(
sys
.
argv
)
!=
4
:
print
(
"Usage : <patches> <labels> <
output_model
_dir>"
)
print
(
"Usage : <patches> <labels> <
export
_dir>"
)
sys
.
exit
(
1
)
# Export dir
log_dir
=
sys
.
argv
[
3
]
+
'/model_checkpoints/'
export_dir
=
sys
.
argv
[
3
]
+
'/model_export/'
export_dir
=
sys
.
argv
[
3
]
print
(
"loading dataset"
)
...
...
@@ -272,13 +289,9 @@ def main(unused_argv):
if
step
%
10
==
0
:
# Print status to stdout.
print
(
'Step %d: loss = %.2f (%.3f sec)'
%
(
step
,
loss_value
,
duration
))
#print('Step %d: (%.3f sec)' % (step, duration))
# Save a checkpoint and evaluate the model periodically.
if
(
curr_epoch
+
1
)
%
1
==
0
:
checkpoint_file
=
os
.
path
.
join
(
log_dir
,
'model.ckpt'
)
saver
.
save
(
sess
,
checkpoint_file
,
global_step
=
step
)
# Evaluate against the training set.
print
(
'Training Data Eval:'
)
do_eval2
(
sess
,
...
...
@@ -301,16 +314,7 @@ def main(unused_argv):
batch_size
)
# Let's export a SavedModel
shutil
.
rmtree
(
export_dir
)
builder
=
tf
.
saved_model
.
builder
.
SavedModelBuilder
(
export_dir
)
signature_def_map
=
{
"model"
:
tf
.
saved_model
.
signature_def_utils
.
predict_signature_def
(
inputs
=
{
"x1"
:
xs_placeholder
},
outputs
=
{
"prediction"
:
testPrediction
})
}
builder
.
add_meta_graph_and_variables
(
sess
,
[
tf
.
saved_model
.
tag_constants
.
TRAINING
],
signature_def_map
)
builder
.
add_meta_graph
([
tf
.
saved_model
.
tag_constants
.
SERVING
])
builder
.
save
()
CreateSavedModel
(
sess
,
[
"x1:0"
],
[
"prediction:0"
],
export_dir
)
quit
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment