Forked from reversaal / OhmPi
Source project has a limited visibility.
geometry.py 6.83 KiB
import sys

import numpy as np
from skimage.feature import SIFT, match_descriptors
import otbApplication as otb
from typing import List
from tqdm import tqdm
from datetime import *
import multiprocessing as mp

def compute_overlap_matrix(_lst: List[otb.Application], out_param='out'):
    masks = []
    for l in _lst:
        img = l.ExportImage(out_param)
        masks.append((1-img['array']).astype(np.bool))
    overlaps = np.zeros((len(masks),len(masks)))
    for i in range(len(masks)):
        valid = np.sum(masks[i])
        if valid > 0:
            for j in range(len(masks)):
                overlaps[i,j] = np.sum(masks[i]*masks[j]) / valid
    return overlaps

def local_shift(src, tgt):
    n_tiepoints = 0
    shift = None
    sft = SIFT()

    try:
        sft.detect_and_extract(src)
    except:
        pass
    ks, ds = sft.keypoints, sft.descriptors
    try:
        sft.detect_and_extract(tgt)
    except:
        pass
    kt, dt = sft.keypoints, sft.descriptors
    if ks is not None and kt is not None:
        mtch = match_descriptors(dt, ds, max_ratio=0.6, cross_check=True)
        if len(mtch) > 0:
            n_tiepoints = len(mtch)
            shift = np.mean(kt[mtch[:, 0]] - ks[mtch[:, 1]], axis=0)

    return shift, n_tiepoints

def compute_displacement(_src: otb.Application, _tgt: otb.Application,
                         src_band=2, tgt_band=2,
                         out_param_src='out', out_param_tgt='out',
                         geobin_size=32, geobin_spacing=256, margin=32, n_proc=6):

    sz = _tgt.GetImageSize(out_param_tgt)
    W,H = sz[0], sz[1]
    reg = otb.itkRegion()

    args = []

    for h in range(margin, H - margin - geobin_size, geobin_size+geobin_spacing):
        for w in range(margin, W - margin - geobin_size, geobin_size + geobin_spacing):
            reg['index'][0], reg['index'][1] = w, h
            reg['size'][0], reg['size'][1] = geobin_size, geobin_size

            _src.PropagateRequestedRegion(out_param_src, reg)
            _src_img = _src.ExportImage(out_param_src)
            src = _src_img['array']
            src = src[:,:,src_band].copy()

            _tgt.PropagateRequestedRegion(out_param_tgt, reg)
            _tgt_img = _tgt.ExportImage(out_param_tgt)
            tgt = _tgt_img['array']
            tgt = tgt[:, :, tgt_band].copy()

            args.append((src,tgt))

    n_tiepoints = 0
    shifts = []
    with mp.Pool(n_proc) as p:
        for res in p.starmap(local_shift, args):
            if res[0] is not None:
                shifts.append(res[0])
            n_tiepoints += res[1]
    """
    for xx in args:
        yy = local_shift(xx[0],xx[1])
        if yy[0] is not None:
            shifts.append(yy[0])
        n_tiepoints += yy[1]
    """

    reg['index'][0], reg['index'][1] = 0, 0
    reg['size'][0], reg['size'][1] = W, H
    _src.PropagateRequestedRegion(out_param_src, reg)
    _tgt.PropagateRequestedRegion(out_param_tgt, reg)

    if n_tiepoints > 0:
        return np.mean(np.array(shifts), axis=0), n_tiepoints
    else:
        return None, n_tiepoints

def get_descriptors(_src: otb.Application, src_band=2, out_param='out',
                    geobin_size=32, geobin_spacing=256, margin=32):

    sz = _src.GetImageSize(out_param)
    W, H = sz[0], sz[1]
    sft = SIFT()
    reg = otb.itkRegion()

    keypoints = []
    descriptors = []

    for h in range(margin, H - margin - geobin_size, geobin_size+geobin_spacing):
        for w in range(margin, W - margin - geobin_size, geobin_size + geobin_spacing):
            reg['index'][0], reg['index'][1] = w, h
            reg['size'][0], reg['size'][1] = geobin_size, geobin_size

            _src.PropagateRequestedRegion(out_param, reg)
            _src_img = _src.ExportImage(out_param)
            src = _src_img['array']

            try:
                sft.detect_and_extract(src[:,:,src_band])
            except:
                continue

            keypoints.append(sft.keypoints)
            descriptors.append(sft.descriptors)

    reg['index'][0], reg['index'][1] = 0, 0
    reg['size'][0], reg['size'][1] = W, H
    _src.PropagateRequestedRegion(out_param, reg)
    keypoints = np.concatenate(keypoints, axis=0)
    descriptors = np.concatenate(descriptors, axis=0)

    return keypoints, descriptors

def get_displacements_sequentially(_lst: List[otb.Application],
                                   band=2, out_param='out',
                                   geobin_size=32, geobin_spacing=256, margin=32):

    shifts = []
    for i in tqdm(range(len(_lst)-1)):
        sh = compute_displacement(_lst[i],_lst[i+1], src_band=band, tgt_band=band,
                                           out_param_src=out_param, out_param_tgt=out_param,
                                           geobin_size=geobin_size, geobin_spacing=geobin_spacing,
                                           margin=margin)
        shifts = [s + sh[0] for s in shifts]
        shifts.append(sh[0])
    shifts.append(np.array([0,0]))
    shifts = np.array(shifts)
    shifts -= np.mean(shifts, axis=0)

    return shifts

def get_displacements_bestref(_lst: List[otb.Application], _msk: List[otb.Application],
                              ref_idx, band=2, out_param='out',
                              geobin_size=32, geobin_spacing=256, margin=32):

    '''
    ov = compute_overlap_matrix(_msk, out_param)
    refs = [-1]
    for i in range(1, ov.shape[0]):
        refs.append(-1)
        if ov[i][i] > 0:
            j = i-1
            while j>0 :
                if ov[i][j] >= threshold:
                    refs[-1] = j
                    break
                j -= 1
    '''
    idx = ref_idx

    shifts = []
    for i in tqdm(range(0,len(_lst))):
        if i != idx:
            sh = compute_displacement(_lst[i], _lst[idx], src_band=band, tgt_band=band,
                                      out_param_src=out_param, out_param_tgt=out_param,
                                      geobin_size=geobin_size, geobin_spacing=geobin_spacing,
                                      margin=margin)
            shifts.append(sh[0])
        else:
            shifts.append(np.array([0, 0]))

    shifts = np.array(shifts)
    shifts -= np.mean(shifts, axis=0)

    return shifts

def get_clearest_central_image(_lst: List[otb.Application], dates, threshold=0.8, out_param='out'):
    coverages = []
    for l in _lst:
        img = l.GetImageAsNumpyArray(out_param)
        msk = (1 - img).astype(np.bool)
        coverages.append(np.sum(msk)/(msk.shape[0]*msk.shape[1]))
        l.FreeRessources()

    dts = [datetime.strptime(x,'%Y%m%d') for x in dates]
    ctr = dts[0] + (dts[-1] - dts[0])/2
    idx = dts.index(min(dts, key=lambda x:abs(ctr - x)))
    k = 0
    found = False
    while idx > 0 and idx < len(coverages):
        if coverages[idx] > threshold:
            found = True
            break
        k += 1
        idx += (2*(k%2)-1)*k
    if found:
        return idx
    else:
        return None