import sys
import warnings

from osgeo import gdal,ogr
import otbApplication as otb
from theia_picker import TheiaCatalog
from pyproj import Transformer as T
import tqdm

from Common.otb_numpy_proc import to_otb_pipeline
import numpy as np
import glob
import os
import xml.etree.ElementTree as et
import zipfile
from osgeo import osr
import datetime
import uuid
import shutil

from Common.geometry import get_displacements_to_ref, get_clearest_central_image

def fetch(shp, dt, output_fld, credentials):

    ds = ogr.Open(shp)
    ly = ds.GetLayer(0)
    xt = ly.GetExtent()
    shp_srs = int(ly.GetSpatialRef().GetAuthorityCode(None))
    qry_srs = 4326

    i_bbox = [xt[0], xt[2], xt[1], xt[3]]
    if shp_srs == qry_srs:
        bbox = i_bbox
    else:
        tr = T.from_crs(shp_srs,qry_srs,always_xy=True)
        lon_min, lat_min = tr.transform(xt[0], xt[2])
        lon_max, lat_max = tr.transform(xt[1], xt[3])
        bbox = [lon_min, lat_min, lon_max, lat_max]

    theia = TheiaCatalog(credentials)
    features = theia.search(
        start_date=dt.split('-')[0],
        end_date=dt.split('-')[1],
        bbox=bbox,
        level='LEVEL2A'
    )

    lst = ['FRE_B2','FRE_B3','FRE_B4','FRE_B5','FRE_B6','FRE_B7','FRE_B8','FRE_B8A',
           'FRE_B11','FRE_B12','EDG_R1','SAT_R1','CLM_R1']

    #prg = tqdm.tqdm(total=len(features) * 13, desc="Fetching from Theia")
    for f in features:
        f.download_files(matching=lst, download_dir=output_fld)
        #prg.update()
    #prg.close()

    return S2TheiaPipeline(output_fld)

class S2TheiaTilePipeline:
    # --- BEGIN SENSOR PROTOTYPE ---

    NAME = 'S2-THEIA'
    REF_TYPE = otb.ImagePixelType_int16
    MSK_TYPE = otb.ImagePixelType_uint8
    PTRN_dir = 'SENTINEL2*'
    PTRN_ref = '_FRE_'
    B2_name = 'B2'
    PTRN_10m = ['*_FRE_B2.tif', '*_FRE_B3.tif', '*_FRE_B4.tif', '*_FRE_B8.tif']
    PTRN_20m = ['*_FRE_B5.tif', '*_FRE_B6.tif', '*_FRE_B7.tif', '*_FRE_B8A.tif', '*_FRE_B11.tif', '*_FRE_B12.tif']
    PTRN_msk = ['MASKS/*_EDG_R1.tif', 'MASKS/*_SAT_R1.tif', 'MASKS/*_CLM_R1.tif']
    MERG_msk = ['min', 'min', 'max']
    PTRN_ful = PTRN_10m[0:3] + PTRN_20m[0:3] + [PTRN_10m[3]] + PTRN_20m[3:]
    FEAT_exp = {
        'B2': 'im1b1',
        'B3': 'im1b2',
        'B4': 'im1b3',
        'B5': 'im1b4',
        'B6': 'im1b5',
        'B7': 'im1b6',
        'B8': 'im1b7',
        'B8A': 'im1b8',
        'B11': 'im1b9',
        'B12': 'im1b10',
        'NDVI': '(im1b7-im1b3)/(im1b7+im1b3+1e-6)',
        'NDWI': '(im1b2-im1b7)/(im1b2+im1b7+1e-6)',
        'BRI': 'sqrt(' + '+'.join(['im1b%d*im1b%d' % (i, i) for i in range(1, 11)]) + ')',
        'MNDWI': '(im1b2-im1b9)/(im1b2+im1b9+1e-6)',
        'SWNDVI': '(im1b9-im1b7)/(im1b9+im1b7+1e-6)',
        'NDRE': '(im1b7-im1b4)/(im1b7+im1b4+1e-6)'
    }
    NDT = -10000

    @classmethod
    def _check(cls,x):
        return cls.PTRN_dir.replace('*', '') in os.path.basename(x)

    @classmethod
    def _img_id(cls,x):
        if cls._check(x):
            if os.path.isdir(x):
                return x.split('_')[-5]
            elif os.path.splitext(x)[-1] == '.zip':
                return x.split('_')[-4]
            else:
                return None
        else:
            return None

    @classmethod
    def _img_date(cls,x):
        if cls._check(x):
            if os.path.isdir(x):
                return x.split('_')[-5].split('-')[0]
            elif os.path.splitext(x)[-1] == '.zip':
                return x.split('_')[-4].split('-')[0]
            else:
                return None
        else:
            return None

    @classmethod
    def _tile_id(cls,x):
        if cls._check(x):
            if os.path.isdir(x):
                return x.split('_')[-3]
            elif os.path.splitext(x)[-1] == '.zip':
                return x.split('_')[-2]
            else:
                return None
        else:
            return None

    @classmethod
    def _tile_cloud_percentage(cls, x):
        if cls._check(x):
            if os.path.isdir(x):
                fid = open(glob.glob(os.path.join(x, '*_MTD_ALL.xml'))[0], 'r')
                mtd = fid.read()
            elif os.path.splitext(x)[-1] == '.zip':
                arch = zipfile.ZipFile(x)
                fid = [name for name in arch.namelist() if name.endswith('_MTD_ALL.xml')]
                mtd = arch.read(fid[0])
            root = et.fromstring(mtd)
            f = filter(lambda x: x.get('name') == 'CloudPercent', root.findall('*/*/*/*/*/*'))
            r = list(f)
            return float(r[0].text)

    @classmethod
    def _check_roi(cls, x, roi, min_surf=0.0, temp_fld='/tmp'):
        if cls._check(x):
            if os.path.isdir(x):
                er = otb.Registry.CreateApplication('ExtractROI')
                bnd = glob.glob(os.path.join(x, cls.PTRN_20m[0]))[0]
            elif os.path.splitext(x)[-1] == '.zip':
                idf = cls.PTRN_20m[0].replace('*', '')
                arch = zipfile.ZipFile(x)
                fid = [name for name in arch.namelist() if idf in name]
                fle = arch.read(fid[0])
                bnd = os.path.join(temp_fld, os.path.basename(fid[0]))
                tgt = open(bnd, 'wb')
                with fle,tgt:
                    shutil.copyfileobj(fle, tgt)
            er.SetParameterString('in', bnd)
            er.SetParameterString('mode', 'fit')
            er.SetParameterString('mode.fit.vect', roi)
            er.Execute()
            arr = er.GetImageAsNumpyArray('out')
            if (np.sum(arr != cls.NDT) / (arr.shape[0]*arr.shape[1])) <= min_surf:
                return False
            else:
                return True

    def _process_mask(self, msks):
        msk_pipe = [otb.Registry.CreateApplication('BandMath')]
        [msk_pipe[-1].AddImageToParameterInputImageList('il', x.GetParameterOutputImage('out')) for x in
         msks]
        msk_pipe[-1].SetParameterString('exp', 'im1b1!=0 || im2b1!=0 || im3b1!=0')
        msk_pipe[-1].Execute()
        return msk_pipe

    # ---- END SENSOR PROTOTYPE ----

    def __init__(self, fld, tile, temp_fld='/tmp', input_date_interval=None, max_clouds_percentage=None,
                 filter_by_roi=None, roi_min_surf=0.0, dummy_read=False):
        self.pipe = []
        self.files = []
        self.types = []
        self.out_p = []
        self.out_idx = []

        self.id = str(uuid.uuid4())
        self.folder = os.path.abspath(fld)
        self.tile_id = tile
        self.temp_fld = temp_fld + os.sep + self.id
        self.image_list = self.parse_folder(self.folder, self.tile_id)
        self.pipe_start = 0

        if len(self.image_list) > 0:
            self.tile_id = self._tile_id(self.image_list[0])
            self.input_dates = [self._img_date(x) for x in self.image_list]
            self.tile_cloud_percentage = []

            if input_date_interval is not None:
                idx = [i for i in range(len(self.input_dates)) if self.input_dates[i]>=input_date_interval[0] and self.input_dates[i]<=input_date_interval[1]]
                self.image_list = [self.image_list[i] for i in idx]
                self.input_dates = [self.input_dates[i] for i in idx]

            if not dummy_read:

                if filter_by_roi is not None:
                    assert(os.path.exists(filter_by_roi))
                    idx = [i for i in range(len(self.input_dates)) if self._check_roi(self.image_list[i], filter_by_roi, roi_min_surf)]
                    self.image_list = [self.image_list[i] for i in idx]
                    self.input_dates = [self.input_dates[i] for i in idx]

                if max_clouds_percentage is not None:
                    self.tile_cloud_percentage = [self._tile_cloud_percentage(x) for x in self.image_list]
                    idx = [i for i in range(len(self.input_dates)) if self.tile_cloud_percentage[i] <= max_clouds_percentage]
                    self.image_list = [self.image_list[i] for i in idx]
                    self.input_dates = [self.input_dates[i] for i in idx]
                    self.tile_cloud_percentage = [self.tile_cloud_percentage[i] for i in idx]

                self.set_input_epsg()
                self.output_epsg = self.input_epsg
                self.output_dates = self.input_dates

                for img in self.image_list:
                    for p in self.PTRN_ful:
                        ifn, ofn = self.get_file(img, p)
                        self.append(to_otb_pipeline(ifn), ofn, self.REF_TYPE, 'out', is_output=True)
                    for p in self.PTRN_msk:
                        ifn, ofn = self.get_file(img, p)
                        self.append(to_otb_pipeline(ifn), ofn, self.MSK_TYPE, 'out', is_output=True)

        else:
            warnings.warn('Empty pipeline. Need to set preprocessed inputs?')

    def __del__(self):
        if os.path.exists(self.temp_fld):
            shutil.rmtree(self.temp_fld)

    def reset(self):
        self.pipe = []
        self.files = []
        self.types = []
        self.out_p = []
        self.out_idx = []

        self.pipe_start = 0

        for img in self.image_list:
            for p in self.PTRN_ful:
                ifn, ofn = self.get_file(img, p)
                self.append(to_otb_pipeline(ifn), ofn, self.REF_TYPE, 'out', is_output=True)
            for p in self.PTRN_msk:
                ifn, ofn = self.get_file(img, p)
                self.append(to_otb_pipeline(ifn), ofn, self.MSK_TYPE, 'out', is_output=True)

    def get_file(self, img, ptrn):
        in_fn, out_fn = None, None
        if os.path.isdir(img):
            in_fn = glob.glob(os.path.join(img, ptrn))[0]
            out_fn = in_fn.replace(self.folder, '').lstrip(os.sep)
        elif os.path.splitext(img)[-1] == '.zip':
            z = zipfile.ZipFile(img)
            out_fn = [x for x in z.namelist() if ptrn.split(os.sep)[-1].replace('*','') in x][0]
            in_fn = os.sep + 'vsizip' + os.sep + os.path.abspath(img) + os.sep + out_fn
        return in_fn, os.path.splitext(self.tile_id + os.sep + out_fn)[0] + '.tif'

    def parse_folder(self, fld, tile):
        img_list = [os.path.abspath(x) for x in glob.glob(os.path.join(fld, self.PTRN_dir))
                    if os.path.isdir(x) and self._tile_id(x) == tile]
        zip_list = [os.path.abspath(x) for x in glob.glob(os.path.join(fld, self.PTRN_dir))
                    if os.path.splitext(x)[-1] == '.zip' and self._tile_id(x) == tile]
        im_dict = {}
        for i in img_list:
            im_dict[self._img_id(i)] = i
        for i in zip_list:
            if self._img_id(i) not in im_dict.keys():
                im_dict[self._img_id(i)] = i
        return sorted(im_dict.values(), key=lambda x: self._img_id(x))

    def merge_same_dates(self):
        to_merge, curr = [], [0]
        i = 1
        while i < len(self.input_dates):
            if self.input_dates[i] == self.input_dates[i-1]:
                curr.append(i)
            else:
                to_merge.append(curr)
                curr = [i]
            i += 1
        to_merge.append(curr)

        T = 10 + len(self.PTRN_msk)
        new_dates = []
        proc_idx = self.out_idx.copy()
        self.out_idx = []
        self.pipe_start = len(self.pipe)
        for l in to_merge:
            if len(l) == 1:
                for k in range(T):
                    idx = T*l[0]+k
                    self.append(self.pipe[idx], self.files[idx], self.types[idx], 'out', is_output=True)
            else:
                for k in range(10):
                    mos = otb.Registry.CreateApplication('Mosaic')
                    for i in l:
                        mos.AddImageToParameterInputImageList('il', self.pipe[T*i+k].GetParameterOutputImage(self.out_p[T*i+k]))
                    mos.SetParameterInt('nodata', self.NDT)
                    mos.Execute()
                    self.append(mos, self.files[T*i+k], self.types[T*i+k], 'out', is_output=True)
                for k in range(10,T):
                    bm = otb.Registry.CreateApplication('BandMath')
                    for i in l:
                        bm.AddImageToParameterInputImageList('il', self.pipe[T*i+k].GetParameterOutputImage(self.out_p[T*i+k]))
                    bm.SetParameterString('exp', self.MERG_msk[k-10] + '({})'.format(','.join(['im{}b1'.format(w+1) for w in range(len(l))])))
                    bm.Execute()
                    self.append(bm, self.files[T * i + k], self.types[T * i + k], 'out', is_output=True)
            new_dates.append(self.input_dates[l[-1]])

        if self.output_dates == self.input_dates:
            self.output_dates = new_dates.copy()
        self.input_dates = new_dates.copy()
        return

    def get_coverage(self):
        proc_idx = self.out_idx.copy()
        proc_idx = proc_idx[1::2]
        cc = []
        from tqdm import tqdm
        for t in proc_idx:
            arr = self.pipe[t].GetImageAsNumpyArray('out', otb.ImagePixelType_uint8)
            cc.append(np.sum(arr)/(arr.shape[0]*arr.shape[1]))
            self.pipe[t].FreeRessources()
            arr = None
        return cc

    def set_input_epsg(self):
        f, _ = self.get_file(self.image_list[0], self.PTRN_20m[0])
        ds = gdal.Open(f)
        self.input_epsg = osr.SpatialReference(wkt=ds.GetProjection()).GetAuthorityCode('PROJCS')
        ds = None

    def append(self, app, fname=None, ftype=None, outp=None, is_output=False):
        if is_output:
            self.out_idx.append(len(self.pipe))
        self.pipe.append(app)
        self.files.append(fname)
        self.types.append(ftype)
        self.out_p.append(outp)

    def clip(self, roi):
        assert(os.path.exists(roi))
        proc_idx = self.out_idx.copy()
        self.out_idx = []
        self.pipe_start = len(self.pipe)
        for t in proc_idx:
            er = otb.Registry.CreateApplication('ExtractROI')
            er.SetParameterInputImage('in', self.pipe[t].GetParameterOutputImage(self.out_p[t]))
            er.SetParameterString('mode', 'fit')
            er.SetParameterString('mode.fit.vect', roi)
            er.Execute()
            fn = self.files[t]
            ty = self.types[t]
            self.append(er, fn, ty, 'out', is_output=True)

    def preprocess(self):
        proc_idx = self.out_idx.copy()
        self.out_idx = []
        num_of_mask_inputs = len(self.PTRN_msk)
        for t in proc_idx[::10+num_of_mask_inputs]:
            idx_to_stack = []
            for k in range(t,t+10):
                if k % (10 + num_of_mask_inputs) in [3, 4, 5, 7, 8, 9]:
                    si = otb.Registry.CreateApplication('Superimpose')
                    si.SetParameterInputImage('inm', self.pipe[k].GetParameterOutputImage('out'))
                    si.SetParameterInputImage('inr', self.pipe[t].GetParameterOutputImage('out'))
                    si.Execute()
                    self.append(si)
                    cr = otb.Registry.CreateApplication('BandMath')
                    cr.AddImageToParameterInputImageList('il', self.pipe[-1].GetParameterOutputImage('out'))
                    cr.SetParameterString('exp', '(im1b1 < 0 && im1b1 !=' + str(self.NDT) + ') ? 0 : im1b1')
                    cr.Execute()
                    self.append(cr)
                    idx_to_stack.append(len(self.pipe)-1)
                else:
                    idx_to_stack.append(k)

            cct = otb.Registry.CreateApplication('ConcatenateImages')
            [cct.AddImageToParameterInputImageList('il', self.pipe[j].GetParameterOutputImage('out')) for j in idx_to_stack]
            cct.Execute()
            fn = self.files[t].replace(self.B2_name, 'STACK')
            ty = self.REF_TYPE
            self.append(cct, fn, ty, 'out', is_output=True)

            msks = self.pipe[t+10:t+10+num_of_mask_inputs]
            msk_pipe = self._process_mask(msks)
            for app in msk_pipe[:-1]:
                self.append(app)
            fn = self.files[t].replace(self.B2_name, 'BINARY_MASK')
            ty = self.MSK_TYPE
            self.append(msk_pipe[-1], fn, ty, 'out', is_output=True)

    def write_src_quicklooks(self, fld, bnds = [3,2,1], scale_factor=0.2):
        proc_idx = self.out_idx.copy()
        fns = [fld + os.sep + self.NAME + '_' + self._img_id(x) + '.png' for x in self.image_list]

        for t,m in zip(proc_idx[::2],proc_idx[1::2]):
            rt = otb.Registry.CreateApplication('RigidTransformResample')
            rt.SetParameterInputImage('in', self.pipe[t].GetParameterOutputImage('out'))
            rt.SetParameterFloat('transform.type.id.scalex', scale_factor)
            rt.SetParameterFloat('transform.type.id.scaley', scale_factor)
            rt.Execute()

            rtm = otb.Registry.CreateApplication('RigidTransformResample')
            rtm.SetParameterInputImage('in', self.pipe[m].GetParameterOutputImage('out'))
            rtm.SetParameterFloat('transform.type.id.scalex', scale_factor)
            rtm.SetParameterFloat('transform.type.id.scaley', scale_factor)
            rtm.SetParameterString('interpolator', 'nn')
            rtm.Execute()

            bmm = otb.Registry.CreateApplication('BandMath')
            bmm.AddImageToParameterInputImageList('il', rtm.GetParameterOutputImage('out'))
            bmm.SetParameterString('exp', '1 - im1b1')
            if not os.path.exists(self.temp_fld):
                os.makedirs(self.temp_fld)
            bmm.SetParameterString('out', self.temp_fld + os.sep + 'msk.tif')
            bmm.ExecuteAndWriteOutput()

            dc = otb.Registry.CreateApplication('DynamicConvert')
            dc.SetParameterInputImage('in', rt.GetParameterOutputImage('out'))
            dc.SetParameterString('mask', self.temp_fld + os.sep + 'msk.tif')
            dc.SetParameterString('channels', 'rgb')
            dc.SetParameterInt('channels.rgb.red', bnds[0])
            dc.SetParameterInt('channels.rgb.green', bnds[1])
            dc.SetParameterInt('channels.rgb.blue', bnds[2])
            dc.SetParameterString('quantile.high', '2')
            dc.SetParameterString('quantile.low', '2')
            dc.SetParameterString('outmin', '1')
            dc.SetParameterString('outmax', '255')

            dc.SetParameterString('out', fns.pop(0))
            dc.SetParameterOutputImagePixelType('out', otb.ImagePixelType_uint8)
            dc.ExecuteAndWriteOutput()

    def reproject(self, epsg):
        if epsg != self.input_epsg:
            self.output_epsg = epsg
            proc_idx = self.out_idx.copy()
            self.out_idx = []
            i = 0
            for t in proc_idx:
                rp = otb.Registry.CreateApplication('OrthoRectification')
                rp.SetParameterInputImage('io.in', self.pipe[t].GetParameterOutputImage('out'))
                rp.SetParameterString('map', 'epsg')
                rp.SetParameterString('map.epsg.code', self.output_epsg)
                rp.SetParameterString('opt.gridspacing', '40')
                if i % 2 == 0:
                    rp.SetParameterString('outputs.default', str(self.NDT))
                    fn = self.files[t].replace('_STACK.tif', '_STACK_' + self.output_epsg + '.tif')
                    ty = self.REF_TYPE
                else:
                    rp.SetParameterString('interpolator', 'nn')
                    fn = self.files[t].replace('_BINARY_MASK.tif', '_BINARY_MASK_' + self.output_epsg + '.tif')
                    ty = self.MSK_TYPE
                rp.Execute()
                self.append(rp, fn, ty, 'io.out', is_output=True)
                i += 1

    def coregister(self, ref_img, ref_bnd, tgt_bnd, out_fld=None):
        # PointMatchCoregistration is a pipeline-breaking app
        # Need to write outputs, compute coregistered and write again
        # Then update pipeline
        to_del = self.write_outputs(self.temp_fld, update_pipe=True)
        proc_idx = self.out_idx.copy()
        self.out_idx = []
        i = 0
        curr_json = ''
        for t in proc_idx:
            cr = otb.Registry.CreateApplication('PointMatchCoregistration')
            cr.SetParameterInputImage('in', self.pipe[t].GetParameterOutputImage('out'))
            cr.SetParameterInt('fallback', 1)
            if (i % 2) == 0:
                curr_json = os.path.splitext(self.files[t])[0] + '.json'
                curr_json = os.path.join(self.temp_fld, curr_json)
                cr.SetParameterInt('band', tgt_bnd)
                cr.SetParameterString('inref', ref_img)
                cr.SetParameterInt('bandref', ref_bnd)
                cr.SetParameterString('outjson', curr_json)
                fn = self.files[t].replace('_STACK.tif', '_STACK_COREG.tif')
                ty = self.REF_TYPE
            else:
                cr.SetParameterString('inmodel', curr_json)
                cr.SetParameterString('interpolator', 'nn')
                fn = self.files[t].replace('_BINARY_MASK.tif', '_BINARY_MASK_COREG.tif')
                ty = self.MSK_TYPE
            self.append(cr, fn, ty, 'out', is_output=True)
            i += 1
        if out_fld is None:
            out_fld = self.temp_fld
        self.write_outputs(out_fld, update_pipe=True, flag_nodata=True)

    def rigid_align(self, ext_ref=None, ref_date=None, cov_th=0.8, match_band=2, ext_band=0,
                    geobin_radius=32, num_geobins=128, margin=32, filter=5, out_param='out'):
        proc_idx = self.out_idx.copy()
        self.out_idx = []
        img = [self.pipe[t] for t in range(self.pipe_start+match_band,
                                           self.pipe_start+(10+len(self.PTRN_msk))*len(self.input_dates),
                                           10+len(self.PTRN_msk))]
        msk = [self.pipe[t] for t in proc_idx[1::2]]
        RX, RY = img[0].GetImageSpacing(out_param)
        ref_idx = None
        etx_ref_si = None
        if ext_ref is None:
            if ref_date is None:
                ref_idx = get_clearest_central_image(msk, self.input_dates, cov_th)
            else:
                dts = [datetime.datetime.strptime(x, '%Y%m%d') for x in self.input_dates]
                ctr = datetime.datetime.strptime(ref_date, '%Y%m%d')
                ref_idx = dts.index(min(dts, key=lambda x: abs(ctr - x)))

            shifts = get_displacements_to_ref(img, msk, ref_mode="index", ref_idx=ref_idx, band=0,
                                              out_param=out_param, geobin_radius=geobin_radius,
                                              num_geobins=num_geobins, margin=margin, filter=filter)
        else:
            assert(os.path.exists(ext_ref))
            ext_ref_si = otb.Registry.CreateApplication("Superimpose")
            ext_ref_si.SetParameterInputImage('inr', self.pipe[proc_idx[0]].GetParameterOutputImage(self.out_p[proc_idx[0]]))
            ext_ref_si.SetParameterString('inm', ext_ref)
            ext_ref_si.Execute()
            ext_msk_si = otb.Registry.CreateApplication("BandMath")
            ext_msk_si.AddImageToParameterInputImageList('il', ext_ref_si.GetParameterOutputImage('out'))
            ext_msk_si.SetParameterString('exp', 'im1b1 != im1b1') # To create a mask of zeros...
            ext_msk_si.Execute()

            shifts = get_displacements_to_ref(img, msk, ref_mode="external", ref_ext=ext_ref_si, ref_ext_msk=ext_msk_si,
                                              band=0, ext_band=ext_band, out_param=out_param,
                                              geobin_radius=geobin_radius, num_geobins=num_geobins,
                                              margin=margin, filter=filter)

        k = 0
        for i in proc_idx:
            if ((i - proc_idx[0]) % 2) == 0:
                rt = otb.Registry.CreateApplication('RigidTransformResample')
                rt.SetParameterInputImage('in', self.pipe[i].GetParameterOutputImage(self.out_p[i]))
                rt.SetParameterString('transform.type', 'translation')
                rt.SetParameterFloat('transform.type.translation.tx', RX*shifts[k][1])
                rt.SetParameterFloat('transform.type.translation.ty', RY*shifts[k][0])
                rt.SetParameterString('interpolator', 'linear')
                rt.Execute()
                fn = self.files[i].replace('FRE_STACK', 'FRE_STACK_ALIGNED')
                ty = self.types[i]
                self.append(rt, fn, ty, 'out', is_output=True)
            else:
                rt = otb.Registry.CreateApplication('RigidTransformResample')
                rt.SetParameterInputImage('in', self.pipe[i].GetParameterOutputImage(self.out_p[i]))
                rt.SetParameterString('transform.type', 'translation')
                rt.SetParameterFloat('transform.type.translation.tx', RX*shifts[k][1])
                rt.SetParameterFloat('transform.type.translation.ty', RY*shifts[k][0])
                rt.SetParameterString('interpolator', 'nn')
                rt.Execute()
                fn = self.files[i].replace('BINARY_MASK', 'BINARY_MASK_ALIGNED')
                ty = self.types[i]
                self.append(rt, fn, ty, 'out', is_output=True)
                k += 1

        return



    def parse_dates(self, fn):
        with open(fn) as f:
            return [l for l in f.read().splitlines() if l]

    def skip_gapfill(self):
        proc_idx = self.out_idx.copy()
        self.out_idx = []
        for t in proc_idx[::2]:
            self.append(self.pipe[t], self.files[t], self.types[t], 'out', is_output=True)
        return


    def gapfill(self, output_dates=None, on_disk=False):
        #assert(os.path.exists(output_dates))
        proc_idx = self.out_idx.copy()
        self.out_idx = []

        stk = otb.Registry.CreateApplication('ConcatenateImages')
        [stk.AddImageToParameterInputImageList('il', self.pipe[x].GetParameterOutputImage(self.out_p[x])) for x in proc_idx[::2]]
        stk.Execute()
        self.append(stk)

        mst = otb.Registry.CreateApplication('ConcatenateImages')
        [mst.AddImageToParameterInputImageList('il', self.pipe[x].GetParameterOutputImage(self.out_p[x])) for x in proc_idx[1::2]]
        mst.Execute()
        self.append(mst)

        if not os.path.exists(self.temp_fld):
            os.makedirs(self.temp_fld)

        with open(os.path.join(self.folder, self.tile_id + '_indates.txt'), 'w') as df:
            [df.write(x + '\n') for x in self.input_dates]

        if output_dates is not None:
            od = self.parse_dates(output_dates)
            self.output_dates = od
        else:
            output_dates = os.path.join(self.folder, self.tile_id + '_indates.txt')
            od = self.input_dates

        self.output_dates = od

        gf = otb.Registry.CreateApplication('ImageTimeSeriesGapFilling')
        gf.SetParameterInputImage('in', self.pipe[-2].GetParameterOutputImage('out'))
        gf.SetParameterInputImage('mask', self.pipe[-1].GetParameterOutputImage('out'))
        gf.SetParameterInt('comp', 10)
        gf.SetParameterString('it', 'linear')
        gf.SetParameterString('id', os.path.join(self.folder, self.tile_id + '_indates.txt'))
        gf.SetParameterString('od', output_dates)
        if not on_disk:
            gf.Execute()
            gf_out = gf
        else:
            gf_fn = self.temp_fld + os.sep + 'SENTINEL2_' + self.tile_id + '_GAPFILLED_FULL.tif'
            if not os.path.exists(gf_fn):
                gf.SetParameterString("out", gf_fn)
                gf.ExecuteAndWriteOutput()
            gf_out = to_otb_pipeline(gf_fn)
        self.append(gf_out)

        t = 1
        for d in od:
            ch_list = ['Channel%d' % i for i in range(t,t+10)]
            t += 10
            er = otb.Registry.CreateApplication('ExtractROI')
            er.SetParameterInputImage('in', gf_out.GetParameterOutputImage('out'))
            er.UpdateParameters()
            er.SetParameterStringList('cl', ch_list)
            er.Execute()
            dn = 'SENTINEL2_' + self.tile_id + '_GAPFILL_' + d
            fn = self.tile_id + os.sep + dn + os.sep + dn + '_STACK.tif'
            ty = self.REF_TYPE
            self.append(er, fn, ty, 'out', is_output=True)

    def generate_feature_stack(self, feat_list=None):
        proc_idx = self.out_idx.copy()
        self.out_idx = []

        exp_list = self.FEAT_exp.values()
        stack_name = 'FEAT'
        if feat_list is not None:
            exp_list = [self.FEAT_exp[x] for x in feat_list]
            stack_name = '_'.join(feat_list)

        for t in proc_idx:
            bm = otb.Registry.CreateApplication('BandMathX')
            bm.AddImageToParameterInputImageList('il', self.pipe[t].GetParameterOutputImage('out'))
            bm.SetParameterString('exp', '{' + ';'.join(exp_list) + '}')
            bm.Execute()
            fn = self.files[t].replace('_STACK.tif', '_' + stack_name + '.tif')
            ty = otb.ImagePixelType_float
            self.append(bm, fn, ty, 'out', is_output=True)

        return stack_name

    def generate_time_series(self, feat_list):
        proc_idx = self.out_idx.copy()
        self.out_idx = []

        for feat in feat_list:
            bm = otb.Registry.CreateApplication('BandMathX')
            i = 1
            expr = []
            for t in proc_idx:
                expr.append(self.FEAT_exp[feat].replace('im1', 'im%d' % i))
                bm.AddImageToParameterInputImageList('il', self.pipe[t].GetParameterOutputImage('out'))
                i += 1
            bm.SetParameterString('exp', '{' + ';'.join(expr) + '}')
            bm.Execute()
            fn = self.tile_id + os.sep + 'SENTINEL2_' + self.tile_id + '_GAPFILL_' + feat + '.tif'
            ty = otb.ImagePixelType_float
            self.append(bm, fn, ty, 'out', is_output=True)

    def set_preprocess_output(self, fld):
        lst = self.parse_folder(os.path.join(fld, self.tile_id), self.tile_id)
        if len(lst) == len(self.image_list):
            self.out_idx = []
            for img in lst:
                f = glob.glob(os.path.join(img, '*STACK.tif'))[0]
                self.append(to_otb_pipeline(f), f, self.REF_TYPE, 'out', is_output=True)
                f = glob.glob(os.path.join(img, '*BINARY_MASK.tif'))[0]
                self.append(to_otb_pipeline(f), f, self.MSK_TYPE, 'out', is_output=True)
        else:
            warnings.warn("No matching files found as preprocess output. No modification to pipeline.")

    def set_gapfilled_output(self, file_pattern):
        lst = sorted(glob.glob(file_pattern))
        self.out_idx = []
        if len(lst) > 0:
            for f in lst:
                pf = os.path.join(self.tile_id, os.path.basename(f))
                self.append(to_otb_pipeline(f), pf, self.REF_TYPE, 'out', is_output=True)
        else:
            warnings.warn("No matching files found as preprocess output. No modification to pipeline.")

    def write_outputs(self, fld, roi=None, update_pipe=False, compress=False, flag_nodata=False):
        out = []
        proc_idx = self.out_idx.copy()
        if roi is not None:
            out_idx_bck = self.out_idx.copy()
            pipe_length = len(self.pipe)
            self.clip(roi)
            proc_idx = self.out_idx.copy()
        if update_pipe:
            assert roi is None, 'Cannot set output files as pipe input over a ROI, use clip function instead.'
            self.out_idx = []
        for t in proc_idx:
            out_file = os.path.join(fld, self.files[t])
            if compress:
                out_file += '?gdal:co:compress=deflate'
            if not os.path.exists(os.path.dirname(out_file)):
                os.makedirs(os.path.dirname(out_file))
            self.pipe[t].SetParameterString(self.out_p[t], out_file)
            self.pipe[t].SetParameterOutputImagePixelType(self.out_p[t], self.types[t])
            self.pipe[t].ExecuteAndWriteOutput()
            out.append(out_file)
            if update_pipe:
                self.append(to_otb_pipeline(out_file), self.files[t], self.types[t], 'out', is_output=True)
        if roi is not None:
            self.out_idx = out_idx_bck
            self.pipe = self.pipe[:pipe_length]
            self.files = self.files[:pipe_length]
            self.types = self.types[:pipe_length]
            self.out_p = self.out_p[:pipe_length]

        if flag_nodata:
            if isinstance(flag_nodata, bool):
                val = self.NDT
            elif isinstance(flag_nodata, int) or isinstance(flag_nodata,float):
                val = flag_nodata
            for f in out:
                ds = gdal.Open(f, 1)
                for i in range(ds.RasterCount):
                    ds.GetRasterBand(i+1).SetNoDataValue(val)
                ds = None

        return out


class S2TheiaPipeline:

    S2TilePipeline = S2TheiaTilePipeline
    _check = S2TilePipeline._check
    _tile_id = S2TilePipeline._tile_id

    def __init__(self, fld, temp_fld='/tmp', input_date_interval=None, max_clouds_percentage=None):
        self.folder = fld
        self.temp_fld = temp_fld
        self.input_date_interval = input_date_interval
        self.max_clouds_percentage = max_clouds_percentage
        self.tile_list = set()
        img_list = [os.path.abspath(x) for x in glob.glob(os.path.join(self.folder, self.S2TilePipeline.PTRN_dir))
                    if os.path.isdir(x) or os.path.splitext(x)[-1] == '.zip']
        [self.tile_list.add(self._tile_id(x)) for x in img_list if self._check(x)]
        self.tiles = [self.S2TilePipeline(fld, t, self.temp_fld, self.input_date_interval, self.max_clouds_percentage) for t in self.tile_list]
        self.output_dates = None
        self.roi = None
        self.output_epsg = self.tiles[0].input_epsg

    def __del__(self):
        for x in self.tiles:
            del x

    def set_output_dates_by_file(self, od):
        self.output_dates = od

    def set_output_dates(self, start, end, step=10):
        start_date = datetime.datetime(int(start[0:4]), int(start[4:6]), int(start[6:8]))
        end_date = datetime.datetime(int(end[0:4]), int(end[4:6]), int(end[6:8]))
        d, st = start_date, datetime.timedelta(step)
        tmstp = str(datetime.datetime.timestamp(datetime.datetime.now())).replace('.', '')
        ofn = self.temp_fld + os.sep + 's2ppl_' + tmstp + '_output_dates.txt'
        with open(ofn, 'w') as f:
            while d < end_date:
                f.write(d.strftime('%Y%m%d')+'\n')
                d += st
        self.output_dates = ofn

    def set_roi(self, roi):
        self.roi = roi

    def set_output_epsg(self, epsg):
        self.output_epsg = epsg

    def extract_feature_set(self, out_fld, feat_list=None, mosaicking=None, store_gapfill=False,
                            align=False, align_to=None, align_to_band=3, align_using_band=3,
                            warp_to=None, warp_to_band=1, warp_using_band=3):
        out = []
        stack_name = ''
        if self.output_dates is not None:
            for t in self.tiles:
                if self.roi is not None:
                    t.clip(self.roi)
                t.preprocess()
                if align:
                    if type(align_to) == str and os.path.exists(align_to):
                        t.rigid_align(ext_ref=align_to, match_band=align_using_band-1, ext_band=align_to_band-1)
                    elif type(align_to) == str :
                        try:
                            d = datetime.datetime.strptime(align_to, '%Y%m%d')
                            t.rigid_align(ref_date=align_to, match_band=align_using_band-1)
                        except:
                            raise ValueError("Provided string is not a valid date nor a valid file.")
                    elif align_to is None:
                        t.rigid_align(match_band=align_using_band-1)
                t.reproject(self.output_epsg)
                if warp_to is not None:
                    t.coregister(warp_to, warp_to_band, warp_using_band, t.temp_fld)
                t.gapfill(self.output_dates, store_gapfill)
                stack_name = t.generate_feature_stack(feat_list)
                out.append(t.write_outputs(out_fld))
                t.reset()
            if len(self.tiles) > 1 and mosaicking == 'vrt':
                out_mos = []
                vrtopt = gdal.BuildVRTOptions()
                for i in range(len(self.tiles[0].output_dates)):
                    fn = out_fld + os.sep + 'SENTINEL2_MOSAIC_GAPFILL_' + self.tiles[0].output_dates[i] + '_' + stack_name + '.vrt'
                    to_mosaic = [x[i] for x in out]
                    gdal.BuildVRT(fn, to_mosaic, options=vrtopt)
                    out_mos.append(fn)
                return out_mos
        return out

    def extract_time_series(self, out_fld, feat_list, mosaicking=None, store_gapfill=False):
        out = []
        if self.output_dates is not None:
            for t in self.tiles:
                t.merge_same_dates()
                if self.roi is not None:
                    t.clip(self.roi)
                t.preprocess()
                t.reproject(self.output_epsg)
                t.gapfill(self.output_dates, store_gapfill)
                t.generate_time_series(feat_list)
                out.append(t.write_outputs(out_fld))
                t.reset()
            if mosaicking == 'vrt':
                out_mos = []
                vrtopt = gdal.BuildVRTOptions()
                for i in range(len(feat_list)):
                    fn = out_fld + os.sep + 'SENTINEL2_MOSAIC_GAPFILL_' + feat_list[i] + '.vrt'
                    to_mosaic = [x[i] for x in out]
                    gdal.BuildVRT(fn, to_mosaic, options=vrtopt)
                    out_mos.append(fn)
                return out_mos
        return out