Forked from HYCAR-Hydro / airGR
Source project has a limited visibility.
segmentation.py 11.21 KiB
import sys
from Common.otb_numpy_proc import to_otb_pipeline
import os.path
import otbApplication as otb
import multiprocessing as mp
from math import sqrt, floor
from itertools import product, chain
from dataclasses import dataclass
import psutil
from skimage.measure import regionprops
from tqdm import tqdm
import numpy as np

# GLOBAL
node_size = 700 # size of a graph node (=pixel) in GRM

@dataclass
class LSGRMParams:
    threshold: float
    color_weight: float
    spatial_weight: float
    n_first_iter: int = 8
    margin: int = 100

def lsgrm_process_tile(input_image, params : LSGRMParams, tile_width, tile_height, tile_idx, out_graph, roi=None):
    if roi is None:
        in_img = to_otb_pipeline(input_image)
    elif os.path.exists(roi):
        in_img = otb.Registry.CreateApplication('ExtractROI')
        in_img.SetParameterString('in', input_image)
        in_img.SetParameterString('mode', 'fit')
        in_img.SetParameterString('mode.fit.vect', roi)
        in_img.Execute()

    seg_app = otb.Registry.CreateApplication('SingleTileGRMGraph')
    seg_app.SetParameterInputImage('in', in_img.GetParameterOutputImage('out'))
    seg_app.SetParameterFloat('threshold', params.threshold)
    seg_app.SetParameterFloat('criterion.bs.cw', params.color_weight)
    seg_app.SetParameterFloat('criterion.bs.sw', params.spatial_weight)
    seg_app.SetParameterString('tiling', 'user')
    seg_app.SetParameterInt('tiling.user.sizex', tile_width)
    seg_app.SetParameterInt('tiling.user.sizey', tile_height)
    seg_app.SetParameterInt('tiling.user.nfirstiter', params.n_first_iter)
    seg_app.SetParameterInt('tiling.user.margin', params.margin)
    seg_app.SetParameterInt('xtileidx', tile_idx[0])
    seg_app.SetParameterInt('ytileidx', tile_idx[1])
    seg_app.SetParameterString('out', out_graph)
    seg_app.ExecuteAndWriteOutput()

    return [out_graph + '_node_{}_{}.bin'.format(tile_idx[1], tile_idx[0]),
            out_graph + '_nodeMargin_{}_{}.bin'.format(tile_idx[1], tile_idx[0]),
            out_graph + '_edge_{}_{}.bin'.format(tile_idx[1], tile_idx[0]),
            out_graph + '_edgeMargin_{}_{}.bin'.format(tile_idx[1], tile_idx[0])]


def lsgrm(input_image, params : LSGRMParams, out_seg, n_proc=None, memory=None, roi=None, remove_graph=True, force_parallel=False):
    # Check output file type
    ext = os.path.splitext(out_seg)[-1]
    if ext in ['.tif']:
        vectorize = False
    elif ext in ['.shp', '.gpkg', 'gml']:
        vectorize = True
    else:
        raise ValueError('Output type {} not recognized/supported.'.format(ext))

    # Define default number of threads (half) and memory amount (3/4 of available)
    if n_proc is None:
        n_proc = round(mp.cpu_count() / 2)
    if memory is None:
        memory = round(psutil.virtual_memory().available * 0.75)
    else:
        memory *= 1e6

    if roi is None:
        in_img = to_otb_pipeline(input_image)
    elif os.path.exists(roi):
        in_img = otb.Registry.CreateApplication('ExtractROI')
        in_img.SetParameterString('in', input_image)
        in_img.SetParameterString('mode', 'fit')
        in_img.SetParameterString('mode.fit.vect', roi)
        in_img.Execute()
    else:
        print('ROI provided but cannot find file.')
        sys.exit(-1)

    # Get image size
    W, H = in_img.GetImageSize('out')

    # adapt memory amount to force fitting the number of cores
    if force_parallel and memory > (W * H * node_size) and W > params.margin and H > params.margin:
        memory = W * H * node_size

    # Compute optimal tile size
    T = (memory/n_proc)/node_size # approx. pixels per tile with margins
    tile_width_wmarg, tile_height_wmarg = floor(sqrt(T * W / H)), floor(sqrt(T * H / W))
    nominal_tw, nominal_th = tile_width_wmarg-2*params.margin, tile_height_wmarg-2*params.margin
    n_tiles_x, n_tiles_y = max(1,floor(W/nominal_tw)), max(1,floor(H/nominal_th))

    if (n_tiles_x == 1 and n_tiles_y == 1):
        # Fallback to classical GRM
        grm = otb.Registry.CreateApplication('GenericRegionMerging')
        grm.SetParameterInputImage('in', in_img.GetParameterOutputImage('out'))
        grm.SetParameterFloat('threshold', params.threshold)
        grm.SetParameterFloat('cw', params.color_weight)
        grm.SetParameterFloat('sw', params.spatial_weight)
        grm.SetParameterString('out', out_seg)
        grm.SetParameterOutputImagePixelType('out', otb.ImagePixelType_uint32)
        grm.ExecuteAndWriteOutput()

    else:
        tile_index_list = product(range(n_tiles_x), range(n_tiles_y))
        graph = os.path.splitext(out_seg)[0]
        arg_list = [(input_image, params, nominal_tw, nominal_th, x, graph, roi) for x in tile_index_list]

        with mp.Pool(n_proc) as p:
            graph_files = p.starmap(lsgrm_process_tile, arg_list)

        agg_app = otb.Registry.CreateApplication('AssembleGRMGraphs')
        agg_app.SetParameterInputImage('in', in_img.GetParameterOutputImage('out'))
        agg_app.SetParameterString('graph', graph)
        agg_app.SetParameterFloat('threshold', params.threshold)
        agg_app.SetParameterFloat('criterion.bs.cw', params.color_weight)
        agg_app.SetParameterFloat('criterion.bs.sw', params.spatial_weight)
        agg_app.SetParameterString('tiling', 'user')
        agg_app.SetParameterInt('tiling.user.sizex', nominal_tw)
        agg_app.SetParameterInt('tiling.user.sizey', nominal_th)
        #agg_app.ExecuteAndWriteOutput()

        if vectorize:
            agg_app.Execute()
            vec_app = otb.Registry.CreateApplication('SimpleVectorization')
            vec_app.SetParameterInputImage('in', agg_app.GetParameterOutputImage('out'))
            vec_app.SetParameterString('out', out_seg)
            vec_app.ExecuteAndWriteOutput()
            write_qgis_seg_style(out_seg.replace(ext, '.qml'))
        else:
            agg_app.SetParameterString('out', out_seg)
            agg_app.SetParameterOutputImagePixelType('out', otb.ImagePixelType_uint32)
            agg_app.ExecuteAndWriteOutput()

        if remove_graph:
            for f in chain(*graph_files):
                os.remove(f)

        return out_seg

def write_qgis_seg_style(out_file, line_color='255,255,0,255', line_width=0.46):
    with open(out_file, 'w') as f:
        f.writelines(
            ["<!DOCTYPE qgis PUBLIC 'http://mrcc.com/qgis.dtd' 'SYSTEM'>",
             "<qgis styleCategories=\"Symbology\" version=\"3.14.1-Pi\">",
             "  <renderer-v2 forceraster=\"0\" type=\"singleSymbol\" symbollevels=\"0\" enableorderby=\"0\">",
             "    <symbols>",
             "      <symbol alpha=\"1\" clip_to_extent=\"1\" name=\"0\" type=\"fill\" force_rhr=\"0\">",
             "        <layer class=\"SimpleLine\" locked=\"0\" enabled=\"1\" pass=\"0\">",
             "          <prop k=\"capstyle\" v=\"square\"/>",
             "          <prop k=\"customdash\" v=\"5;2\"/>",
             "          <prop k=\"customdash_map_unit_scale\" v=\"3x:0,0,0,0,0,0\"/>",
             "          <prop k=\"customdash_unit\" v=\"MM\"/>",
             "          <prop k=\"draw_inside_polygon\" v=\"0\"/>",
             "          <prop k=\"joinstyle\" v=\"bevel\"/>",
             "          <prop k=\"line_color\" v=\"{}\"/>".format(line_color),
             "          <prop k=\"line_style\" v=\"solid\"/>",
             "          <prop k=\"line_width\" v=\"{}\"/>".format(line_width),
             "          <prop k=\"line_width_unit\" v=\"MM\"/>",
             "          <prop k=\"offset\" v=\"0\"/>",
             "          <prop k=\"offset_map_unit_scale\" v=\"3x:0,0,0,0,0,0\"/>",
             "          <prop k=\"offset_unit\" v=\"MM\"/>",
             "          <prop k=\"ring_filter\" v=\"0\"/>",
             "          <prop k=\"use_custom_dash\" v=\"0\"/>",
             "          <prop k=\"width_map_unit_scale\" v=\"3x:0,0,0,0,0,0\"/>",
             "          <data_defined_properties>",
             "            <Option type=\"Map\">",
             "              <Option name=\"name\" value=\"\" type=\"QString\"/>",
             "              <Option name=\"properties\"/>",
             "              <Option name=\"type\" value=\"collection\" type=\"QString\"/>",
             "            </Option>",
             "          </data_defined_properties>",
             "        </layer>",
             "      </symbol>",
             "    </symbols>",
             "    <rotation/>",
             "    <sizescale/>",
             "  </renderer-v2>",
             "  <blendMode>0</blendMode>",
             "  <featureBlendMode>0</featureBlendMode>",
             "  <layerGeometryType>2</layerGeometryType>",
             "</qgis>"]
        )
        return out_file

def vectorize_tile(obj, region, to_keep, out):
    r = otb.itkRegion()
    r['index'][0], r['index'][1] = region[0], region[1]
    r['size'][0], r['size'][1] = region[2], region[3]
    obj.PropagateRequestedRegion('out', r)
    clip = obj.ExportImage('out')

    clip['array'] *= np.isin(clip['array'], to_keep)

    vec = otb.Registry.CreateApplication('SimpleVectorization')
    vec.ImportImage('in', clip)
    vec.SetParameterString('out', out)
    vec.ExecuteAndWriteOutput()

    return out

def tiled_vectorization(input_segm, nominal_tile_size, output_template):
    in_seg = to_otb_pipeline(input_segm)
    full = in_seg.GetImageAsNumpyArray('out')
    rp = regionprops(np.squeeze(full.astype(np.uint32)))

    W, H = in_seg.GetImageSize('out')
    tx, ty = int(W / nominal_tile_size[0]) + 1, int(H / nominal_tile_size[1]) + 1

    obj_to_tile = dict.fromkeys(range(tx*ty))
    tiles = dict.fromkeys(range(tx*ty))
    for i in range(tx*ty):
        obj_to_tile[i] = []
        tiles[i] = [np.inf, np.inf, 0, 0]

    for o in rp:
        if o.label != 0:
            ix, iy = int(o.bbox[1] / nominal_tile_size[0]), int(o.bbox[0] / nominal_tile_size[1])
            idx = ix * ty + iy
            obj_to_tile[idx].append(o.label)
            tiles[idx][0] = min(o.bbox[1], tiles[idx][0])
            tiles[idx][1] = min(o.bbox[0], tiles[idx][1])
            tiles[idx][2] = max(o.bbox[3], tiles[idx][2])
            tiles[idx][3] = max(o.bbox[2], tiles[idx][3])

    out_files = []
    for i in range(len(tiles)):
        tiles[i][2] -= tiles[i][0]
        tiles[i][3] -= tiles[i][1]
        in_seg = to_otb_pipeline(input_segm)
        if len(obj_to_tile[i]) > 0:
            out_files.append(output_template.format(i))
            vectorize_tile(in_seg, tiles[i], obj_to_tile[i], out_files[-1])

    return out_files

def get_bounding_boxes(input_segm):
    in_seg = to_otb_pipeline(input_segm)

    W, H = in_seg.GetImageSize('out')

    r = otb.itkRegion()
    r['index'][0] = 0
    r['size'][0], r['size'][1] = W, 1

    bboxes = {}
    for y in tqdm(range(H)):
        r['index'][1] = y
        in_seg.PropagateRequestedRegion('out', r)
        row = in_seg.ExportImage('out')
        row = np.squeeze(row['array'])
        lbl, pos = np.unique(row, return_index=True)
        _, rpos = np.unique(row[::-1], return_index=True)
        lbl = lbl.astype(np.uint32)
        rpos = len(row) - rpos - 1
        for i in range(len(lbl)):
            if lbl[i] not in bboxes.keys():
                bboxes[lbl[i]] = [pos[i], y, rpos[i], y]
            else:
                bboxes[lbl[i]][0] = min(bboxes[lbl[i]][0], pos[i])
                bboxes[lbl[i]][2] = max(bboxes[lbl[i]][2], rpos[i])
                bboxes[lbl[i]][3] = y

    return bboxes