# -*- coding: utf-8 -*- """ Copyright (c) 2020-2022 INRAE Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ """Classes for Sentinel images access""" import datetime import itertools import json import logging import multiprocessing from abc import ABC, abstractmethod import numpy as np import rtree from scipy import spatial import otbApplication from decloud.core import system from decloud.preprocessing import constants from decloud.core import raster # --------------------------------------------------- Constants -------------------------------------------------------- KEY_S1_BANDS_10M = "s1_10m" KEY_S2_BANDS_10M = "s2_10m" KEY_S2_BANDS_20M = "s2_20m" KEY_S2_CLOUDMASK_10M = "s2_cld_10m" KEY_DEM_20M = "dem_20m" DTYPE = {KEY_S1_BANDS_10M: np.uint16, KEY_S2_BANDS_10M: np.int16, KEY_S2_BANDS_20M: np.int16, KEY_S2_CLOUDMASK_10M: np.uint8, KEY_DEM_20M: np.int16} NB_CHANNELS = {KEY_S1_BANDS_10M: 2, KEY_S2_BANDS_10M: 4, KEY_S2_BANDS_20M: 6, KEY_S2_CLOUDMASK_10M: 1, KEY_DEM_20M: 1} # ---------------------------------------------------- Helpers --------------------------------------------------------- def s1_filename_to_md(filename): """ This function converts the S1 filename into a small dict of metadata :param filename: S1 raster filename :return: dict, a small dict with useful metadata """ basename = filename[filename.rfind("/"):] metadata = dict() splits = basename.split("_") if len(splits) != 7: raise Exception("{} not a S1 image (wrong number of splits between \"_\" in filename)".format(filename)) if len(splits[5]) < 15: raise Exception("{} not a S1 archive (wrong date format)".format(filename)) date_str = splits[5][:15] metadata["tile"] = splits[1] if date_str[9:15] == "xxxxxx": date_str = date_str[0:8] + "t054500" # We should find a way to use the real acquisition time here metadata["date"] = datetime.datetime.strptime(date_str, '%Y%m%dt%H%M%S') metadata["orbit"] = splits[3] metadata["pol"] = splits[2] metadata["filename"] = filename return metadata def get_s1images_in_directory(pth, ref_patchsize, patchsize_10m): """ Returns a list of S1Images instantiated from the designed path :param pth: the path containing the S1 products :param ref_patchsize: reference patch size :param patchsize_10m: size of the 10m-patch :return: a list of S1Image instances """ files = system.get_files(pth, "dB.tif") return [create_s1_image(vvvh_gtiff=fn, ref_patchsize=ref_patchsize, patchsize_10m=patchsize_10m) for fn in files] def create_s1_image(vvvh_gtiff, ref_patchsize, patchsize_10m): """ Instantiate a S1Image from the given GeoTiff. The input image must be 2 channel (vv/vh) processed using sentinel1_prepare.py :param vvvh_gtiff: input geotiff image :param ref_patchsize: reference patch size :param patchsize_10m: size of 10m-patch :return: an S1Image instance """ metadata = s1_filename_to_md(vvvh_gtiff) pth = system.dirname(vvvh_gtiff) # Compute stats edge_stats_fn = system.pathify(pth) + system.new_bname(metadata["filename"], constants.SUFFIX_STATS_S1) compute_patches_stats(image=metadata["filename"], output_stats=edge_stats_fn, expr="im1b1==0&&im1b2==0", patchsize=ref_patchsize) return S1Image(acq_date=metadata["date"], edge_stats_fn=edge_stats_fn, vvvh_fn=metadata["filename"], ascending=metadata["orbit"].lower() == "asc", patchsize_10m=patchsize_10m) def get_s2images_in_directory(pth, ref_patchsize, patchsize_10m, with_cld_mask, with_20m_bands): """ Returns a list of S2Images instantiated from the designed path :param pth: the path containing the prepared Sentinel-2 products :param ref_patchsize: reference patch size :param patchsize_10m: size of the 10m-patch :param with_cld_mask: True/False. If True, the S2Image are instantiated with cloud mask support :param with_20m_bands: True/False. If True, the S2Image are instantiated with 20m spacing bands support :return: a list of S2Image instances """ return [create_s2_image_from_dir(s2_product, ref_patchsize=ref_patchsize, patchsize_10m=patchsize_10m, with_cld_mask=with_cld_mask, with_20m_bands=with_20m_bands) for s2_product in system.get_directories(pth)] def s2_filename_to_md(filename): """ This function converts the S2 filename into a small dict of metadata :param filename: :return: dict """ basename = system.basename(filename) metadata = dict() splits = basename.split("_") if len(splits) < 4: raise Exception("{} might not be a S2 product".format(filename)) metadata["tile"] = splits[3] datestr = splits[1] metadata["date"] = datetime.datetime.strptime(datestr[:-1], '%Y%m%d-%H%M%S-%f') return metadata def compute_patches_stats(image, output_stats, patchsize, expr=""): """ Run the "SquarePatchesSelection" OTB application over the input image, only if the output file does not exist :param image: input image(s). Either a string, or a string list. For string list, the "expr" parameter must be set :param output_stats: output image :param expr: BandMath expression (optional for 1 input image, mandatory for multiple input image) :param patchsize: the patch size """ logging.debug("Computing stats for %s. Result will be stored in %s.", image, output_stats) if system.is_complete(output_stats): logging.debug("File %s already exists. Skipping.", output_stats) else: app = otbApplication.Registry.CreateApplication("SquarePatchesSelection") if expr: thresh = otbApplication.Registry.CreateApplication("BandMath") if isinstance(image, str): thresh.SetParameterStringList("il", [image]) else: thresh.SetParameterStringList("il", image) thresh.SetParameterString("exp", expr) thresh.Execute() app.SetParameterInputImage("in", thresh.GetParameterOutputImage("out")) else: if not isinstance(image, str): raise Exception("\"image\" must be of type str, if no expr is provided!") app.SetParameterString("in", image) app.SetParameterInt("patchsize", patchsize) app.SetParameterString("out", "{}?&gdal:co:COMPRESS=DEFLATE".format(output_stats)) app.SetParameterOutputImagePixelType("out", otbApplication.ImagePixelType_uint16) app.ExecuteAndWriteOutput() system.declare_complete(output_stats) def create_s2_image_from_dir(s2_product_dir, ref_patchsize, patchsize_10m, with_cld_mask, with_20m_bands): """ Create a S2Image instance from one S2 product :param s2_product_dir: directory containing: SENTINEL2A_20170209-103304-620_L2A_T31TEJ_D_V1-4_FRE_10m.tif SENTINEL2A_20170209-103304-620_L2A_T31TEJ_D_V1-4_FRE_20m.tif SENTINEL2A_20170209-103304-620_L2A_T31TEJ_D_V1-4_CLM_R1.tif SENTINEL2A_20170209-103304-620_L2A_T31TEJ_D_V1-4_EDG_R1.tif :param ref_patchsize: patch size to compute statistics :param patchsize_10m: patch size used at the 10m spacing resolution :param with_cld_mask: True/False. If True, the S2Image are instantiated with cloud mask support :param with_20m_bands: True/False. If True, the S2Image are instantiated with 20m spacing bands support :return: an S2Image instance """ logging.debug("Processing %s", s2_product_dir) files = system.get_files(s2_product_dir, ext=".tif") edg_mask, cld_mask, b10m_imgs, b20m_imgs = None, None, None, None for file in files: if "EDG_R1.tif" in file: edg_mask = file if "CLM_R1.tif" in file: cld_mask = file if "FRE_10m.tif" in file: b10m_imgs = file if "FRE_20m.tif" in file: b20m_imgs = file # Check that files exists def _check(title, filename): if filename is None: raise Exception("File for {} does not exist in product {}".format(title, s2_product_dir)) _check("edge mask", edg_mask) _check("cloud mask", cld_mask) _check("10m bands stack", b10m_imgs) _check("20m bands stack", b20m_imgs) # Print infos logging.debug("Cloud mask:\t%s", cld_mask) logging.debug("Edge mask:\t%s", edg_mask) logging.debug("Channels:") logging.debug("\t10m bands: %s", b10m_imgs) logging.debug("\t20m bands: %s", b20m_imgs) # Compute stats clouds_stats_fn = system.pathify(s2_product_dir) + system.new_bname(cld_mask, constants.SUFFIX_STATS_S2) edge_stats_fn = system.pathify(s2_product_dir) + system.new_bname(edg_mask, constants.SUFFIX_STATS_S2) compute_patches_stats(image=cld_mask, output_stats=clouds_stats_fn, expr="im1b1>0", patchsize=ref_patchsize) compute_patches_stats(image=edg_mask, output_stats=edge_stats_fn, patchsize=ref_patchsize) # Return a s2 image class metadata = s2_filename_to_md(system.pathify(s2_product_dir)) return S2Image(acq_date=metadata["date"], edge_stats_fn=edge_stats_fn, bands_10m_fn=b10m_imgs, bands_20m_fn=b20m_imgs if with_20m_bands is True else None, cld_mask_fn=cld_mask if with_cld_mask is True else None, clouds_stats_fn=clouds_stats_fn, patchsize_10m=patchsize_10m) # --------------------------------------------- Patch reader class ----------------------------------------------------- class PatchReader: """ A patch reader """ def __init__(self, filename, psz, dtype, gdal_cachemax="32"): """ Initializer :param filename: The image filename :param psz: The patch size """ raster.set_gdal_cachemax(gdal_cachemax) # Set GDAL DS self.gdal_ds = raster.gdal_open(filename) # Set GDAL GeoTransform self.ulx, self.resolution_x, _, self.uly, _, self.resolution_y = self.gdal_ds.GetGeoTransform() # Set patches sizes self.patch_size = psz # dtype self.dtype = dtype def get(self, patch_location): """ Read a patch as numpy array :return A numpy array """ # Read array myarray = self.gdal_ds.ReadAsArray(patch_location[0] * self.patch_size, patch_location[1] * self.patch_size, self.patch_size, self.patch_size) # Re-order bands (when there is > 1 band) if len(myarray.shape) == 3: axes = (1, 2, 0) myarray = np.transpose(myarray, axes=axes) else: myarray = np.expand_dims(myarray, axis=2) return myarray.astype(self.dtype) def get_geographic_info(self, patch_location): """ Get the geographic info of a patch :param patch_location: tuple :return the coordinates of the bounding box (Upper left and Lower Right), in lat/lon 4326 coordinate system """ # Getting the Upper Left info of the patch patch_ulx = self.ulx + patch_location[0] * self.resolution_x * self.patch_size patch_uly = self.uly + patch_location[1] * self.resolution_y * self.patch_size # Deduce the Lower Right patch_lrx = patch_ulx + self.patch_size * self.resolution_x patch_lry = patch_uly + self.patch_size * self.resolution_y # Convert to 4326 patch_ul_lon, patch_ul_lat = raster.convert_to_4326((patch_ulx, patch_uly), self.gdal_ds) patch_lr_lon, patch_lr_lat = raster.convert_to_4326((patch_lrx, patch_lry), self.gdal_ds) return patch_ul_lon, patch_ul_lat, patch_lr_lon, patch_lr_lat # (lon, lat) is the standard for GeoJSON # ---------------------------------------------- Image base classes ---------------------------------------------------- class AbstractImage(ABC): """ Abstract class for images """ @abstractmethod def __init__(self): self.patch_sources = dict() def get_patch(self, key, patch_location): """ Returns one patch from the selected patch_source at the given location :param key: patch source key :param patch_location: the patch location :return: a numpy array """ if key not in self.patch_sources: raise Exception("Key {} not in patches sources. Available sources keys: {}".format(key, self.patch_sources)) return self.patch_sources[key].get(patch_location=patch_location) @abstractmethod def get(self, patch_location): """ Returns all existing data :param patch_location: the patch location :return: a numpy array """ class SentinelImage(AbstractImage): """ Abstract class for Sentinel images """ @abstractmethod def __init__(self, acq_date, edge_stats_fn, patchsize_10m): """ :param acq_date: date of the Sentinel image (datetime.datetime) :param edge_stats_fn: filename of the edges stats image :param patchsize_10m: size of the 10m-patches """ super().__init__() self.acq_date = acq_date self.edge_stats_fn = edge_stats_fn self.edge_stats = raster.read_as_np(edge_stats_fn) self.patchsize_10m = patchsize_10m self.timestamp = self.acq_date.replace(tzinfo=datetime.timezone.utc).timestamp() def get_timestamp(self): """ Returns the timestamps (in seconds) :return: int """ return self.timestamp def get(self, patch_location): return {"timestamp": np.asarray(self.get_timestamp())} # ------------------------------------------------- DEM image class ---------------------------------------------------- class SRTMDEMImage(AbstractImage): """ DEM image class. Handles a single raster access. """ def __init__(self, raster_20m_filename, patchsize_20m): """ :param raster_20m_filename: filename of the 20m-spacing raster :param patchsize_20m: patch size (the actual patch is 20m resolution) """ super().__init__() self.raster_20m_filename = raster_20m_filename self.patchsize_20m = patchsize_20m # Patches sources self.patch_sources[KEY_DEM_20M] = PatchReader(filename=self.raster_20m_filename, psz=self.patchsize_20m, dtype=DTYPE[KEY_DEM_20M]) def get(self, patch_location): return {constants.DEM_KEY: self.patch_sources[KEY_DEM_20M].get(patch_location=patch_location)} # ------------------------------------------- Sentinel images classes -------------------------------------------------- class S2Image(SentinelImage): """ Sentinel-2 image class. Keeps Sentinel-2 product metadata, provide an access to image patches. """ def __init__(self, acq_date, edge_stats_fn, bands_10m_fn, clouds_stats_fn, patchsize_10m, bands_20m_fn=None, cld_mask_fn=None): super().__init__(acq_date=acq_date, edge_stats_fn=edge_stats_fn, patchsize_10m=patchsize_10m) self.bands_10m_fn = bands_10m_fn self.bands_20m_fn = bands_20m_fn self.cld_mask_fn = cld_mask_fn self.clouds_stats_fn = clouds_stats_fn self.clouds_stats = raster.read_as_np(clouds_stats_fn) # Prepare patches sources self.patch_sources[KEY_S2_BANDS_10M] = PatchReader(filename=self.bands_10m_fn, psz=self.patchsize_10m, dtype=DTYPE[KEY_S2_BANDS_10M]) if self.bands_20m_fn is not None: self.patch_sources[KEY_S2_BANDS_20M] = PatchReader(filename=self.bands_20m_fn, psz=int(self.patchsize_10m / 2), dtype=DTYPE[KEY_S2_BANDS_20M]) if self.cld_mask_fn is not None: self.patch_sources[KEY_S2_CLOUDMASK_10M] = PatchReader(filename=self.cld_mask_fn, psz=self.patchsize_10m, dtype=DTYPE[KEY_S2_CLOUDMASK_10M]) def get(self, patch_location): ret = {"s2_timestamp": np.asarray(self.get_timestamp())} ret.update({"s2": self.patch_sources[KEY_S2_BANDS_10M].get(patch_location=patch_location)}) if self.bands_20m_fn is not None: ret.update({"s2_20m": self.patch_sources[KEY_S2_BANDS_20M].get(patch_location=patch_location)}) if self.cld_mask_fn is not None: ret.update({"s2_cld10m": self.patch_sources[KEY_S2_CLOUDMASK_10M].get(patch_location=patch_location)}) return ret class S1Image(SentinelImage): """ Sentinel-1 image class. Keeps Sentinel-1 product metadata, provide an access to image patches. """ def __init__(self, acq_date, edge_stats_fn, vvvh_fn, ascending, patchsize_10m): super().__init__(acq_date=acq_date, edge_stats_fn=edge_stats_fn, patchsize_10m=patchsize_10m) self.vvvh_fn = vvvh_fn self.ascending = ascending # Prepare patches sources self.patch_sources[KEY_S1_BANDS_10M] = PatchReader(filename=self.vvvh_fn, psz=self.patchsize_10m, dtype=DTYPE[KEY_S1_BANDS_10M]) def get(self, patch_location): ret = {"s1_timestamp": np.asarray(self.get_timestamp())} ret.update({"s1_ascending": np.asarray(self.ascending)}) ret.update({"s1": self.patch_sources[KEY_S1_BANDS_10M].get(patch_location=patch_location)}) return ret # ---------------------------------------------- Tile Handler class ---------------------------------------------------- class TileHandler: """ TilesHandler performs every I/O operations, build indexation structures in one S2 tile """ @staticmethod def new_bbox(timeframe_low, timeframe_hi, cld_cov_min, cld_cov_max, validity, closest_s1_gap_min, closest_s1_gap_max): """ Return a bounding box in the domain (Time, Cloud coverage, Validity, Closest s1 temporal gap) """ return (timeframe_low, cld_cov_min, validity, closest_s1_gap_min, timeframe_hi, cld_cov_max, validity, closest_s1_gap_max) def for_each_pos(self, apply_fn): """ Iterate over every (pos_x, pos_y) positions and runs "apply_fn(pos)" with pos = (pos_x, pos_y) :param apply_fn: The function to call for each pos. Must have a single argument, "pos" a tuple (pos_x, pos_y) :return: nothing """ for pos_x in range(self.grid_size_x): for pos_y in range(self.grid_size_y): pos = (pos_x, pos_y) apply_fn(pos) def find_s2(self, pos, timeframe_low, timeframe_hi, cld_cov_min, cld_cov_max, validity, closest_s1_gap_max): """ Return all candidates that intersect the bounding box in the domain (Time, Cloud coverage, Validity, Closest s1 temporal gap) """ if closest_s1_gap_max is None: closest_s1_gap_max = self.max_distance bbox_search = self.new_bbox(timeframe_low, timeframe_hi, cld_cov_min, cld_cov_max, validity, closest_s1_gap_min=0, closest_s1_gap_max=closest_s1_gap_max) return self.s2_trees[pos].intersection(bbox_search) def __init__(self, s1_dir, s2_dir, patchsize_10m, tile, dem_20m=None, with_s2_cldmsk=False, with_20m_bands=False): """ TileHandler delivers patches for one given tile. Patches are delivered through the read_tuple() function, in the form of a dict with the following structure: {"s2_someKey1": np.array([..]), "s1_someKey1": np.array([..]), ..., "dem": np.array([..])} :param s1_dir: The directory where the tiled/uint16 S1 images are stored (these images have been processed with `preprocessing/sentinel1_prepare.py`) :param s2_dir: The directory where the tiled/int16 S2 images are stored (these images have been processed with `preprocessing/sentinel2_prepare.py`) :param patchsize_10m: The patch size for the 10m resolution. Must be a multiple of 64. :param tile: the name of the tile, e.g. 'T31TCJ' :param dem_20m: optional raster for the 20m spacing DEM. If value is None, it is not delivered. :param with_s2_cldmsk: True or False. True: the cloud mask patches are delivered :param with_20m_bands: True or False. True: the 20m-spacing bands patches are delivered """ self.tile = tile # This is the size of a patch of the 10m bands of Sentinel-2. self.patchsize_10m = patchsize_10m # List S1 images if s1_dir is not None: self.s1_images = get_s1images_in_directory(pth=s1_dir, ref_patchsize=constants.PATCHSIZE_REF, patchsize_10m=self.patchsize_10m) self.s1_images.sort(key=lambda x: x.acq_date) logging.info("Found %i S1 images in %s", len(self.s1_images), s1_dir) # List S2 images self.s2_images = get_s2images_in_directory(pth=s2_dir, ref_patchsize=constants.PATCHSIZE_REF, patchsize_10m=self.patchsize_10m, with_cld_mask=with_s2_cldmsk, with_20m_bands=with_20m_bands) self.s2_images.sort(key=lambda x: x.acq_date) logging.info("Found %i S2 images in %s", len(self.s2_images), s2_dir) # Get grid size gdal_ds = raster.gdal_open(self.s2_images[0].edge_stats_fn) self.grid_size_x = int(gdal_ds.RasterXSize * constants.PATCHSIZE_REF / self.patchsize_10m) self.grid_size_y = int(gdal_ds.RasterYSize * constants.PATCHSIZE_REF / self.patchsize_10m) # Index images # Create one grid of cloud coverage, and one grid of nodatas. # The grid has cells of size (patch_size) def _index(sx_images, read_fn, process_fn): """ This function is used to build a numpy array of shape (N, grid_size_x, grid_size_y) that store for each patch some useful stuff. :param sx_images: a list of SentinelImage objects (either S1Image objects or S2Image objects) :param read_fn: the function used to retrieve the raster file that is read as a numpy array :param process_fn: the function used to process the value retrieved from the sub numpy array :return: A numpy array of shape (N, grid_size_x, grid_size_y) """ output = np.zeros((len(sx_images), self.grid_size_x, self.grid_size_y)) for sx_image_idx, sx_image in enumerate(sx_images): gdal_ds = raster.gdal_open(read_fn(sx_image)) image_as_np = gdal_ds.ReadAsArray() def compute_value(pos, full_arr=image_as_np, image_idx=sx_image_idx): sub_arr = raster.get_sub_arr(full_arr, patch_location=pos, patch_size=self.patchsize_10m, ref_patch_size=constants.PATCHSIZE_REF) value = process_fn(sub_arr) pos_x, pos_y = pos output[image_idx, pos_x, pos_y] = value self.for_each_pos(compute_value) return output def _reject_no_data(np_arr): """ Returns the map of patches validity """ return np.amax(np_arr) == 0 def _get_edge_stats_fn(sx_image): """ Returns the edge statistics raster file name """ return sx_image.edge_stats_fn def _average_cloud_coverage_values(np_arr): """ Returns the map of average cloud percentage inside patches """ return np.mean(100.0 * np_arr / (constants.PATCHSIZE_REF * constants.PATCHSIZE_REF)) def _get_clouds_stats_fn(s2_image): """ Returns the clouds statistics raster file name """ return s2_image.clouds_stats_fn def _print_np_stats(np_arr, title="some"): """ Print some statistics of the input numpy array """ msg = "{} stats: Shape={}, Min={:.2f}, Max={:.2f}, Mean={:.2f}, Standard deviation={:.2f}".format( title, np_arr.shape, np.amin(np_arr), np.amax(np_arr), np.mean(np_arr), np.std(np_arr)) logging.info(msg) if s1_dir is not None: logging.info("Computing S1 patches statistics") self.s1_images_validity = _index(self.s1_images, read_fn=_get_edge_stats_fn, process_fn=_reject_no_data) _print_np_stats(self.s1_images_validity, "Validity") logging.info("Computing S2 patches statistics") self.s2_images_validity = _index(self.s2_images, read_fn=_get_edge_stats_fn, process_fn=_reject_no_data) _print_np_stats(self.s2_images_validity, "Validity") self.s2_images_cloud_coverage = _index(self.s2_images, read_fn=_get_clouds_stats_fn, process_fn=_average_cloud_coverage_values) _print_np_stats(self.s2_images_cloud_coverage, "Cloud coverage") # Build a dict() of the closest s1_image for each pos, and for each s2_image # The dict structure is like: self.closest_s1[pos][s2_idx] class Closest: """ Simple class to store index/distance to compute the closest image """ def __init__(self, index, distance): self.index = index self.distance = distance def update(self, other): """ Update the closest one """ if other.distance < self.distance: self.distance = other.distance self.index = other.index self.closest_s1 = dict() self.max_distance = 10 * 12 * 31 * 24 * 3600 # Maximum distance to search if s1_dir is not None: # Build KDTree to index s1 images timestamps logging.info("Build KDTrees") s1_timestamps_kdtrees = dict() s1_timestamps_indices = dict() def build_kdtree(pos): """ Build a KDTree at the specified location (pos_x, pos_y) This function modifies: s1_timestamps_kdtrees s1_timestamps_indices """ timestamps = [] timestamps_index = [] pos_x, pos_y = pos for s1_index, s1_image in enumerate(self.s1_images): validity = self.s1_images_validity[s1_index, pos_x, pos_y] if validity: timestamps.append(s1_image.get_timestamp()) timestamps_index.append(s1_index) s1_timestamps_kdtrees[pos] = spatial.KDTree(list(zip(np.asarray(timestamps).ravel()))) s1_timestamps_indices[pos] = timestamps_index self.for_each_pos(build_kdtree) def find_closest_s1_image(pos): """ Find the closest s1 image for each s2 images, at the specified location (pos_x, pos_y) This function modifies: self.closest_s1 """ closest_s1 = {} for s2_idx, s2_image in enumerate(self.s2_images): timestamp_query = np.array([s2_image.get_timestamp()]) _value, timestamp_idx = s1_timestamps_kdtrees[pos].query(timestamp_query) s1_idx = s1_timestamps_indices[pos][timestamp_idx] if _value < self.max_distance: closest = Closest(index=s1_idx, distance=_value) if s2_idx not in closest_s1: closest_s1[s2_idx] = closest else: closest_s1[s2_idx].update(closest) self.closest_s1[pos] = closest_s1 self.for_each_pos(find_closest_s1_image) else: # When S1 images aren't used # Closest S1 dict is composed of virtual S1 images def set_virtual_closest_s1_image(pos): """ Set a virtual S1 image very close to each S2 image """ self.closest_s1[pos] = {s2_idx: Closest(index=-1, distance=0) for s2_idx, _ in enumerate(self.s2_images)} self.for_each_pos(set_virtual_closest_s1_image) # Build RTrees (Cloud_coverage, Date, Validity, Closest S1 (timestamp)) logging.info("Build RTrees (Cloud_coverage, Date, Validity, Closest S1)") properties = rtree.index.Property() properties.dimension = 4 self.s2_trees = dict() def build_rtree(pos): """ Build a RTree for the specified location pos=(pos_x, pos_y) This function modifies: self.s2_trees """ closest_s1 = self.closest_s1[pos] self.s2_trees[pos] = rtree.index.Index(properties=properties) pos_x, pos_y = pos for s2_image_idx, s2_image in enumerate(self.s2_images): timestamp = s2_image.get_timestamp() cld_cov_value = self.s2_images_cloud_coverage[s2_image_idx, pos_x, pos_y] validity_value = self.s2_images_validity[s2_image_idx, pos_x, pos_y] closest_s1_gap = closest_s1[s2_image_idx].distance if s2_image_idx in closest_s1 else self.max_distance bbox = self.new_bbox(timeframe_low=timestamp, timeframe_hi=timestamp, cld_cov_min=cld_cov_value, cld_cov_max=cld_cov_value, validity=validity_value, closest_s1_gap_min=closest_s1_gap, closest_s1_gap_max=self.max_distance) self.s2_trees[pos].insert(s2_image_idx, bbox) self.for_each_pos(build_rtree) # Reading lock self.read_lock = multiprocessing.Lock() # Setup DEM self.dem_image = None if dem_20m is None else SRTMDEMImage(raster_20m_filename=dem_20m, patchsize_20m=int(self.patchsize_10m / 2)) logging.info("Done") def tuple_search(self, acquisitions_layout, roi=None): """ The function performs a search of every tuples of patches that fulfil the acquisition layout. :param acquisitions_layout: the acquisition layout that specify how the scenes are acquired :param roi: the roi (geotiff file name) describes for each reference patch is the patch has to be used or not. Basically the roi.tif results in the rasterization of one vector layer over the raster grid formed by the reference patch size over the Sentinel-2 image (e.g. 640m spacing is the PATCHSIZE_REF is 64 pixels) :return: the tuples, stored in a dict() """ roi_np = None if roi is None else raster.read_as_np(roi) # A function that returns True is the (new_elem, y) pos is inside the ROI def is_inside_roi(patch_location): if roi_np is not None: sub_np_arr = raster.get_sub_arr(roi_np, patch_size=self.patchsize_10m, patch_location=patch_location, ref_patch_size=constants.PATCHSIZE_REF) if np.amin(sub_np_arr) == 0: # if there is at least one cell with "0": it is not entirely inside return False return True # Summarize acquisition layout logging.info("Tile %s, seeking the following acquisition layout:", self.tile) acquisitions_layout.summarize() # To fetch timestamp origin in acquisitions layout: # acquisition_ref_key = acquisitions_layout.get_ref_name() # Begin filtering acquisition_candidates_grid = dict() def collect(pos): """ Collect the samples that match with the search criterion at the given pos=(pos_x, pos_y) This function modifies: acquisition_candidates_grid """ if is_inside_roi(pos): acquisition_candidates_grid[pos] = [] def _filter(acquisition_name, ref_timestamp): """ Function that filter from the available s2 images, given the acquisition and the timestamp """ s2_acquisition = acquisitions_layout.get_s2_acquisition(acquisition_name) # Cloud coverage cld_cov_max = s2_acquisition.max_cloud_percent if (isinstance(s2_acquisition.min_cloud_percent, str) and s2_acquisition.min_cloud_percent.startswith('random')): cld_cov_min = min(eval(s2_acquisition.min_cloud_percent), cld_cov_max) elif isinstance(s2_acquisition.min_cloud_percent, (int, float)): cld_cov_min = s2_acquisition.min_cloud_percent else: raise Exception('Wrong format for min cloud percent, must be a number or random.[whatever]') # Timestamp window timeframe_begin, timeframe_end = acquisitions_layout.get_timestamp_range(acquisition_name) # If the S1S2 gap isn't defined, it is because we don't need one S1 image. # So we just put None, so that the find_s2() knows that the timestamp delta is not a filtering # criterion. closest_s1_gap_max = acquisitions_layout.get_s1s2_max_timestamp_delta(acquisition_name) # Call to RTree query result = self.find_s2(pos=pos, validity=1, cld_cov_min=cld_cov_min, cld_cov_max=cld_cov_max, timeframe_low=ref_timestamp + timeframe_begin, timeframe_hi=ref_timestamp + timeframe_end, closest_s1_gap_max=closest_s1_gap_max) return list(result) for idx, s2_image in enumerate(self.s2_images): acquisition_candidates = dict() ref_timestamp = s2_image.get_timestamp() for acquisition_name in acquisitions_layout: # Here we check that the images fulfill the constraints ret = _filter(acquisition_name=acquisition_name, ref_timestamp=ref_timestamp) if idx in ret and \ not acquisitions_layout.is_siblings(acquisition_candidates.keys(), acquisition_name) \ and idx in acquisition_candidates.values(): ret.remove(idx) # if acquisition_name == acquisition_ref_key: # assert (len(ret) <= 1) if len(ret) == 0: break acquisition_candidates[acquisition_name] = ret # We check that we have candidates for each key. If not, we skip. if len(acquisition_candidates) == len(acquisitions_layout): acquisition_candidates_grid[pos].append(acquisition_candidates) self.for_each_pos(collect) # Add s1/s2 key candidates_grid = dict() for pos, candidates in acquisition_candidates_grid.items(): for candidate in candidates: new_candidate = dict() for key, values in candidate.items(): # key: "s_t-1", "s_t", "s_t+1", ... # values : 45, 48, 49, ... new_val = [] for value in values: new_entry = {"s2": value} if acquisitions_layout.has_s1_acquisition(key): if value in self.closest_s1[pos]: closest_s1_idx = self.closest_s1[pos][value].index new_entry.update({"s1": closest_s1_idx}) new_val.append(new_entry) new_candidate[key] = new_val if pos not in candidates_grid: candidates_grid[pos] = [new_candidate] else: candidates_grid[pos].append(new_candidate) # Should be: # candidates_grid[(0, 0)] = [{"t-1": [{"s2": 45}], # "t": [{"s1:" 12, "s2": 47}], # "t+1": [{"s2": 48}, {"s2": 49}]}, # ...] # Generate every possible combinations from candidates tuples_grid = dict() # convert acquisition_candidates_grid --> tuples_grid # The structure of tuples_grid should be: # # tuples_grid[(0,0)] = [{"t-1": {"s2": 5}, # "t": {"s1": 3, "s2": 7}, # "t+1": {"s2": 11}}, # ... # {"t-1": {"s2": 345}, # "t": {"s1": 453, "s2": 344}, # "t+1": {"s2": 346}}] # tuples_grid[(0,1)] = [...] # ... # tuples_grid[(n,n)] = [...] index = acquisitions_layout.keys() for pos, candidates in candidates_grid.items(): tuples_grid[pos] = [dict(zip(index, list(x))) for candidate in candidates for x in list(itertools.product(*list(candidate.values())))] nb_samples = sum(len(lst) for lst in tuples_grid.values()) logging.info("Tile %s, found %s samples satisfying the acquisition layout", self.tile, nb_samples) return tuples_grid def read_tuple(self, tuple_pos, tuple_indices): """ Read a tuple of Sentinel images patches. :param tuple_pos: the tuple position in the tile_handler grid. The position is a tuple (pos_x, pos_y) e.g. tuple_pos = (23, 41) :param tuple_indices: the tuple indices. e.g. tuple_indices = {"t-1": {"s2": 345}, "t": {"s1": 453, "s2": 344}, "t+1": {"s2": 346}} :return: new_sample, a dict of the acquisition layout keys prefixed with "s1" or "s2" depending on the sensor, with the numpy arrays. If the DEM is used, the {"dem": <numpy_array@0x..>} item updates the returned dict. e.g. new_sample = {"s2_t-1": <numpy_array@0x..>, "s1_t": <numpy_array@0x..>, "s2_t": <numpy_array@0x..>, "s2_t+1": <numpy_array@0x..>} """ with self.read_lock: new_sample = dict() # fill the sample with s1/s2 keys for key, values in tuple_indices.items(): for sx_key, sx_idx in values.items(): if sx_key == "s1": src = self.s1_images[sx_idx] elif sx_key == "s2": src = self.s2_images[sx_idx] # Add the geographic info new_sample['geoinfo'] = \ src.patch_sources[KEY_S2_BANDS_10M].get_geographic_info(patch_location=tuple_pos) else: raise Exception("Unknown key {}!".format(sx_key)) src_dict = src.get(patch_location=tuple_pos) for src_key, src_np_arr in src_dict.items(): # the final key is composed in concatenating key, "_", src_key new_sample[src_key + "_" + key] = src_np_arr # update the sample with the DEM if self.dem_image is not None: new_sample.update(self.dem_image.get(patch_location=tuple_pos)) return new_sample # ---------------------------------------------- Tiles loader class ---------------------------------------------------- class TilesLoader(dict): """ A class that instantiate some TileHandler objects from a json file Keys: - "S1_ROOT_DIR": str (Optional) - "S2_ROOT_DIR": str - "DEM_ROOT_DIR": str (Optional) - "TILES": list Example of a .json file: { "S1_ROOT_DIR": "/data/decloud/S1_PREPARE", "S2_ROOT_DIR": "/data/decloud/S2_PREPARE", "DEM_ROOT_DIR": "/data/decloud/DEM_PREPARE", "TILES": ["T31TCK", "T31TDJ", "T31TEJ"] } """ def __init__(self, the_json, patchsize_10m, with_20m_bands=False): """ :param the_json: The .json file :param patchsize_10m: Patches size (64, 128, 256, ...) must be a multiple of 64 :param with_20m_bands: True or False. True: the 20m-spacing bands patches are delivered """ super().__init__() logging.info("Loading tiles from %s", the_json) with open(the_json) as json_file: data = json.load(json_file) def get_pth(key): """ Retrieve the path :param key: path key :return: the path value """ if key in data: value = data[key] assert isinstance(value, str) return system.pathify(value) return None # Paths self.s1_tiles_root_dir = get_pth("S1_ROOT_DIR") self.s2_tiles_root_dir = get_pth("S2_ROOT_DIR") self.dem_tiles_root_dir = get_pth("DEM_ROOT_DIR") if self.s2_tiles_root_dir is None: raise Exception("S2_ROOT_DIR key not found in {}".format(the_json)) # Tiles list self.tiles_list = data["TILES"] if self.tiles_list is None: raise Exception("TILES key not found in {}".format(the_json)) if not isinstance(self.tiles_list, list): raise Exception("TILES value must be a list of strings!") # Instantiate tile handlers for tile in self.tiles_list: def _get_tile_pth(root_dir, current_tile=tile): """ Returns the directory for the current tile """ if root_dir is not None: return root_dir + current_tile return None s1_dir = _get_tile_pth(self.s1_tiles_root_dir) s2_dir = _get_tile_pth(self.s2_tiles_root_dir) dem_tif = _get_tile_pth(self.dem_tiles_root_dir) if dem_tif is not None: dem_tif += ".tif" logging.info("Creating TileHandler for \"%s\"", tile) tile_handler = TileHandler(s1_dir=s1_dir, s2_dir=s2_dir, dem_20m=dem_tif, patchsize_10m=patchsize_10m, with_20m_bands=with_20m_bands, tile=tile) self.update({tile: tile_handler})