import numpy as np
import sys
from time import sleep
import os, psutil
from functools import partial
import string, random
from osgeo.gdal import BuildVRT
import otbApplication as otb

def randomword(length):
    return ''.join(random.choice(string.ascii_lowercase) for i in range(length))

def generate_stream_rois(img_size, n_lines):
    streams = []
    n_streams = int(img_size[1]/n_lines)
    for i in range(n_streams):
        stream = {}
        stream['sizex'] = img_size[0]
        stream['sizey'] = n_lines
        stream['startx'] = 0
        stream['starty'] = i * n_lines
        streams.append(stream)

    if (img_size[1] % n_lines) :
        stream = {}
        stream['sizex'] = img_size[0]
        stream['sizey'] = img_size[1] % n_lines
        stream['startx'] = 0
        stream['starty'] = n_streams * n_lines
        streams.append(stream)

    return streams

def pad_format_tile(tile, img_size, pad):
    new_tile = np.array(tile) + np.array([-pad, -pad, pad, pad])
    pad_out = [0,0,0,0]

    if new_tile[0] < 0:
        pad_out[0] = pad
        new_tile[0] = 0
    if new_tile[1] < 0:
        pad_out[1] = pad
        new_tile[1] = 0
    if new_tile[2] > img_size[0]-1:
        pad_out[2] = pad
        new_tile[2] = img_size[0]-1
    if new_tile[3] > img_size[1]-1:
        pad_out[3] = pad
        new_tile[3] = img_size[1]-1

    new_tile[2] = new_tile[2] - new_tile[0] + 1
    new_tile[3] = new_tile[3] - new_tile[1] + 1
    pads1 = (pad_out[0], pad_out[2])
    pads2 = (pad_out[1], pad_out[3])

    return ([int(x) for x in new_tile],pads1,pads2)


def generate_tile_rois(img_size, tile_size, pad = 0):
    tiles = []
    n_tiles = [int(img_size[0]/tile_size[0]), int(img_size[1]/tile_size[1])]
    for i in range(2):
        if img_size[i] % tile_size[i] :
            n_tiles[i] += 1
    for c in range(n_tiles[0]):
        for r in range(n_tiles[1]):
            pad_out = [0,0,0,0]
            bbox = [c*tile_size[0],r*tile_size[1],min(c*tile_size[0]+tile_size[0]-1,img_size[0]-1),min(r*tile_size[1]+tile_size[1]-1,img_size[1]-1)]
            tiles.append(pad_format_tile(bbox,img_size,pad))
    return tiles

'''
def generate_stream_regions(img_size, n_lines):
    streams = []
    n_streams = int(img_size[1]/n_lines)
    for i in range(n_streams):
        stream = otb.itkRegion()
        stream['size'][0] = img_size[0]
        stream['size'][1] = n_lines
        stream['index'][0] = 0
        stream['index'][1] = i * n_lines
        streams.append(stream)

    if (img_size[1] % n_lines) :
        stream = otb.itkRegion()
        stream['size'][0] = img_size[0]
        stream['size'][1] = img_size[1] % n_lines
        stream['index'][0] = 0
        stream['index'][1] = n_streams * n_lines
        streams.append(stream)

    return streams
'''

'''
def generate_streams(otb_pipeline):
    otb_pipeline.Execute()
    img_size = tuple(otb_pipeline.GetImageSize("out"))
    stream_rois = generate_stream_rois(img_size,1000)
    ext_roi_apps = []
    for roi in stream_rois :
        print("setting roi : " + str(roi))
        app = otb.Registry.CreateApplication('ExtractROI')
        app.SetParameterInputImage("in", otb_pipeline.GetParameterOutputImage("out"))
        app.SetParameterInt("sizex", roi['sizex'])
        app.SetParameterInt("sizey", roi['sizey'])
        app.SetParameterInt("startx", roi['startx'])
        app.SetParameterInt("starty", roi['starty'])
        ext_roi_apps.append(app)
        ext_roi_apps[-1].Execute()

    return ext_roi_apps
'''

def stream_function(otb_in, function, tile_size, write_output_to=None, pad=None, out_key = 'out', work_dir='.', prefix = None):
    otb_pipeline = to_otb_pipeline(otb_in)

    if prefix is None:
        prefix = randomword(16)

    if pad is None:
        pad = check_function_padding(function)

    otb_pipeline.Execute()
    img_size = tuple(otb_pipeline.GetImageSize(out_key))

    streams = generate_tile_rois(img_size,tile_size,pad)

    t = 1
    output_streams = []
    for stream in streams:
        roi = otb.Registry.CreateApplication('ExtractROI')
        roi.SetParameterInputImage('in', otb_pipeline.GetParameterOutputImage(out_key))
        roi.SetParameterInt('startx', stream[0][0])
        roi.SetParameterInt('starty', stream[0][1])
        roi.SetParameterInt('sizex', stream[0][2])
        roi.SetParameterInt('sizey', stream[0][3])
        roi.Execute()
        img = roi.ExportImage('out')

        # Launch numpy function !
        padw = (stream[1],stream[2]) if img['array'].ndim == 2 else (stream[1],stream[2],(0,0))
        img['array'] = np.ascontiguousarray(function(np.pad(img['array'],padw)))

        cpy = otb.Registry.CreateApplication('ExtractROI')
        cpy.ImportVectorImage('in', img)
        output_streams.append(os.path.join(work_dir, prefix + '_' + str(t) + '.tif'))
        cpy.SetParameterString('out', output_streams[-1])
        cpy.ExecuteAndWriteOutput()
        print(cpy.GetImageSize('out'))
        app = None
        img = None
        t += 1

    vrt = os.path.join(work_dir, prefix + '.vrt')
    BuildVRT(vrt, output_streams)
    output_streams.append(vrt)

    new_pipeline = otb.Registry.CreateApplication('ExtractROI')
    new_pipeline.SetParameterString('in', vrt)

    if write_output_to is not None:
        new_pipeline.SetParameterString('out', write_output_to)
        new_pipeline.ExecuteAndWriteOutput()
        [os.remove(f) for f in output_streams]
        return write_output_to, []
    else:
        return new_pipeline, output_streams

'''
def merge_streams(ext_roi_apps):
    app = otb.Registry.CreateApplication('Mosaic')
    i=0
    for roi in ext_roi_apps:
        img = roi.ExportImage('out')
        img['array'] /= 10
        print(psutil.Process(os.getpid()).memory_info().rss)
        sleep(5)
        app.ImportVectorImage('il',img,i)
        i += 1
    return app
'''

def test_function(arr):
    #return arr[2:-2,2:-2,:]
    return np.sum(arr,axis=2)[6:-6,6:-6]

def check_function_padding(function, test_tile_size=(100,100)):
    try:
        out = function(np.ones((test_tile_size)))
    except:
        out = function(np.ones((test_tile_size[0],test_tile_size[1],3)))
    return int(max(abs(test_tile_size[0] - out.shape[0]), abs(test_tile_size[1] - out.shape[1])) / 2)

def to_otb_pipeline(obj):
    if isinstance(obj, str):
        ppl = otb.Registry.CreateApplicationWithoutLogger('ExtractROI')
        ppl.SetParameterString('in',obj)
        ppl.Execute()
    elif isinstance(obj,otb.Application):
        ppl = obj
    elif isinstance(obj,list): # to-do, check if all sublists are lists of otbApp
        ppl = obj
    else:
        sys.exit("Impossible to convert object ot OTB pipeline")
    return ppl

def extract_roi_as_numpy(img: otb.Application, startx=None, starty=None, sizex=None, sizey=None, bands=None, out_param='out'):
    ers = otb.Registry.CreateApplication('ExtractROI')
    ers.SetParameterInputImage('in', img.GetParameterOutputImage(out_param))
    if startx is not None:
        ers.SetParameterInt('startx', startx)
    if starty is not None:
        ers.SetParameterInt('starty', starty)
    if sizex is not None:
        ers.SetParameterInt('sizex', sizex)
    if sizey is not None:
        ers.SetParameterInt('sizey', sizey)
    if bands is not None:
        ers.SetParameterStringList('cl', ['Channel{}'.format(b + 1) for b in bands])
    ers.Execute()
    out = ers.ExportImage('out')
    return out['array'].copy().squeeze()

def do_something(fn):
    # Create a smoothing application
    app = otb.Registry.CreateApplication("ExtractROI")
    app.SetParameterString("in", fn)

    app2, to_del = stream_function(app, partial(test_function), tile_size=(1000,1000))

    app2.SetParameterString('out','pippo.tif')
    app2.ExecuteAndWriteOutput()

    [os.remove(f) for f in to_del]

if __name__ == "__main__":
    do_something("/DATA/Koumbia/Tong/SENTINEL2_Aug_Oct_2017_RGB_NIR_NDVI.tif")