diff --git a/OBIA/segmentation.py b/OBIA/segmentation.py index 752455b590d0d49f33c021f2390efcd5b9cda611..098496179a3a9d45e8efa0ea7edb2091e36216f7 100644 --- a/OBIA/segmentation.py +++ b/OBIA/segmentation.py @@ -1,13 +1,16 @@ import sys + +import rasterio +from osgeo import gdal from Common.otb_numpy_proc import to_otb_pipeline import os.path import otbApplication as otb import multiprocessing as mp -from math import sqrt, floor +from math import sqrt, floor, ceil from itertools import product, chain from dataclasses import dataclass import psutil -from skimage.measure import regionprops +from skimage.measure import regionprops, label from tqdm import tqdm import numpy as np @@ -22,6 +25,77 @@ class LSGRMParams: n_first_iter: int = 8 margin: int = 100 +def grm_process_tile(input_image, params : LSGRMParams, tile_width, tile_height, tile_idx, out_img, roi): + + if roi is None: + in_img = to_otb_pipeline(input_image) + elif os.path.exists(roi): + in_img = otb.Registry.CreateApplication('ExtractROI') + in_img.SetParameterString('in', input_image) + in_img.SetParameterString('mode', 'fit') + in_img.SetParameterString('mode.fit.vect', roi) + in_img.Execute() + else: + print('ROI provided but cannot find file.') + sys.exit(-1) + + W, H = in_img.GetImageSize('out') + + startx = 0 if tile_idx[0] == 0 else tile_idx[0] * tile_width - params.margin + starty = 0 if tile_idx[1] == 0 else tile_idx[1] * tile_height - params.margin + endx = W if tile_width * (tile_idx[0] + 1) >= W else tile_width * (tile_idx[0] + 1) + params.margin + endy = H if tile_height * (tile_idx[1] + 1) >= H else tile_height * (tile_idx[1] + 1) + params.margin + + tile = otb.Registry.CreateApplication('ExtractROI') + tile.SetParameterInputImage('in', in_img.GetParameterOutputImage('out')) + tile.SetParameterInt('startx', startx) + tile.SetParameterInt('starty', starty) + tile.SetParameterInt('sizex', endx - startx) + tile.SetParameterInt('sizey', endy - starty) + tile.Execute() + + out_fn = '{}_{}_{}.tif'.format(os.path.splitext(out_img)[0], tile_idx[1], tile_idx[0]) + + seg_app = otb.Registry.CreateApplication('GenericRegionMerging') + seg_app.SetParameterInputImage('in', tile.GetParameterOutputImage('out')) + seg_app.SetParameterFloat('threshold', params.threshold) + seg_app.SetParameterFloat('cw', params.color_weight) + seg_app.SetParameterFloat('sw', params.spatial_weight) + seg_app.SetParameterInt('niter', params.n_first_iter) + seg_app.Execute() + + tie_lines = (0 if starty == 0 else params.margin, + 0 if startx == 0 else params.margin, + 0, 0) + + tie_lines = (tie_lines[0], tie_lines[1], + H if endy == H else tile_height + tie_lines[0], + W if endx == W else tile_width + tie_lines[1]) + + seg = seg_app.ExportImage('out') + op = regionprops(seg['array'][:,:,0].astype(np.int32)) + to_del = [] + for o in op: + if not (tie_lines[0] <= o.bbox[0] < tie_lines[2] + and tie_lines[1] <= o.bbox[1] < tie_lines[3]): + # This one to check if a potentially deleted object trespass the overlap area, + # in which case it is kept. + if not (2*params.margin <= o.bbox[2] < tile_height and + 2*params.margin <= o.bbox[3] < tile_width): + to_del.append(o.label) + seg['array'][np.isin(seg['array'], to_del)] = 0 + seg['array'], nlab = label(seg['array'], background=0, connectivity=1, return_num=True) + seg['array'] = np.ascontiguousarray(seg['array']).astype(float) + + out_img = otb.Registry.CreateApplication('ExtractROI') + out_img.ImportImage('in', seg) + out_img.SetParameterString('out', out_fn) + out_img.SetParameterOutputImagePixelType('out', otb.ImagePixelType_uint32) + out_img.ExecuteAndWriteOutput() + + return out_fn, nlab + + def lsgrm_process_tile(input_image, params : LSGRMParams, tile_width, tile_height, tile_idx, out_graph, roi=None): if roi is None: in_img = to_otb_pipeline(input_image) @@ -52,6 +126,113 @@ def lsgrm_process_tile(input_image, params : LSGRMParams, tile_width, tile_heigh out_graph + '_edge_{}_{}.bin'.format(tile_idx[1], tile_idx[0]), out_graph + '_edgeMargin_{}_{}.bin'.format(tile_idx[1], tile_idx[0])] +def get_ls_seg_parameter(input_image, roi=None, margin=0, n_proc=None, memory=None, force_parallel=False): + # Define default number of threads (half) and memory amount (3/4 of available) + if n_proc is None: + n_proc = round(mp.cpu_count() / 2) + if memory is None: + memory = round(psutil.virtual_memory().available * 0.75) + else: + memory *= 1e6 + + if roi is None: + in_img = to_otb_pipeline(input_image) + elif os.path.exists(roi): + in_img = otb.Registry.CreateApplication('ExtractROI') + in_img.SetParameterString('in', input_image) + in_img.SetParameterString('mode', 'fit') + in_img.SetParameterString('mode.fit.vect', roi) + in_img.Execute() + else: + print('ROI provided but cannot find file.') + sys.exit(-1) + + # Get image size + W, H = in_img.GetImageSize('out') + + # adapt memory amount to force fitting the number of cores + if force_parallel and memory > (W * H * node_size) and W > margin and H > margin: + memory = W * H * node_size + + # Compute optimal tile size + T = (memory / n_proc) / node_size # approx. pixels per tile with margins + tile_width_wmarg, tile_height_wmarg = floor(sqrt(T * W / H)), floor(sqrt(T * H / W)) + nominal_tw, nominal_th = tile_width_wmarg - 2 * margin, tile_height_wmarg - 2 * margin + n_tiles_x, n_tiles_y = max(1, ceil(W / nominal_tw)), max(1, ceil(H / nominal_th)) + + return in_img, n_tiles_x, n_tiles_y, nominal_tw, nominal_th + +def lsgrm_light(input_image, params : LSGRMParams, out_seg, n_proc=None, memory=None, + roi=None, force_parallel=False, remove_tiles_if_useless=True): + # Check output file type + ext = os.path.splitext(out_seg)[-1] + if ext == '.vrt': + mode = 'vrt' + elif ext in ['.tif']: + mode = 'raster' + elif ext in ['.shp', '.gml', '.gpkg']: + mode = 'vector' + else: + raise ValueError('Output type {} not supported.'.format(ext)) + + in_img, n_tiles_x, n_tiles_y, nominal_tw, nominal_th = get_ls_seg_parameter(input_image, roi, params.margin, n_proc, + memory, force_parallel) + + print('[INFO] Using a layout of {} x {} tiles.'.format(n_tiles_x,n_tiles_y)) + print('[INFO] Nominal tile size w/margin: {} x {} pixels'.format(nominal_tw+2*params.margin, + nominal_th+2*params.margin)) + + if (n_tiles_x == 1 and n_tiles_y == 1): + # Fallback to classical GRM + print('[INFO] Fallback to one-tile GRM.') + grm = otb.Registry.CreateApplication('GenericRegionMerging') + grm.SetParameterInputImage('in', in_img.GetParameterOutputImage('out')) + grm.SetParameterFloat('threshold', params.threshold) + grm.SetParameterFloat('cw', params.color_weight) + grm.SetParameterFloat('sw', params.spatial_weight) + grm.SetParameterString('out', out_seg) + grm.SetParameterOutputImagePixelType('out', otb.ImagePixelType_uint32) + grm.ExecuteAndWriteOutput() + else: + tile_index_list = product(range(n_tiles_x), range(n_tiles_y)) + arg_list = [(input_image, params, nominal_tw, nominal_th, x, out_seg, roi) for x in tile_index_list] + with mp.Pool(n_proc) as p: + out_tiles = p.starmap(grm_process_tile, arg_list) + cumul_label = 0 + for f,l in out_tiles: + if cumul_label > 0: + with rasterio.open(f, 'r+') as tile: + arr = tile.read(1) + arr[arr>0] += cumul_label + tile.write(arr, indexes=1) + tile.nodata = 0 + cumul_label += l + + print('[INFO] Output mode {}.'.format(mode)) + if mode == 'vrt': + vrtopt = gdal.BuildVRTOptions(separate=False, srcNodata=0, VRTNodata=0) + gdal.BuildVRT(out_seg, [x[0] for x in out_tiles], options=vrtopt) + elif mode in ['raster', 'vector']: + mos = otb.Registry.CreateApplication('Mosaic') + mos.SetParameterStringList('il', [x[0] for x in out_tiles]) + mos.SetParameterInt('nodata', 0) + if mode == 'raster': + mos.SetParameterOutputImagePixelType('out', otb.ImagePixelType_uint32) + mos.SetParameterString('out', out_seg) + mos.ExecuteAndWriteOutput() + elif mode == 'vector': + mos.Execute() + vec = otb.Registry.CreateApplication('SimpleVectorization') + vec.SetParameterInputImage('in', mos.GetParameterOutputImage('out')) + vec.SetParameterString('out', out_seg) + vec.ExecuteAndWriteOutput() + if remove_tiles_if_useless: + [os.remove(x[0]) for x in out_tiles] + + return out_seg + + + def lsgrm(input_image, params : LSGRMParams, out_seg, n_proc=None, memory=None, roi=None, remove_graph=True, force_parallel=False): # Check output file type diff --git a/docker/dockerfile b/docker/dockerfile index cb1850284e5c9d25bfb3dec09b9e28c38a34a7de..2804f617735641e33c8e2ccd50045d7556b4b937 100644 --- a/docker/dockerfile +++ b/docker/dockerfile @@ -100,6 +100,7 @@ ENV LD_LIBRARY_PATH="/opt/otb/lib:$LD_LIBRARY_PATH" ENV OTB_APPLICATION_PATH="/opt/otb/lib/otb/applications" ENV PROJ_LIB="$DEPS_INSTALL_PREFIX/share/proj" ENV GDAL_DATA="$DEPS_INSTALL_PREFIX/share/gdal" +ENV OTB_LOGGER_LEVEL="CRITICAL" RUN useradd -s /bin/bash -m ubuntu USER ubuntu