diff --git a/VHR/segmentation.py b/VHR/segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..56a8202f8d870cb3383d1788a9d8804d308dcf59 --- /dev/null +++ b/VHR/segmentation.py @@ -0,0 +1,108 @@ +import sys +sys.path.append("../Common") + +from Common.otb_numpy_proc import to_otb_pipeline +import os.path +import otbApplication as otb +import multiprocessing as mp +from math import sqrt, floor +from itertools import product, chain +from dataclasses import dataclass +import psutil + +# GLOBAL +node_size = 700 # size of a graph node (=pixel) in GRM + +@dataclass +class LSGRMParams: + threshold: float + color_weight: float + spatial_weight: float + n_first_iter: int = 8 + margin : int = 100 + +def lsgrm_process_tile(input_image, params : LSGRMParams, tile_width, tile_height, tile_idx, out_graph, roi=None): + if roi is None: + in_img = to_otb_pipeline(input_image) + + seg_app = otb.Registry.CreateApplication('SingleTileGRMGraph') + seg_app.SetParameterInputImage('in', in_img.GetParameterOutputImage('out')) + seg_app.SetParameterFloat('threshold', params.threshold) + seg_app.SetParameterFloat('criterion.bs.cw', params.color_weight) + seg_app.SetParameterFloat('criterion.bs.sw', params.spatial_weight) + seg_app.SetParameterString('tiling', 'user') + seg_app.SetParameterInt('tiling.user.sizex', tile_width) + seg_app.SetParameterInt('tiling.user.sizey', tile_height) + seg_app.SetParameterInt('tiling.user.nfirstiter', params.n_first_iter) + seg_app.SetParameterInt('tiling.user.margin', params.margin) + seg_app.SetParameterInt('xtileidx', tile_idx[0]) + seg_app.SetParameterInt('ytileidx', tile_idx[1]) + seg_app.SetParameterString('out', out_graph) + seg_app.Execute() + + return [out_graph + '_node_{}_{}.bin'.format(tile_idx[1], tile_idx[0]), + out_graph + '_nodeMargin_{}_{}.bin'.format(tile_idx[1], tile_idx[0]), + out_graph + '_edge_{}_{}.bin'.format(tile_idx[1], tile_idx[0]), + out_graph + '_edgeMargin_{}_{}.bin'.format(tile_idx[1], tile_idx[0])] + + +def lsgrm(input_image, params : LSGRMParams, out_seg, n_proc=None, memory=None, roi=None, remove_graph=True, force_parallel=False): + # Define default number of threads and memory amount + if n_proc is None: + n_proc = mp.cpu_count() + if memory is None: + memory = psutil.virtual_memory().available + else: + memory *= 1e6 + + if roi is None: + in_img = to_otb_pipeline(input_image) + + # Get image size + W, H = in_img.GetImageSize('out') + + # adapt memory amount to force fitting the number of cores + if force_parallel and memory > (W * H * node_size) and W > params.margin and H > params.margin: + memory = W * H * node_size + + # Compute optimal tile size + T = (memory/n_proc)/node_size # approx. pixels per tile with margins + tile_width_wmarg, tile_height_wmarg = floor(sqrt(T * W / H)), floor(sqrt(T * H / W)) + nominal_tw, nominal_th = tile_width_wmarg-2*params.margin, tile_height_wmarg-2*params.margin + n_tiles_x, n_tiles_y = max(1,floor(W/nominal_tw)), max(1,floor(H/nominal_th)) + + if (n_tiles_x == 1 and n_tiles_y == 1): + # Fallback to classical GRM + grm = otb.Registry.CreateApplication('GenericRegionMerging') + grm.SetParameterInputImage('in', in_img.GetParameterOutputImage('out')) + grm.SetParameterFloat('threshold', params.threshold) + grm.SetParameterFloat('cw', params.color_weight) + grm.SetParameterFloat('sw', params.spatial_weight) + grm.SetParameterString('out', out_seg) + grm.ExecuteAndWriteOutput() + + else: + tile_index_list = product(range(n_tiles_x), range(n_tiles_y)) + graph = os.path.splitext(out_seg)[0] + arg_list = [(input_image, params, nominal_tw, nominal_th, x, graph, roi) for x in tile_index_list] + + with mp.Pool(n_proc) as p: + graph_files = p.starmap(lsgrm_process_tile, arg_list) + + agg_app = otb.Registry.CreateApplication('AssembleGRMGraphs') + agg_app.SetParameterInputImage('in', in_img.GetParameterOutputImage('out')) + agg_app.SetParameterString('graph', graph) + agg_app.SetParameterString('out', out_seg) + agg_app.SetParameterFloat('threshold', params.threshold) + agg_app.SetParameterFloat('criterion.bs.cw', params.color_weight) + agg_app.SetParameterFloat('criterion.bs.sw', params.spatial_weight) + agg_app.SetParameterString('tiling', 'user') + agg_app.SetParameterInt('tiling.user.sizex', nominal_tw) + agg_app.SetParameterInt('tiling.user.sizey', nominal_th) + agg_app.ExecuteAndWriteOutput() + + if remove_graph: + for f in chain(*graph_files): + os.remove(f) + + return out_seg diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391