Commit 8ff2dec2 authored by Gaetano Raffaele's avatar Gaetano Raffaele
Browse files

ENH: Large scale segmentation parallel orchestrator.

No related merge requests found
Showing with 108 additions and 0 deletions
+108 -0
import sys
sys.path.append("../Common")
from Common.otb_numpy_proc import to_otb_pipeline
import os.path
import otbApplication as otb
import multiprocessing as mp
from math import sqrt, floor
from itertools import product, chain
from dataclasses import dataclass
import psutil
# GLOBAL
node_size = 700 # size of a graph node (=pixel) in GRM
@dataclass
class LSGRMParams:
threshold: float
color_weight: float
spatial_weight: float
n_first_iter: int = 8
margin : int = 100
def lsgrm_process_tile(input_image, params : LSGRMParams, tile_width, tile_height, tile_idx, out_graph, roi=None):
if roi is None:
in_img = to_otb_pipeline(input_image)
seg_app = otb.Registry.CreateApplication('SingleTileGRMGraph')
seg_app.SetParameterInputImage('in', in_img.GetParameterOutputImage('out'))
seg_app.SetParameterFloat('threshold', params.threshold)
seg_app.SetParameterFloat('criterion.bs.cw', params.color_weight)
seg_app.SetParameterFloat('criterion.bs.sw', params.spatial_weight)
seg_app.SetParameterString('tiling', 'user')
seg_app.SetParameterInt('tiling.user.sizex', tile_width)
seg_app.SetParameterInt('tiling.user.sizey', tile_height)
seg_app.SetParameterInt('tiling.user.nfirstiter', params.n_first_iter)
seg_app.SetParameterInt('tiling.user.margin', params.margin)
seg_app.SetParameterInt('xtileidx', tile_idx[0])
seg_app.SetParameterInt('ytileidx', tile_idx[1])
seg_app.SetParameterString('out', out_graph)
seg_app.Execute()
return [out_graph + '_node_{}_{}.bin'.format(tile_idx[1], tile_idx[0]),
out_graph + '_nodeMargin_{}_{}.bin'.format(tile_idx[1], tile_idx[0]),
out_graph + '_edge_{}_{}.bin'.format(tile_idx[1], tile_idx[0]),
out_graph + '_edgeMargin_{}_{}.bin'.format(tile_idx[1], tile_idx[0])]
def lsgrm(input_image, params : LSGRMParams, out_seg, n_proc=None, memory=None, roi=None, remove_graph=True, force_parallel=False):
# Define default number of threads and memory amount
if n_proc is None:
n_proc = mp.cpu_count()
if memory is None:
memory = psutil.virtual_memory().available
else:
memory *= 1e6
if roi is None:
in_img = to_otb_pipeline(input_image)
# Get image size
W, H = in_img.GetImageSize('out')
# adapt memory amount to force fitting the number of cores
if force_parallel and memory > (W * H * node_size) and W > params.margin and H > params.margin:
memory = W * H * node_size
# Compute optimal tile size
T = (memory/n_proc)/node_size # approx. pixels per tile with margins
tile_width_wmarg, tile_height_wmarg = floor(sqrt(T * W / H)), floor(sqrt(T * H / W))
nominal_tw, nominal_th = tile_width_wmarg-2*params.margin, tile_height_wmarg-2*params.margin
n_tiles_x, n_tiles_y = max(1,floor(W/nominal_tw)), max(1,floor(H/nominal_th))
if (n_tiles_x == 1 and n_tiles_y == 1):
# Fallback to classical GRM
grm = otb.Registry.CreateApplication('GenericRegionMerging')
grm.SetParameterInputImage('in', in_img.GetParameterOutputImage('out'))
grm.SetParameterFloat('threshold', params.threshold)
grm.SetParameterFloat('cw', params.color_weight)
grm.SetParameterFloat('sw', params.spatial_weight)
grm.SetParameterString('out', out_seg)
grm.ExecuteAndWriteOutput()
else:
tile_index_list = product(range(n_tiles_x), range(n_tiles_y))
graph = os.path.splitext(out_seg)[0]
arg_list = [(input_image, params, nominal_tw, nominal_th, x, graph, roi) for x in tile_index_list]
with mp.Pool(n_proc) as p:
graph_files = p.starmap(lsgrm_process_tile, arg_list)
agg_app = otb.Registry.CreateApplication('AssembleGRMGraphs')
agg_app.SetParameterInputImage('in', in_img.GetParameterOutputImage('out'))
agg_app.SetParameterString('graph', graph)
agg_app.SetParameterString('out', out_seg)
agg_app.SetParameterFloat('threshold', params.threshold)
agg_app.SetParameterFloat('criterion.bs.cw', params.color_weight)
agg_app.SetParameterFloat('criterion.bs.sw', params.spatial_weight)
agg_app.SetParameterString('tiling', 'user')
agg_app.SetParameterInt('tiling.user.sizex', nominal_tw)
agg_app.SetParameterInt('tiling.user.sizey', nominal_th)
agg_app.ExecuteAndWriteOutput()
if remove_graph:
for f in chain(*graph_files):
os.remove(f)
return out_seg
main.py 0 → 100644
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment