segmentation.py 7.67 KiB
import sys
sys.path.append('../Common')

from Common.otb_numpy_proc import to_otb_pipeline
import os.path
import otbApplication as otb
import multiprocessing as mp
from math import sqrt, floor
from itertools import product, chain
from dataclasses import dataclass
import psutil

# GLOBAL
node_size = 700 # size of a graph node (=pixel) in GRM

@dataclass
class LSGRMParams:
    threshold: float
    color_weight: float
    spatial_weight: float
    n_first_iter: int = 8
    margin : int = 100

def lsgrm_process_tile(input_image, params : LSGRMParams, tile_width, tile_height, tile_idx, out_graph, roi=None):
    if roi is None:
        in_img = to_otb_pipeline(input_image)

    seg_app = otb.Registry.CreateApplication('SingleTileGRMGraph')
    seg_app.SetParameterInputImage('in', in_img.GetParameterOutputImage('out'))
    seg_app.SetParameterFloat('threshold', params.threshold)
    seg_app.SetParameterFloat('criterion.bs.cw', params.color_weight)
    seg_app.SetParameterFloat('criterion.bs.sw', params.spatial_weight)
    seg_app.SetParameterString('tiling', 'user')
    seg_app.SetParameterInt('tiling.user.sizex', tile_width)
    seg_app.SetParameterInt('tiling.user.sizey', tile_height)
    seg_app.SetParameterInt('tiling.user.nfirstiter', params.n_first_iter)
    seg_app.SetParameterInt('tiling.user.margin', params.margin)
    seg_app.SetParameterInt('xtileidx', tile_idx[0])
    seg_app.SetParameterInt('ytileidx', tile_idx[1])
    seg_app.SetParameterString('out', out_graph)
    seg_app.Execute()

    return [out_graph + '_node_{}_{}.bin'.format(tile_idx[1], tile_idx[0]),
            out_graph + '_nodeMargin_{}_{}.bin'.format(tile_idx[1], tile_idx[0]),
            out_graph + '_edge_{}_{}.bin'.format(tile_idx[1], tile_idx[0]),
            out_graph + '_edgeMargin_{}_{}.bin'.format(tile_idx[1], tile_idx[0])]


def lsgrm(input_image, params : LSGRMParams, out_seg, n_proc=None, memory=None, roi=None, remove_graph=True, force_parallel=False):
    # Check output file type
    ext = os.path.splitext(out_seg)[-1]
    if ext in ['.tif']:
        vectorize = False
    elif ext in ['.shp', '.gpkg', 'gml']:
        vectorize = True
    else:
        raise ValueError('Output type {} not recognized/supported.'.format(ext))

    # Define default number of threads (half) and memory amount (3/4 of available)
    if n_proc is None:
        n_proc = round(mp.cpu_count() / 2)
    if memory is None:
        memory = round(psutil.virtual_memory().available * 0.75)
    else:
        memory *= 1e6

    if roi is None:
        in_img = to_otb_pipeline(input_image)

    # Get image size
    W, H = in_img.GetImageSize('out')

    # adapt memory amount to force fitting the number of cores
    if force_parallel and memory > (W * H * node_size) and W > params.margin and H > params.margin:
        memory = W * H * node_size

    # Compute optimal tile size
    T = (memory/n_proc)/node_size # approx. pixels per tile with margins
    tile_width_wmarg, tile_height_wmarg = floor(sqrt(T * W / H)), floor(sqrt(T * H / W))
    nominal_tw, nominal_th = tile_width_wmarg-2*params.margin, tile_height_wmarg-2*params.margin
    n_tiles_x, n_tiles_y = max(1,floor(W/nominal_tw)), max(1,floor(H/nominal_th))

    if (n_tiles_x == 1 and n_tiles_y == 1):
        # Fallback to classical GRM
        grm = otb.Registry.CreateApplication('GenericRegionMerging')
        grm.SetParameterInputImage('in', in_img.GetParameterOutputImage('out'))
        grm.SetParameterFloat('threshold', params.threshold)
        grm.SetParameterFloat('cw', params.color_weight)
        grm.SetParameterFloat('sw', params.spatial_weight)
        grm.SetParameterString('out', out_seg)
        grm.ExecuteAndWriteOutput()

    else:
        tile_index_list = product(range(n_tiles_x), range(n_tiles_y))
        graph = os.path.splitext(out_seg)[0]
        arg_list = [(input_image, params, nominal_tw, nominal_th, x, graph, roi) for x in tile_index_list]

        with mp.Pool(n_proc) as p:
            graph_files = p.starmap(lsgrm_process_tile, arg_list)

        agg_app = otb.Registry.CreateApplication('AssembleGRMGraphs')
        agg_app.SetParameterInputImage('in', in_img.GetParameterOutputImage('out'))
        agg_app.SetParameterString('graph', graph)
        agg_app.SetParameterFloat('threshold', params.threshold)
        agg_app.SetParameterFloat('criterion.bs.cw', params.color_weight)
        agg_app.SetParameterFloat('criterion.bs.sw', params.spatial_weight)
        agg_app.SetParameterString('tiling', 'user')
        agg_app.SetParameterInt('tiling.user.sizex', nominal_tw)
        agg_app.SetParameterInt('tiling.user.sizey', nominal_th)
        #agg_app.ExecuteAndWriteOutput()

        if vectorize:
            agg_app.Execute()
            vec_app = otb.Registry.CreateApplication('SimpleVectorization')
            vec_app.SetParameterInputImage('in', agg_app.GetParameterOutputImage('out'))
            vec_app.SetParameterString('out', out_seg)
            vec_app.ExecuteAndWriteOutput()
            write_qgis_seg_style(out_seg.replace(ext, '.qml'))
        else:
            agg_app.SetParameterString('out', out_seg)
            agg_app.ExecuteAndWriteOutput()

        if remove_graph:
            for f in chain(*graph_files):
                os.remove(f)

        return out_seg

def write_qgis_seg_style(out_file, line_color='255,255,0,255', line_width=0.46):
    with open(out_file, 'wb') as f:
        f.writelines(
            ["<!DOCTYPE qgis PUBLIC 'http://mrcc.com/qgis.dtd' 'SYSTEM'>",
             "<qgis styleCategories=\"Symbology\" version=\"3.14.1-Pi\">",
             "  <renderer-v2 forceraster=\"0\" type=\"singleSymbol\" symbollevels=\"0\" enableorderby=\"0\">",
             "    <symbols>",
             "      <symbol alpha=\"1\" clip_to_extent=\"1\" name=\"0\" type=\"fill\" force_rhr=\"0\">",
             "        <layer class=\"SimpleLine\" locked=\"0\" enabled=\"1\" pass=\"0\">",
             "          <prop k=\"capstyle\" v=\"square\"/>",
             "          <prop k=\"customdash\" v=\"5;2\"/>",
             "          <prop k=\"customdash_map_unit_scale\" v=\"3x:0,0,0,0,0,0\"/>",
             "          <prop k=\"customdash_unit\" v=\"MM\"/>",
             "          <prop k=\"draw_inside_polygon\" v=\"0\"/>",
             "          <prop k=\"joinstyle\" v=\"bevel\"/>",
             "          <prop k=\"line_color\" v=\"{}\"/>".format(line_color),
             "          <prop k=\"line_style\" v=\"solid\"/>",
             "          <prop k=\"line_width\" v=\"{}\"/>".format(line_width),
             "          <prop k=\"line_width_unit\" v=\"MM\"/>",
             "          <prop k=\"offset\" v=\"0\"/>",
             "          <prop k=\"offset_map_unit_scale\" v=\"3x:0,0,0,0,0,0\"/>",
             "          <prop k=\"offset_unit\" v=\"MM\"/>",
             "          <prop k=\"ring_filter\" v=\"0\"/>",
             "          <prop k=\"use_custom_dash\" v=\"0\"/>",
             "          <prop k=\"width_map_unit_scale\" v=\"3x:0,0,0,0,0,0\"/>",
             "          <data_defined_properties>",
             "            <Option type=\"Map\">",
             "              <Option name=\"name\" value=\"\" type=\"QString\"/>",
             "              <Option name=\"properties\"/>",
             "              <Option name=\"type\" value=\"collection\" type=\"QString\"/>",
             "            </Option>",
             "          </data_defined_properties>",
             "        </layer>",
             "      </symbol>",
             "    </symbols>",
             "    <rotation/>",
             "    <sizescale/>",
             "  </renderer-v2>",
             "  <blendMode>0</blendMode>",
             "  <featureBlendMode>0</featureBlendMode>",
             "  <layerGeometryType>2</layerGeometryType>",
             "</qgis>"]
        )
        return out_file