import sys

import rasterio
from osgeo import gdal
from Common.otb_numpy_proc import to_otb_pipeline
import os.path
import otbApplication as otb
import multiprocessing as mp
from math import sqrt, floor, ceil
from itertools import product, chain
from dataclasses import dataclass
import psutil
from skimage.measure import regionprops, label
from tqdm import tqdm
import numpy as np

# GLOBAL
node_size = 700 # size of a graph node (=pixel) in GRM

@dataclass
class LSGRMParams:
    threshold: float
    color_weight: float
    spatial_weight: float
    n_first_iter: int = 8
    margin: int = 100

def grm_process_tile(input_image, params : LSGRMParams, tile_width, tile_height, tile_idx, out_img, roi):

    if roi is None:
        in_img = to_otb_pipeline(input_image)
    elif os.path.exists(roi):
        in_img = otb.Registry.CreateApplication('ExtractROI')
        in_img.SetParameterString('in', input_image)
        in_img.SetParameterString('mode', 'fit')
        in_img.SetParameterString('mode.fit.vect', roi)
        in_img.Execute()
    else:
        print('ROI provided but cannot find file.')
        sys.exit(-1)

    W, H = in_img.GetImageSize('out')

    startx = 0 if tile_idx[0] == 0 else tile_idx[0] * tile_width - params.margin
    starty = 0 if tile_idx[1] == 0 else tile_idx[1] * tile_height - params.margin
    endx = W if tile_width * (tile_idx[0] + 1) >= W else tile_width * (tile_idx[0] + 1) + params.margin
    endy = H if tile_height * (tile_idx[1] + 1) >= H else tile_height * (tile_idx[1] + 1) + params.margin

    tile = otb.Registry.CreateApplication('ExtractROI')
    tile.SetParameterInputImage('in', in_img.GetParameterOutputImage('out'))
    tile.SetParameterInt('startx', startx)
    tile.SetParameterInt('starty', starty)
    tile.SetParameterInt('sizex', endx - startx)
    tile.SetParameterInt('sizey', endy - starty)
    tile.Execute()

    out_fn = '{}_{}_{}.tif'.format(os.path.splitext(out_img)[0], tile_idx[1], tile_idx[0])

    seg_app = otb.Registry.CreateApplication('GenericRegionMerging')
    seg_app.SetParameterInputImage('in', tile.GetParameterOutputImage('out'))
    seg_app.SetParameterFloat('threshold', params.threshold)
    seg_app.SetParameterFloat('cw', params.color_weight)
    seg_app.SetParameterFloat('sw', params.spatial_weight)
    seg_app.SetParameterInt('niter', params.n_first_iter)
    seg_app.Execute()

    tie_lines = (0 if starty == 0 else params.margin,
                 0 if startx == 0 else params.margin,
                 0, 0)

    tie_lines = (tie_lines[0], tie_lines[1],
                 H if endy == H else tile_height + tie_lines[0],
                 W if endx == W else tile_width + tie_lines[1])

    seg = seg_app.ExportImage('out')
    op = regionprops(seg['array'][:,:,0].astype(np.int32))
    to_del = []
    for o in op:
        if not (tie_lines[0] <= o.bbox[0] < tie_lines[2]
                and tie_lines[1] <= o.bbox[1] < tie_lines[3]):
            # This one to check if a potentially deleted object trespass the overlap area,
            # in which case it is kept.
            if not (2*params.margin <= o.bbox[2] < tile_height and
                    2*params.margin <= o.bbox[3] < tile_width):
                to_del.append(o.label)
    seg['array'][np.isin(seg['array'], to_del)] = 0
    seg['array'], nlab = label(seg['array'], background=0, connectivity=1, return_num=True)
    seg['array'] = np.ascontiguousarray(seg['array']).astype(float)

    out_img = otb.Registry.CreateApplication('ExtractROI')
    out_img.ImportImage('in', seg)
    out_img.SetParameterString('out', out_fn)
    out_img.SetParameterOutputImagePixelType('out', otb.ImagePixelType_uint32)
    out_img.ExecuteAndWriteOutput()

    return out_fn, nlab


def lsgrm_process_tile(input_image, params : LSGRMParams, tile_width, tile_height, tile_idx, out_graph, roi=None):
    if roi is None:
        in_img = to_otb_pipeline(input_image)
    elif os.path.exists(roi):
        in_img = otb.Registry.CreateApplication('ExtractROI')
        in_img.SetParameterString('in', input_image)
        in_img.SetParameterString('mode', 'fit')
        in_img.SetParameterString('mode.fit.vect', roi)
        in_img.Execute()

    seg_app = otb.Registry.CreateApplication('SingleTileGRMGraph')
    seg_app.SetParameterInputImage('in', in_img.GetParameterOutputImage('out'))
    seg_app.SetParameterFloat('threshold', params.threshold)
    seg_app.SetParameterFloat('criterion.bs.cw', params.color_weight)
    seg_app.SetParameterFloat('criterion.bs.sw', params.spatial_weight)
    seg_app.SetParameterString('tiling', 'user')
    seg_app.SetParameterInt('tiling.user.sizex', tile_width)
    seg_app.SetParameterInt('tiling.user.sizey', tile_height)
    seg_app.SetParameterInt('tiling.user.nfirstiter', params.n_first_iter)
    seg_app.SetParameterInt('tiling.user.margin', params.margin)
    seg_app.SetParameterInt('xtileidx', tile_idx[0])
    seg_app.SetParameterInt('ytileidx', tile_idx[1])
    seg_app.SetParameterString('out', out_graph)
    seg_app.ExecuteAndWriteOutput()

    return [out_graph + '_node_{}_{}.bin'.format(tile_idx[1], tile_idx[0]),
            out_graph + '_nodeMargin_{}_{}.bin'.format(tile_idx[1], tile_idx[0]),
            out_graph + '_edge_{}_{}.bin'.format(tile_idx[1], tile_idx[0]),
            out_graph + '_edgeMargin_{}_{}.bin'.format(tile_idx[1], tile_idx[0])]

def get_ls_seg_parameter(input_image, roi=None, margin=0, n_proc=None, memory=None, force_parallel=False):
    # Define default number of threads (half) and memory amount (3/4 of available)
    if n_proc is None:
        n_proc = round(mp.cpu_count() / 2)
    if memory is None:
        memory = round(psutil.virtual_memory().available * 0.75)
    else:
        memory *= 1e6

    if roi is None:
        in_img = to_otb_pipeline(input_image)
    elif os.path.exists(roi):
        in_img = otb.Registry.CreateApplication('ExtractROI')
        in_img.SetParameterString('in', input_image)
        in_img.SetParameterString('mode', 'fit')
        in_img.SetParameterString('mode.fit.vect', roi)
        in_img.Execute()
    else:
        print('ROI provided but cannot find file.')
        sys.exit(-1)

    # Get image size
    W, H = in_img.GetImageSize('out')

    # adapt memory amount to force fitting the number of cores
    if force_parallel and memory > (W * H * node_size) and W > margin and H > margin:
        memory = W * H * node_size

    # Compute optimal tile size
    T = (memory / n_proc) / node_size  # approx. pixels per tile with margins
    tile_width_wmarg, tile_height_wmarg = floor(sqrt(T * W / H)), floor(sqrt(T * H / W))
    nominal_tw, nominal_th = tile_width_wmarg - 2 * margin, tile_height_wmarg - 2 * margin
    n_tiles_x, n_tiles_y = max(1, ceil(W / nominal_tw)), max(1, ceil(H / nominal_th))

    return in_img, n_tiles_x, n_tiles_y, nominal_tw, nominal_th

def lsgrm_light(input_image, params : LSGRMParams, out_seg, n_proc=None, memory=None,
                roi=None, force_parallel=False, remove_tiles_if_useless=True):
    # Check output file type
    ext = os.path.splitext(out_seg)[-1]
    if ext == '.vrt':
        mode = 'vrt'
    elif ext in ['.tif']:
        mode = 'raster'
    elif ext in ['.shp', '.gml', '.gpkg']:
        mode = 'vector'
    else:
        raise ValueError('Output type {} not supported.'.format(ext))

    in_img, n_tiles_x, n_tiles_y, nominal_tw, nominal_th = get_ls_seg_parameter(input_image, roi, params.margin, n_proc,
                                                                        memory, force_parallel)

    print('[INFO] Using a layout of {} x {} tiles.'.format(n_tiles_x,n_tiles_y))
    print('[INFO] Nominal tile size w/margin: {} x {} pixels'.format(nominal_tw+2*params.margin,
                                                                     nominal_th+2*params.margin))

    if (n_tiles_x == 1 and n_tiles_y == 1):
        # Fallback to classical GRM
        print('[INFO] Fallback to one-tile GRM.')
        grm = otb.Registry.CreateApplication('GenericRegionMerging')
        grm.SetParameterInputImage('in', in_img.GetParameterOutputImage('out'))
        grm.SetParameterFloat('threshold', params.threshold)
        grm.SetParameterFloat('cw', params.color_weight)
        grm.SetParameterFloat('sw', params.spatial_weight)
        grm.SetParameterString('out', out_seg)
        grm.SetParameterOutputImagePixelType('out', otb.ImagePixelType_uint32)
        grm.ExecuteAndWriteOutput()
    else:
        tile_index_list = product(range(n_tiles_x), range(n_tiles_y))
        arg_list = [(input_image, params, nominal_tw, nominal_th, x, out_seg, roi) for x in tile_index_list]
        with mp.Pool(n_proc) as p:
            out_tiles = p.starmap(grm_process_tile, arg_list)
        cumul_label = 0
        for f,l in out_tiles:
            if cumul_label > 0:
                with rasterio.open(f, 'r+') as tile:
                    arr = tile.read(1)
                    arr[arr>0] += cumul_label
                    tile.write(arr, indexes=1)
                    tile.nodata = 0
            cumul_label += l

        print('[INFO] Output mode {}.'.format(mode))
        if mode == 'vrt':
            vrtopt = gdal.BuildVRTOptions(separate=False, srcNodata=0, VRTNodata=0)
            gdal.BuildVRT(out_seg, [x[0] for x in out_tiles], options=vrtopt)
        elif mode in ['raster', 'vector']:
            mos = otb.Registry.CreateApplication('Mosaic')
            mos.SetParameterStringList('il', [x[0] for x in out_tiles])
            mos.SetParameterInt('nodata', 0)
            if mode == 'raster':
                mos.SetParameterOutputImagePixelType('out', otb.ImagePixelType_uint32)
                mos.SetParameterString('out', out_seg)
                mos.ExecuteAndWriteOutput()
            elif mode == 'vector':
                mos.Execute()
                vec = otb.Registry.CreateApplication('SimpleVectorization')
                vec.SetParameterInputImage('in', mos.GetParameterOutputImage('out'))
                vec.SetParameterString('out', out_seg)
                vec.ExecuteAndWriteOutput()
            if remove_tiles_if_useless:
                [os.remove(x[0]) for x in out_tiles]

    return out_seg




def lsgrm(input_image, params : LSGRMParams, out_seg, n_proc=None, memory=None, roi=None, remove_graph=True, force_parallel=False):
    # Check output file type
    ext = os.path.splitext(out_seg)[-1]
    if ext in ['.tif']:
        vectorize = False
    elif ext in ['.shp', '.gpkg', 'gml']:
        vectorize = True
    else:
        raise ValueError('Output type {} not recognized/supported.'.format(ext))

    # Define default number of threads (half) and memory amount (3/4 of available)
    if n_proc is None:
        n_proc = round(mp.cpu_count() / 2)
    if memory is None:
        memory = round(psutil.virtual_memory().available * 0.75)
    else:
        memory *= 1e6

    if roi is None:
        in_img = to_otb_pipeline(input_image)
    elif os.path.exists(roi):
        in_img = otb.Registry.CreateApplication('ExtractROI')
        in_img.SetParameterString('in', input_image)
        in_img.SetParameterString('mode', 'fit')
        in_img.SetParameterString('mode.fit.vect', roi)
        in_img.Execute()
    else:
        print('ROI provided but cannot find file.')
        sys.exit(-1)

    # Get image size
    W, H = in_img.GetImageSize('out')

    # adapt memory amount to force fitting the number of cores
    if force_parallel and memory > (W * H * node_size) and W > params.margin and H > params.margin:
        memory = W * H * node_size

    # Compute optimal tile size
    T = (memory/n_proc)/node_size # approx. pixels per tile with margins
    tile_width_wmarg, tile_height_wmarg = floor(sqrt(T * W / H)), floor(sqrt(T * H / W))
    nominal_tw, nominal_th = tile_width_wmarg-2*params.margin, tile_height_wmarg-2*params.margin
    n_tiles_x, n_tiles_y = max(1,floor(W/nominal_tw)), max(1,floor(H/nominal_th))

    if (n_tiles_x == 1 and n_tiles_y == 1):
        # Fallback to classical GRM
        grm = otb.Registry.CreateApplication('GenericRegionMerging')
        grm.SetParameterInputImage('in', in_img.GetParameterOutputImage('out'))
        grm.SetParameterFloat('threshold', params.threshold)
        grm.SetParameterFloat('cw', params.color_weight)
        grm.SetParameterFloat('sw', params.spatial_weight)
        grm.SetParameterString('out', out_seg)
        grm.SetParameterOutputImagePixelType('out', otb.ImagePixelType_uint32)
        grm.ExecuteAndWriteOutput()

    else:
        tile_index_list = product(range(n_tiles_x), range(n_tiles_y))
        graph = os.path.splitext(out_seg)[0]
        arg_list = [(input_image, params, nominal_tw, nominal_th, x, graph, roi) for x in tile_index_list]

        with mp.Pool(n_proc) as p:
            graph_files = p.starmap(lsgrm_process_tile, arg_list)

        agg_app = otb.Registry.CreateApplication('AssembleGRMGraphs')
        agg_app.SetParameterInputImage('in', in_img.GetParameterOutputImage('out'))
        agg_app.SetParameterString('graph', graph)
        agg_app.SetParameterFloat('threshold', params.threshold)
        agg_app.SetParameterFloat('criterion.bs.cw', params.color_weight)
        agg_app.SetParameterFloat('criterion.bs.sw', params.spatial_weight)
        agg_app.SetParameterString('tiling', 'user')
        agg_app.SetParameterInt('tiling.user.sizex', nominal_tw)
        agg_app.SetParameterInt('tiling.user.sizey', nominal_th)
        #agg_app.ExecuteAndWriteOutput()

        if vectorize:
            agg_app.Execute()
            vec_app = otb.Registry.CreateApplication('SimpleVectorization')
            vec_app.SetParameterInputImage('in', agg_app.GetParameterOutputImage('out'))
            vec_app.SetParameterString('out', out_seg)
            vec_app.ExecuteAndWriteOutput()
            write_qgis_seg_style(out_seg.replace(ext, '.qml'))
        else:
            agg_app.SetParameterString('out', out_seg)
            agg_app.SetParameterOutputImagePixelType('out', otb.ImagePixelType_uint32)
            agg_app.ExecuteAndWriteOutput()

        if remove_graph:
            for f in chain(*graph_files):
                os.remove(f)

        return out_seg

def write_qgis_seg_style(out_file, line_color='255,255,0,255', line_width=0.46):
    with open(out_file, 'w') as f:
        f.writelines(
            ["<!DOCTYPE qgis PUBLIC 'http://mrcc.com/qgis.dtd' 'SYSTEM'>",
             "<qgis styleCategories=\"Symbology\" version=\"3.14.1-Pi\">",
             "  <renderer-v2 forceraster=\"0\" type=\"singleSymbol\" symbollevels=\"0\" enableorderby=\"0\">",
             "    <symbols>",
             "      <symbol alpha=\"1\" clip_to_extent=\"1\" name=\"0\" type=\"fill\" force_rhr=\"0\">",
             "        <layer class=\"SimpleLine\" locked=\"0\" enabled=\"1\" pass=\"0\">",
             "          <prop k=\"capstyle\" v=\"square\"/>",
             "          <prop k=\"customdash\" v=\"5;2\"/>",
             "          <prop k=\"customdash_map_unit_scale\" v=\"3x:0,0,0,0,0,0\"/>",
             "          <prop k=\"customdash_unit\" v=\"MM\"/>",
             "          <prop k=\"draw_inside_polygon\" v=\"0\"/>",
             "          <prop k=\"joinstyle\" v=\"bevel\"/>",
             "          <prop k=\"line_color\" v=\"{}\"/>".format(line_color),
             "          <prop k=\"line_style\" v=\"solid\"/>",
             "          <prop k=\"line_width\" v=\"{}\"/>".format(line_width),
             "          <prop k=\"line_width_unit\" v=\"MM\"/>",
             "          <prop k=\"offset\" v=\"0\"/>",
             "          <prop k=\"offset_map_unit_scale\" v=\"3x:0,0,0,0,0,0\"/>",
             "          <prop k=\"offset_unit\" v=\"MM\"/>",
             "          <prop k=\"ring_filter\" v=\"0\"/>",
             "          <prop k=\"use_custom_dash\" v=\"0\"/>",
             "          <prop k=\"width_map_unit_scale\" v=\"3x:0,0,0,0,0,0\"/>",
             "          <data_defined_properties>",
             "            <Option type=\"Map\">",
             "              <Option name=\"name\" value=\"\" type=\"QString\"/>",
             "              <Option name=\"properties\"/>",
             "              <Option name=\"type\" value=\"collection\" type=\"QString\"/>",
             "            </Option>",
             "          </data_defined_properties>",
             "        </layer>",
             "      </symbol>",
             "    </symbols>",
             "    <rotation/>",
             "    <sizescale/>",
             "  </renderer-v2>",
             "  <blendMode>0</blendMode>",
             "  <featureBlendMode>0</featureBlendMode>",
             "  <layerGeometryType>2</layerGeometryType>",
             "</qgis>"]
        )
        return out_file

def vectorize_tile(obj, region, to_keep, out):
    r = otb.itkRegion()
    r['index'][0], r['index'][1] = region[0], region[1]
    r['size'][0], r['size'][1] = region[2], region[3]
    obj.PropagateRequestedRegion('out', r)
    clip = obj.ExportImage('out')

    clip['array'] *= np.isin(clip['array'], to_keep)

    vec = otb.Registry.CreateApplication('SimpleVectorization')
    vec.ImportImage('in', clip)
    vec.SetParameterString('out', out)
    vec.ExecuteAndWriteOutput()

    return out

def tiled_vectorization(input_segm, nominal_tile_size, output_template):
    in_seg = to_otb_pipeline(input_segm)
    full = in_seg.GetImageAsNumpyArray('out')
    rp = regionprops(np.squeeze(full.astype(np.uint32)))

    W, H = in_seg.GetImageSize('out')
    tx, ty = int(W / nominal_tile_size[0]) + 1, int(H / nominal_tile_size[1]) + 1

    obj_to_tile = dict.fromkeys(range(tx*ty))
    tiles = dict.fromkeys(range(tx*ty))
    for i in range(tx*ty):
        obj_to_tile[i] = []
        tiles[i] = [np.inf, np.inf, 0, 0]

    for o in rp:
        if o.label != 0:
            ix, iy = int(o.bbox[1] / nominal_tile_size[0]), int(o.bbox[0] / nominal_tile_size[1])
            idx = ix * ty + iy
            obj_to_tile[idx].append(o.label)
            tiles[idx][0] = min(o.bbox[1], tiles[idx][0])
            tiles[idx][1] = min(o.bbox[0], tiles[idx][1])
            tiles[idx][2] = max(o.bbox[3], tiles[idx][2])
            tiles[idx][3] = max(o.bbox[2], tiles[idx][3])

    out_files = []
    for i in range(len(tiles)):
        tiles[i][2] -= tiles[i][0]
        tiles[i][3] -= tiles[i][1]
        in_seg = to_otb_pipeline(input_segm)
        if len(obj_to_tile[i]) > 0:
            out_files.append(output_template.format(i))
            vectorize_tile(in_seg, tiles[i], obj_to_tile[i], out_files[-1])

    return out_files

def get_bounding_boxes(input_segm):
    in_seg = to_otb_pipeline(input_segm)

    W, H = in_seg.GetImageSize('out')

    r = otb.itkRegion()
    r['index'][0] = 0
    r['size'][0], r['size'][1] = W, 1

    bboxes = {}
    for y in tqdm(range(H)):
        r['index'][1] = y
        in_seg.PropagateRequestedRegion('out', r)
        row = in_seg.ExportImage('out')
        row = np.squeeze(row['array'])
        lbl, pos = np.unique(row, return_index=True)
        _, rpos = np.unique(row[::-1], return_index=True)
        lbl = lbl.astype(np.uint32)
        rpos = len(row) - rpos - 1
        for i in range(len(lbl)):
            if lbl[i] not in bboxes.keys():
                bboxes[lbl[i]] = [pos[i], y, rpos[i], y]
            else:
                bboxes[lbl[i]][0] = min(bboxes[lbl[i]][0], pos[i])
                bboxes[lbl[i]][2] = max(bboxes[lbl[i]][2], rpos[i])
                bboxes[lbl[i]][3] = y

    return bboxes