Commit 4f71f0bd authored by Cresson Remi's avatar Cresson Remi
Browse files

Merge branch 'checkpoints_callbacks_fixes' into 'master'

Checkpoints callbacks fixes

See merge request !6
parents b0447af7 d604263c
Pipeline #36558 passed with stages
in 12 minutes and 57 seconds
......@@ -59,7 +59,7 @@ Build the docker image:
flake8:
extends: .static_analysis_base
script:
- sudo apt update && sudo apt install -y flake8 && python -m flake8 -ignore=E402 --max-line-length=120 $PWD/decloud
- sudo apt update && sudo apt install -y flake8 && python -m flake8 --ignore=E402 --max-line-length=120 $PWD/decloud
pylint:
extends: .static_analysis_base
......
......@@ -26,6 +26,7 @@ This scripts summarizes the number of samples that we can get from an Acquisitio
suited for single optical image reconstruction from date SAR/optical pair, for different
parameters of the AcquisitionsLayout
"""
import os
import argparse
import logging
from decloud.acquisitions.sensing_layout import AcquisitionsLayout, S1Acquisition, S2Acquisition
......@@ -139,8 +140,8 @@ for max_s1s2_gap_hours in params.maxgaps1s2_list:
np_counts[pos[0], pos[1]] += nb_samples_in_patch
# Export
out_fn = "count_gap{}_range{}-{}_{}.tif".format(max_s1s2_gap_hours, int_radius, ext_radius, tile_name)
out_fn = system.pathify(params.out_dir) + out_fn
out_fn = f"count_gap{max_s1s2_gap_hours}_range{int_radius}-{ext_radius}_{tile_name}.tif"
out_fn = os.path.join(params.out_dir, out_fn)
logging.info("Saving %s", out_fn)
raster.save_numpy_array_as_raster(ref_fn=ref_fn, np_arr=np_counts, out_fn=out_fn, scale=scale)
......
......@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
"""
Analyze the S1 and S2 orbits
"""
import os
import argparse
import numpy as np
import logging
......@@ -80,5 +81,5 @@ for tile_name, tile_handler in th.items():
# Export with pyotb
out = np.add(initialized_raster, histo_array) # this is a pyotb object
out_fn = system.pathify(params.out_dir) + "{}_s1s2gap_hist.tif".format(tile_name)
out_fn = os.path.join(params.out_dir, f"{tile_name}_s1s2gap_hist.tif")
out.write(out_fn)
......@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
"""
Compute the number of S1 and S2 images used for each patch.
"""
import os
import argparse
import logging
import numpy as np
......@@ -59,7 +60,7 @@ scale = float(params.patch_size) / float(constants.PATCHSIZE_REF)
for al_bname, al in als:
for tile_name, tile_handler in th.items():
# Output files prefix
out_prefix = system.pathify(params.out_dir) + tile_name + "_" + al_bname
out_prefix = os.join(params.out_dir, tile_name + "_" + al_bname)
# Reference raster grid
ref_fn = tile_handler.s2_images[0].clouds_stats_fn
......@@ -94,6 +95,6 @@ for al_bname, al in als:
# Export
for key in ["s1", "s2"]:
out_fn = "{}_{}_freq.tif".format(out_prefix, key)
out_fn = f"{out_prefix}_{key}_freq.tif"
logging.info("Saving %s", out_fn)
raster.save_numpy_array_as_raster(ref_fn=ref_fn, np_arr=np_counts[key], out_fn=out_fn, scale=scale)
......@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
"""
Compute cloud coverage and pixel validity from an input set of tiles
"""
import os
import argparse
import logging
import numpy as np
......@@ -46,7 +47,7 @@ def compute_stats(tile_name, tile_handler):
:param tile_handler: Tile handler instance
"""
ref_fn = tile_handler.s2_images[0].clouds_stats_fn
out_prefix = system.pathify(params.out_dir) + tile_name
out_prefix = os.path.join(params.out_dir, tile_name)
# Statistics
cloud_cov = np.sum(np.multiply(tile_handler.s2_images_validity, tile_handler.s2_images_cloud_coverage), axis=0)
......
......@@ -21,6 +21,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
"""Dataset classes"""
import os
import abc
import collections.abc
import json
......@@ -436,7 +437,6 @@ class RoisLoader(dict):
assert root_dir_key in data
self.rois_root_dir = data[root_dir_key]
assert isinstance(self.rois_root_dir, str)
self.rois_root_dir = system.pathify(self.rois_root_dir)
def get_list(key):
"""
......@@ -464,7 +464,7 @@ class RoisLoader(dict):
"""
tiles = {}
for tile in tiles_list:
roi_file = "{}{}_{}.tif".format(self.rois_root_dir, tile, suffix)
roi_file = os.path.join(self.rois_root_dir, f"{tile}_{suffix}.tif")
assert system.file_exists(roi_file)
tiles.update({tile: roi_file})
self.update({"roi_{}".format(suffix): tiles})
......@@ -21,9 +21,9 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
"""Classes for Tensorflow summaries"""
import os
import tensorflow as tf
from tensorflow import keras
from decloud.core import system
from decloud.preprocessing import constants
......@@ -71,7 +71,7 @@ class PreviewsCallback(keras.callbacks.Callback):
predicted = self.model.predict(self.test_data)
# Log the images summary.
file_writer = tf.summary.create_file_writer(system.pathify(self.logdir) + 'previews')
file_writer = tf.summary.create_file_writer(os.path.join(self.logdir, 'previews'))
with file_writer.as_default():
for key in self.target_keys:
tf.summary.image("predicted: " + key, get_preview_fn(key)(predicted[key]), step=epoch)
......
......@@ -84,16 +84,7 @@ def get_files(directory, ext=None):
def new_bname(filename, suffix):
""" return a new basename (without path, without extension, + suffix) """
filename = filename[filename.rfind("/"):]
filename = filename[:filename.rfind(".")]
return filename + "_" + suffix
def pathify(pth):
""" Adds posix separator if needed """
if not pth.endswith("/"):
pth += "/"
return pth
return pathlib.Path(filename).stem + "_" + suffix
def mkdir(pth):
......
......@@ -21,6 +21,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
"""Classes for Sentinel images access"""
import os
import datetime
import itertools
import json
......@@ -70,9 +71,9 @@ def s1_filename_to_md(filename):
metadata = dict()
splits = basename.split("_")
if len(splits) != 7:
raise Exception("{} not a S1 image (wrong number of splits between \"_\" in filename)".format(filename))
raise Exception(f"{filename} not a S1 image (wrong number of splits between \"_\" in filename)")
if len(splits[5]) < 15:
raise Exception("{} not a S1 archive (wrong date format)".format(filename))
raise Exception(f"{filename} not a S1 archive (wrong date format)")
date_str = splits[5][:15]
metadata["tile"] = splits[1]
if date_str[9:15] == "xxxxxx":
......@@ -110,7 +111,7 @@ def create_s1_image(vvvh_gtiff, ref_patchsize, patchsize_10m):
pth = system.dirname(vvvh_gtiff)
# Compute stats
edge_stats_fn = system.pathify(pth) + system.new_bname(metadata["filename"], constants.SUFFIX_STATS_S1)
edge_stats_fn = os.path.join(pth, system.new_bname(metadata["filename"], constants.SUFFIX_STATS_S1))
compute_patches_stats(image=metadata["filename"], output_stats=edge_stats_fn, expr="im1b1==0&&im1b2==0",
patchsize=ref_patchsize)
......@@ -145,7 +146,7 @@ def s2_filename_to_md(filename):
metadata = dict()
splits = basename.split("_")
if len(splits) < 4:
raise Exception("{} might not be a S2 product".format(filename))
raise Exception(f"{filename} might not be a S2 product")
metadata["tile"] = splits[3]
datestr = splits[1]
metadata["date"] = datetime.datetime.strptime(datestr[:-1], '%Y%m%d-%H%M%S-%f')
......@@ -180,7 +181,7 @@ def compute_patches_stats(image, output_stats, patchsize, expr=""):
raise Exception("\"image\" must be of type str, if no expr is provided!")
app.SetParameterString("in", image)
app.SetParameterInt("patchsize", patchsize)
app.SetParameterString("out", "{}?&gdal:co:COMPRESS=DEFLATE".format(output_stats))
app.SetParameterString("out", f"{output_stats}?&gdal:co:COMPRESS=DEFLATE")
app.SetParameterOutputImagePixelType("out", otbApplication.ImagePixelType_uint16)
app.ExecuteAndWriteOutput()
system.declare_complete(output_stats)
......@@ -216,7 +217,7 @@ def create_s2_image_from_dir(s2_product_dir, ref_patchsize, patchsize_10m, with_
# Check that files exists
def _check(title, filename):
if filename is None:
raise Exception("File for {} does not exist in product {}".format(title, s2_product_dir))
raise Exception(f"File for {title} does not exist in product {s2_product_dir}")
_check("edge mask", edg_mask)
_check("cloud mask", cld_mask)
......@@ -231,13 +232,13 @@ def create_s2_image_from_dir(s2_product_dir, ref_patchsize, patchsize_10m, with_
logging.debug("\t20m bands: %s", b20m_imgs)
# Compute stats
clouds_stats_fn = system.pathify(s2_product_dir) + system.new_bname(cld_mask, constants.SUFFIX_STATS_S2)
edge_stats_fn = system.pathify(s2_product_dir) + system.new_bname(edg_mask, constants.SUFFIX_STATS_S2)
clouds_stats_fn = os.path.join(s2_product_dir, system.new_bname(cld_mask, constants.SUFFIX_STATS_S2))
edge_stats_fn = os.path.join(s2_product_dir, system.new_bname(edg_mask, constants.SUFFIX_STATS_S2))
compute_patches_stats(image=cld_mask, output_stats=clouds_stats_fn, expr="im1b1>0", patchsize=ref_patchsize)
compute_patches_stats(image=edg_mask, output_stats=edge_stats_fn, patchsize=ref_patchsize)
# Return a s2 image class
metadata = s2_filename_to_md(system.pathify(s2_product_dir))
metadata = s2_filename_to_md(s2_product_dir)
return S2Image(acq_date=metadata["date"],
edge_stats_fn=edge_stats_fn,
bands_10m_fn=b10m_imgs,
......@@ -333,7 +334,7 @@ class AbstractImage(ABC):
:return: a numpy array
"""
if key not in self.patch_sources:
raise Exception("Key {} not in patches sources. Available sources keys: {}".format(key, self.patch_sources))
raise Exception(f"Key {key} not in patches sources. Available sources keys: {self.patch_sources}")
return self.patch_sources[key].get(patch_location=patch_location)
@abstractmethod
......@@ -917,7 +918,7 @@ class TileHandler:
new_sample['geoinfo'] = \
src.patch_sources[KEY_S2_BANDS_10M].get_geographic_info(patch_location=tuple_pos)
else:
raise Exception("Unknown key {}!".format(sx_key))
raise Exception(f"Unknown key {sx_key}!")
src_dict = src.get(patch_location=tuple_pos)
for src_key, src_np_arr in src_dict.items():
# the final key is composed in concatenating key, "_", src_key
......@@ -971,7 +972,7 @@ class TilesLoader(dict):
if key in data:
value = data[key]
assert isinstance(value, str)
return system.pathify(value)
return value
return None
# Paths
......@@ -980,12 +981,12 @@ class TilesLoader(dict):
self.dem_tiles_root_dir = get_pth("DEM_ROOT_DIR")
if self.s2_tiles_root_dir is None:
raise Exception("S2_ROOT_DIR key not found in {}".format(the_json))
raise Exception(f"S2_ROOT_DIR key not found in {the_json}")
# Tiles list
self.tiles_list = data["TILES"]
if self.tiles_list is None:
raise Exception("TILES key not found in {}".format(the_json))
raise Exception(f"TILES key not found in {the_json}")
if not isinstance(self.tiles_list, list):
raise Exception("TILES value must be a list of strings!")
......@@ -995,7 +996,7 @@ class TilesLoader(dict):
def _get_tile_pth(root_dir, current_tile=tile):
""" Returns the directory for the current tile """
if root_dir is not None:
return root_dir + current_tile
return os.path.join(root_dir, current_tile)
return None
s1_dir = _get_tile_pth(self.s1_tiles_root_dir)
......
......@@ -25,7 +25,6 @@ import os
import shutil
import tensorflow as tf
from tensorflow import keras
from decloud.core import system
from decloud.models.utils import _is_chief
# Callbacks being called at the end of each epoch during training
......@@ -45,7 +44,7 @@ class ArchiveCheckpoint(keras.callbacks.Callback):
self.backup_dir = backup_dir
self.strategy = strategy
def on_epoch_end(self, epoch, logs=None):
def on_epoch_begin(self, epoch, logs=None):
"""
At the end of each epoch, we save the directory of BackupAndRestore to a different name for archiving
"""
......@@ -92,7 +91,7 @@ class AdditionalValidationSets(keras.callbacks.Callback):
for metric, result in zip(self.model.metrics_names, results):
if self.logdir:
writer = tf.summary.create_file_writer(system.pathify(self.logdir) + 'validation_{}'.format(i + 1))
writer = tf.summary.create_file_writer(os.path.join(self.logdir, 'validation_{}'.format(i + 1)))
with writer.as_default():
tf.summary.scalar('epoch_' + metric, result, step=epoch) # tensorboard adds an 'epoch_' prefix
else:
......
......@@ -22,6 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
"""Create some TFRecords from a decloud.dataset"""
import os
import argparse
import sys
import logging
......
......@@ -44,12 +44,12 @@ class TFRecords:
if system.is_dir(path) or not os.path.exists(path):
self.dirpath = path
system.mkdir(self.dirpath)
self.tfrecords_pattern_path = "{}*.records".format(system.pathify(self.dirpath))
self.tfrecords_pattern_path = f"{self.dirpath}/*.records"
else:
self.dirpath = system.dirname(path)
self.tfrecords_pattern_path = path
self.output_types_file = "{}output_types.json".format(system.pathify(self.dirpath))
self.output_shape_file = "{}output_shape.json".format(system.pathify(self.dirpath))
self.output_types_file = f"{self.dirpath}/output_types.json"
self.output_shape_file = f"{self.dirpath}/output_shape.json"
self.output_shape = self.load(self.output_shape_file) if os.path.exists(self.output_shape_file) else None
self.output_types = self.load(self.output_types_file) if os.path.exists(self.output_types_file) else None
......@@ -99,10 +99,10 @@ class TFRecords:
else:
nb_sample = dataset.size - i * n_samples_per_shard
filepath = "{}{}.records".format(system.pathify(self.dirpath), i)
filepath = os.path.join(self.dirpath, f"{i}.records")
# Geographic info of all samples of the record
geojson_path = "{}{}.geojson".format(system.pathify(self.dirpath), i)
geojson_path = os.path.join(self.dirpath, f"{i}.geojson")
geojson_dic = {"type": "FeatureCollection",
"name": "{}_geoinfo".format(i),
"features": []}
......
......@@ -152,7 +152,7 @@ def main(args):
# adding the info to the SavedModel path
out_savedmodel = None if params.out_savedmodel is None else \
system.pathify(params.out_savedmodel) + expe_name + date_tag
os.path.join(params.out_savedmodel, expe_name + date_tag)
# Scaling batch size and learning rate accordingly to number of workers
batch_size_train = params.batch_size_train * n_workers
......@@ -203,17 +203,16 @@ def main(args):
if params.strategy == 'singlecpu':
logging.warning('Checkpoints can not be saved while using singlecpu option. Discarding checkpoints')
else:
# Create a backup
backup_dir = system.pathify(params.ckpt_dir) + params.model
callbacks.append(keras.callbacks.experimental.BackupAndRestore(backup_dir=backup_dir))
# Save the checkpoint to a persistent location
backup_dir = os.path.join(params.ckpt_dir, params.model)
# Backup (deleted once the model is trained the specified number of epochs)
callbacks.append(keras.callbacks.BackupAndRestore(backup_dir=backup_dir))
# Persistent save (still here after the model is trained)
callbacks.append(ArchiveCheckpoint(backup_dir, strategy))
# Define the Keras TensorBoard callback.
logdir = None
if params.logdir:
logdir = system.pathify(params.logdir) + "{}_{}".format(date_tag, expe_name)
logdir = os.path.join(params.logdir, f"{date_tag}_{expe_name}")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir,
profile_batch=params.profiling)
callbacks.append(tensorboard_callback)
......
......@@ -22,6 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
"""Pre-process one Sentinel-1 image"""
import os
import logging
import argparse
import otbApplication
......@@ -79,7 +80,7 @@ def main(args):
constants.PATCHSIZE_REF)
# Calibration + concatenation + tiling/compression
out_fn = system.pathify(params.out_s1_dir) + out_fn
out_fn = os.path.join(params.out_s1_dir, out_fn)
if system.is_complete(out_fn):
logging.info("File %s already exists. Skipping.", system.remove_ext_filename(out_fn))
else:
......
......@@ -22,6 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
"""Pre-process one Sentinel-2 image"""
import os
import logging
import argparse
import otbApplication
......@@ -43,7 +44,7 @@ def fconc(il, suffix, tilesize, out_tile_dir, pixel_type=otbApplication.ImagePix
logging.info("Concatenate + Tile + Compress GeoTiff for files: %s", "".join(il))
out_fn = system.basename(il[0])
out_fn = out_fn[:out_fn.rfind("_")]
out_fn = out_tile_dir + out_fn + "_" + suffix
out_fn = os.path.join(out_tile_dir, out_fn + "_" + suffix)
out_fn += ".tif?&gdal:co:COMPRESS=DEFLATE"
out_fn += "&streaming:type=tiled&streaming:sizemode=height&streaming:sizevalue={}".format(4 * tilesize)
out_fn += "&gdal:co:TILED=YES&gdal:co:BLOCKXSIZE={ts}&gdal:co:BLOCKYSIZE={ts}".format(ts=tilesize)
......@@ -130,7 +131,7 @@ def main(args):
logging.info("Product name is %s", product_name)
# Out directory
out_tile_dir = system.pathify(params.out_s2_dir) + tile_name + "/" + product_name + "/"
out_tile_dir = os.path.join(params.out_s2_dir, tile_name, product_name)
logging.info("Create or use the following output directory: %s", out_tile_dir)
system.mkdir(out_tile_dir)
......
......@@ -7,7 +7,7 @@ import filecmp
import gdal
import otbApplication as otb
from abc import ABC
from decloud.core.system import get_env_var, pathify, basename
from decloud.core.system import get_env_var, basename
class DecloudTest(ABC, unittest.TestCase):
......@@ -15,7 +15,10 @@ class DecloudTest(ABC, unittest.TestCase):
DECLOUD_DATA_DIR = get_env_var("DECLOUD_DATA_DIR")
def get_path(self, path):
return pathify(self.DECLOUD_DATA_DIR) + path
pth = os.path.join(self.DECLOUD_DATA_DIR, path)
if not os.path.exists(pth):
raise FileNotFoundError(f"Directory {pth} not found!")
return pth
def compare_images(self, image, reference, mae_threshold=0.01):
......
......@@ -8,6 +8,7 @@ from .decloud_unittest import DecloudTest
SAVEDMODEL_FILENAME = "saved_model.pb"
def is_savedmodel_written(args_list):
out_savedmodel = "/tmp/savedmodel"
base_args = ["--logdir", "/tmp/logdir",
......@@ -25,8 +26,9 @@ def is_savedmodel_written(args_list):
OS2_TFREC_PTH = "baseline/TFRecord/CRGA"
OS2_ALL_BANDS_TFREC_PTH = "/baseline/TFRecord/CRGA_all_bands"
MERANER_ALL_BANDS_TFREC_PTH = "/baseline/TFRecord/CRGA_all_bands"
OS2_ALL_BANDS_TFREC_PTH = "baseline/TFRecord/CRGA_all_bands"
MERANER_ALL_BANDS_TFREC_PTH = "baseline/TFRecord/CRGA_all_bands"
ERRMSG = f"File {SAVEDMODEL_FILENAME} not found !"
class TrainFromTFRecordsTest(DecloudTest):
......@@ -34,42 +36,42 @@ class TrainFromTFRecordsTest(DecloudTest):
def test_trainFromTFRecords_os1_unet(self):
self.assertTrue(is_savedmodel_written(["--training_record", self.get_path(OS2_TFREC_PTH),
"--model", "crga_os1_unet"]),
"File {} not found !".format(SAVEDMODEL_FILENAME))
ERRMSG)
def test_trainFromTFRecords_os2_david(self):
self.assertTrue(is_savedmodel_written(["--training_record", self.get_path(OS2_TFREC_PTH),
"--model", "crga_os2_david"]),
"File {} not found !".format(SAVEDMODEL_FILENAME))
ERRMSG)
def test_trainFromTFRecords_os2_unet(self):
self.assertTrue(is_savedmodel_written(["--training_record", self.get_path(OS2_TFREC_PTH),
"--model", "crga_os2_unet"]),
"File {} not found !".format(SAVEDMODEL_FILENAME))
ERRMSG)
def test_trainFromTFRecords_os1_unet_all_bands(self):
self.assertTrue(is_savedmodel_written(["--training_record", self.get_path(OS2_ALL_BANDS_TFREC_PTH),
"--model", "crga_os1_unet_all_bands"]),
"File {} not found !".format(SAVEDMODEL_FILENAME))
ERRMSG)
def test_trainFromTFRecords_os2_david_all_bands(self):
self.assertTrue(is_savedmodel_written(["--training_record", self.get_path(OS2_ALL_BANDS_TFREC_PTH),
"--model", "crga_os2_david_all_bands"]),
"File {} not found !".format(SAVEDMODEL_FILENAME))
ERRMSG)
def test_trainFromTFRecords_os2_unet_all_bands(self):
self.assertTrue(is_savedmodel_written(["--training_record", self.get_path(OS2_ALL_BANDS_TFREC_PTH),
"--model", "crga_os2_unet_all_bands"]),
"File {} not found !".format(SAVEDMODEL_FILENAME))
ERRMSG)
def test_trainFromTFRecords_meraner_unet(self):
self.assertTrue(is_savedmodel_written(["--training_record", self.get_path(MERANER_ALL_BANDS_TFREC_PTH),
"--model", "meraner_unet"]),
"File {} not found !".format(SAVEDMODEL_FILENAME))
ERRMSG)
def test_trainFromTFRecords_meraner_unet_all_bands(self):
self.assertTrue(is_savedmodel_written(["--training_record", self.get_path(MERANER_ALL_BANDS_TFREC_PTH),
"--model", "meraner_unet_all_bands"]),
"File {} not found !".format(SAVEDMODEL_FILENAME))
ERRMSG)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment