import numpy as np
from skimage.feature import SIFT, match_descriptors
import otbApplication as otb
from typing import List
from tqdm import tqdm

def compute_displacement(_src: otb.Application, _tgt: otb.Application,
                         src_band=2, tgt_band=2,
                         out_param_src='out', out_param_tgt='out',
                         geobin_size=32, geobin_spacing=256, margin=32):

    #W,H = src.shape[0], src.shape[1]
    sz = _tgt.GetImageSize(out_param_tgt)
    W,H = sz[0], sz[1]
    sft = SIFT()
    shifts = []
    n_tiepoints = 0
    reg = otb.itkRegion()

    for h in range(margin, H - margin - geobin_size, geobin_size+geobin_spacing):
        for w in range(margin, W - margin - geobin_size, geobin_size + geobin_spacing):
            reg['index'][0], reg['index'][1] = w, h
            reg['size'][0], reg['size'][1] = geobin_size, geobin_size

            _src.PropagateRequestedRegion(out_param_src, reg)
            _src_img = _src.ExportImage(out_param_src)
            src = _src_img['array']

            _tgt.PropagateRequestedRegion(out_param_tgt, reg)
            _tgt_img = _tgt.ExportImage(out_param_tgt)
            tgt = _tgt_img['array']

            try:
                sft.detect_and_extract(src[:,:,src_band])
            except:
                continue
            ks, ds = sft.keypoints, sft.descriptors
            try:
                sft.detect_and_extract(tgt[:,:,tgt_band])
            except:
                continue
            kt, dt = sft.keypoints, sft.descriptors
            if len(ks) > 0 and len(kt) > 0:
                mtch = match_descriptors(dt, ds, max_ratio=0.6, cross_check=True)
                if len(mtch) > 0:
                    n_tiepoints += len(mtch)
                    shifts.append(np.mean(kt[mtch[:, 0]] - ks[mtch[:, 1]], axis=0))

    reg['index'][0], reg['index'][1] = 0, 0
    reg['size'][0], reg['size'][1] = W, H
    _src.PropagateRequestedRegion(out_param_src, reg)
    _tgt.PropagateRequestedRegion(out_param_tgt, reg)

    return np.mean(np.array(shifts), axis=0), n_tiepoints

def get_normalized_displacements(_lst: List[otb.Application],
                                 band=2, out_param='out',
                                 geobin_size=32, geobin_spacing=256, margin=32):

    shifts = []
    for i in tqdm(range(len(_lst)-1)):
        sh = compute_displacement(_lst[i],_lst[i+1], src_band=band, tgt_band=band,
                                           out_param_src=out_param, out_param_tgt=out_param,
                                           geobin_size=geobin_size, geobin_spacing=geobin_spacing,
                                           margin=margin)
        shifts = [s + sh[0] for s in shifts]
        shifts.append(sh[0])
    shifts.append(np.array([0,0]))
    shifts = np.array(shifts)
    shifts -= np.mean(shifts, axis=0)

    return shifts