import numpy as np
from skimage.feature import SIFT, match_descriptors

def compute_displacement(src, tgt, geobin_size=32, geobin_spacing=256, margin=32):

    W,H = src.shape[0], src.shape[1]
    sft = SIFT()
    shifts = []
    n_tiepoints = 0

    for h in range(margin, H - margin - geobin_size, geobin_size+geobin_spacing):
        for w in range(margin, W - margin - geobin_size, geobin_size + geobin_spacing):
            try:
                sft.detect_and_extract(src[h:h+geobin_size,w:w+geobin_size])
            except:
                continue
            ks, ds = sft.keypoints, sft.descriptors
            try:
                sft.detect_and_extract(tgt[h:h+geobin_size,w:w+geobin_size])
            except:
                continue
            kt, dt = sft.keypoints, sft.descriptors
            if len(ks) > 0 and len(kt) > 0:
                mtch = match_descriptors(ds, dt, max_ratio=0.6, cross_check=True)
                if len(mtch) > 0:
                    n_tiepoints += len(mtch)
                    shifts.append(np.mean(ks[mtch[:, 0]] - kt[mtch[:, 1]], axis=0))

    return np.mean(np.array(shifts), axis=0), n_tiepoints

'''
FOR TESTING IN CONSOLE:

from Common.otb_numpy_proc import *
p = to_otb_pipeline('/DATA/Senegal/Koussanar/subset/SENTINEL2A_20210203-113730-491_L2A_T28PEA_C_V2-2/SENTINEL2A_20210203-113730-491_L2A_T28PEA_C_V2-2_FRE_B4.tif')
r = to_otb_pipeline('/DATA/Senegal/Koussanar/subset/SENTINEL2A_20210213-113734-979_L2A_T28PEA_C_V2-2/SENTINEL2A_20210213-113734-979_L2A_T28PEA_C_V2-2_FRE_B4.tif')
from Common.geometry import compute_displacement

src = p.ExportImage('out')
tgt = r.ExportImage('out')
sh, n = compute_displacement(src['array'][:,:,0], tgt['array'][:,:,0])
tgt['origin'][0] += sh[1]*src['spacing'][0]
tgt['origin'][1] += sh[0]*src['spacing'][1]
si = otb.Registry.CreateApplication('Superimpose')
si.SetParameterInputImage('inr', p.GetParameterOutputImage('out'))
si.ImportImage('inm', tgt)
# Interpolator??
si.SetParameterString('out', '/DATA/Senegal/Koussanar/ttt.tif')
si.SetParameterOutputImagePixelType('out', otb.ImagePixelType_int16)
si.ExecuteAndWriteOutput()
'''