Commit c506f3e2 authored by Gaetano Raffaele's avatar Gaetano Raffaele
Browse files

ENH: new suboptimal version of lsgrm.

No related merge requests found
Showing with 156 additions and 2 deletions
+156 -2
import sys
import rasterio
from osgeo import gdal
from Common.otb_numpy_proc import to_otb_pipeline
import os.path
import otbApplication as otb
import multiprocessing as mp
from math import sqrt, floor
from math import sqrt, floor, ceil
from itertools import product, chain
from dataclasses import dataclass
import psutil
from skimage.measure import regionprops
from skimage.measure import regionprops, label
from tqdm import tqdm
import numpy as np
......@@ -22,6 +25,73 @@ class LSGRMParams:
n_first_iter: int = 8
margin: int = 100
def grm_process_tile(input_image, params : LSGRMParams, tile_width, tile_height, tile_idx, out_img, roi):
if roi is None:
in_img = to_otb_pipeline(input_image)
elif os.path.exists(roi):
in_img = otb.Registry.CreateApplication('ExtractROI')
in_img.SetParameterString('in', input_image)
in_img.SetParameterString('mode', 'fit')
in_img.SetParameterString('mode.fit.vect', roi)
in_img.Execute()
else:
print('ROI provided but cannot find file.')
sys.exit(-1)
W, H = in_img.GetImageSize('out')
startx = 0 if tile_idx[0] == 0 else tile_idx[0] * tile_width - params.margin
starty = 0 if tile_idx[1] == 0 else tile_idx[1] * tile_height - params.margin
endx = W if tile_width * (tile_idx[0] + 1) >= W else tile_width * (tile_idx[0] + 1) + params.margin
endy = H if tile_height * (tile_idx[1] + 1) >= H else tile_height * (tile_idx[1] + 1) + params.margin
tile = otb.Registry.CreateApplication('ExtractROI')
tile.SetParameterInputImage('in', in_img.GetParameterOutputImage('out'))
tile.SetParameterInt('startx', startx)
tile.SetParameterInt('starty', starty)
tile.SetParameterInt('sizex', endx - startx)
tile.SetParameterInt('sizey', endy - starty)
tile.Execute()
out_fn = '{}_{}_{}.tif'.format(os.path.splitext(out_img)[0], tile_idx[1], tile_idx[0])
seg_app = otb.Registry.CreateApplication('GenericRegionMerging')
seg_app.SetParameterInputImage('in', tile.GetParameterOutputImage('out'))
seg_app.SetParameterFloat('threshold', params.threshold)
seg_app.SetParameterFloat('cw', params.color_weight)
seg_app.SetParameterFloat('sw', params.spatial_weight)
seg_app.SetParameterInt('niter', params.n_first_iter)
seg_app.Execute()
tie_lines = (0 if starty == 0 else params.margin,
0 if startx == 0 else params.margin,
0, 0)
tie_lines = (tie_lines[0], tie_lines[1],
H if endy == H else tile_height + tie_lines[0],
W if endx == W else tile_width + tie_lines[1])
seg = seg_app.ExportImage('out')
op = regionprops(seg['array'][:,:,0].astype(np.int32))
to_del = []
for o in op:
if not ((tie_lines[0] <= o.bbox[0] < tie_lines[2])
and tie_lines[1] <= o.bbox[1] < tie_lines[3]):
to_del.append(o.label)
seg['array'][np.isin(seg['array'], to_del)] = 0
seg['array'], nlab = label(seg['array'], background=0, connectivity=1, return_num=True)
seg['array'] = np.ascontiguousarray(seg['array']).astype(float)
out_img = otb.Registry.CreateApplication('ExtractROI')
out_img.ImportImage('in', seg)
out_img.SetParameterString('out', out_fn)
out_img.SetParameterOutputImagePixelType('out', otb.ImagePixelType_uint32)
out_img.ExecuteAndWriteOutput()
return out_fn, nlab
def lsgrm_process_tile(input_image, params : LSGRMParams, tile_width, tile_height, tile_idx, out_graph, roi=None):
if roi is None:
in_img = to_otb_pipeline(input_image)
......@@ -52,6 +122,90 @@ def lsgrm_process_tile(input_image, params : LSGRMParams, tile_width, tile_heigh
out_graph + '_edge_{}_{}.bin'.format(tile_idx[1], tile_idx[0]),
out_graph + '_edgeMargin_{}_{}.bin'.format(tile_idx[1], tile_idx[0])]
def get_ls_seg_parameter(input_image, roi=None, margin=0, n_proc=None, memory=None, force_parallel=False):
# Define default number of threads (half) and memory amount (3/4 of available)
if n_proc is None:
n_proc = round(mp.cpu_count() / 2)
if memory is None:
memory = round(psutil.virtual_memory().available * 0.75)
else:
memory *= 1e6
if roi is None:
in_img = to_otb_pipeline(input_image)
elif os.path.exists(roi):
in_img = otb.Registry.CreateApplication('ExtractROI')
in_img.SetParameterString('in', input_image)
in_img.SetParameterString('mode', 'fit')
in_img.SetParameterString('mode.fit.vect', roi)
in_img.Execute()
else:
print('ROI provided but cannot find file.')
sys.exit(-1)
# Get image size
W, H = in_img.GetImageSize('out')
# adapt memory amount to force fitting the number of cores
if force_parallel and memory > (W * H * node_size) and W > margin and H > margin:
memory = W * H * node_size
# Compute optimal tile size
T = (memory / n_proc) / node_size # approx. pixels per tile with margins
tile_width_wmarg, tile_height_wmarg = floor(sqrt(T * W / H)), floor(sqrt(T * H / W))
nominal_tw, nominal_th = tile_width_wmarg - 2 * margin, tile_height_wmarg - 2 * margin
n_tiles_x, n_tiles_y = max(1, ceil(W / nominal_tw)), max(1, ceil(H / nominal_th))
return in_img, n_tiles_x, n_tiles_y, nominal_tw, nominal_th
def lsgrm_light(input_image, params : LSGRMParams, out_seg, n_proc=None, memory=None, roi=None, force_parallel=False):
# Check output file type
ext = os.path.splitext(out_seg)[-1]
if ext in ['.tif']:
mode = 'raster'
elif ext in ['.shp', '.gpkg', 'gml']:
mode = 'vector'
elif ext in ['.vrt']:
mode = 'vrt'
else:
raise ValueError('Output type {} not recognized/supported.'.format(ext))
in_img, n_tiles_x, n_tiles_y, nominal_tw, nominal_th = get_ls_seg_parameter(input_image, roi, params.margin, n_proc,
memory, force_parallel)
if (n_tiles_x == 1 and n_tiles_y == 1):
# Fallback to classical GRM
grm = otb.Registry.CreateApplication('GenericRegionMerging')
grm.SetParameterInputImage('in', in_img.GetParameterOutputImage('out'))
grm.SetParameterFloat('threshold', params.threshold)
grm.SetParameterFloat('cw', params.color_weight)
grm.SetParameterFloat('sw', params.spatial_weight)
grm.SetParameterString('out', out_seg)
grm.SetParameterOutputImagePixelType('out', otb.ImagePixelType_uint32)
grm.ExecuteAndWriteOutput()
else:
tile_index_list = product(range(n_tiles_x), range(n_tiles_y))
arg_list = [(input_image, params, nominal_tw, nominal_th, x, out_seg, roi) for x in tile_index_list]
with mp.Pool(n_proc) as p:
out_tiles = p.starmap(grm_process_tile, arg_list)
cumul_label = 0
for f,l in out_tiles:
if cumul_label > 0:
with rasterio.open(f, 'r+') as tile:
arr = tile.read(1)
arr[arr>0] += cumul_label
tile.write(arr, indexes=1)
tile.nodata = 0
cumul_label += l
if mode == 'vrt':
vrtopt = gdal.BuildVRTOptions(separate=False, srcNodata=0, VRTNodata=0)
gdal.BuildVRT(out_seg, [x[0] for x in out_tiles], options=vrtopt)
return out_seg
def lsgrm(input_image, params : LSGRMParams, out_seg, n_proc=None, memory=None, roi=None, remove_graph=True, force_parallel=False):
# Check output file type
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment