Forked from HYCAR-Hydro / airGR
Source project has a limited visibility.
geometry.py 13.61 KiB
import sys

import numpy as np
from math import ceil, floor, sqrt
from scipy.ndimage.morphology import distance_transform_cdt
from skimage.feature import SIFT, match_descriptors
from skimage.measure import find_contours
import otbApplication as otb
from typing import List
from tqdm import tqdm
from datetime import *
import multiprocessing as mp

def compute_overlap_matrix(_lst: List[otb.Application], out_param='out'):
    masks = []
    for l in _lst:
        img = l.ExportImage(out_param)
        masks.append((1-img['array']).astype(np.bool))
    overlaps = np.zeros((len(masks),len(masks)))
    for i in range(len(masks)):
        valid = np.sum(masks[i])
        if valid > 0:
            for j in range(len(masks)):
                overlaps[i,j] = np.sum(masks[i]*masks[j]) / valid
    return overlaps

def local_shift(src, tgt):
    n_tiepoints = 0
    shifts = []
    sft = SIFT()

    try:
        sft.detect_and_extract(src)
    except:
        pass
    ks, ds = sft.keypoints, sft.descriptors
    try:
        sft.detect_and_extract(tgt)
    except:
        pass
    kt, dt = sft.keypoints, sft.descriptors
    if ks is not None and kt is not None:
        mtch = match_descriptors(dt, ds, max_ratio=0.8, cross_check=True)
        if len(mtch) > 0:
            shifts.append(kt[mtch[:, 0]] - ks[mtch[:, 1]])

    return shifts

def compute_displacement(_src: otb.Application, _tgt: otb.Application,
                         src_band=2, tgt_band=2,
                         out_param_src='out', out_param_tgt='out',
                         geobin_size=32, geobin_spacing=256, margin=32,
                         filter=5, n_proc=6):

    sz = _tgt.GetImageSize(out_param_tgt)
    W,H = sz[0], sz[1]
    reg = otb.itkRegion()

    args = []

    for h in range(margin, H - margin - geobin_size, geobin_size+geobin_spacing):
        for w in range(margin, W - margin - geobin_size, geobin_size + geobin_spacing):
            reg['index'][0], reg['index'][1] = w, h
            reg['size'][0], reg['size'][1] = geobin_size, geobin_size

            _src.PropagateRequestedRegion(out_param_src, reg)
            _src_img = _src.ExportImage(out_param_src)
            src = _src_img['array']
            src = src[:,:,src_band].copy()

            _tgt.PropagateRequestedRegion(out_param_tgt, reg)
            _tgt_img = _tgt.ExportImage(out_param_tgt)
            tgt = _tgt_img['array']
            tgt = tgt[:, :, tgt_band].copy()

            args.append((src,tgt))

    shifts = []
    with mp.Pool(n_proc) as p:
        for res in p.starmap(local_shift, args):
            shifts.extend(res)
    shifts = np.concatenate(shifts)

    """
    for xx in args:
        yy = local_shift(xx[0],xx[1])
        if yy[0] is not None:
            shifts.append(yy[0])
        n_tiepoints += yy[1]
    """

    reg['index'][0], reg['index'][1] = 0, 0
    reg['size'][0], reg['size'][1] = W, H
    _src.PropagateRequestedRegion(out_param_src, reg)
    _tgt.PropagateRequestedRegion(out_param_tgt, reg)

    if len(shifts) > 0 and filter > 0:
        nrm = np.linalg.norm(shifts, axis=1)
        shifts = shifts[nrm < filter]

    if len(shifts) > 0:
        return np.mean(np.array(shifts), axis=0), len(shifts)
    else:
        return None, 0

def compute_displacement_with_masks(_src: otb.Application, _tgt: otb.Application,
                                    _src_msk: otb.Application, _tgt_msk: otb.Application,
                                    src_band=2, tgt_band=2, num_geobins=16,
                                    out_param_src='out', out_param_tgt='out',
                                    out_param_src_msk='out', out_param_tgt_msk='out',
                                    geobin_radius=16, margin=32, filter=5, n_proc=6):

    geobin_size = 2 * geobin_radius + 1
    mask = 1 - _src_msk.GetImageAsNumpyArray(out_param_src_msk)
    ref_mask = 1 - _tgt_msk.GetImageAsNumpyArray(out_param_tgt_msk)
    mask = np.all((mask, ref_mask), axis=0).astype(int)
    H, W = mask.shape
    ratio = np.sum(mask)/(H*W) #scale number of bins based on coverage ratio

    if ratio == 0:
        return None, 0

    cnt, cov, cov_ext = get_patch_centers(mask, geobin_size, max(1,int(ratio*num_geobins)), margin=margin)
    _src_msk.FreeRessources()
    _tgt_msk.FreeRessources()

    reg = otb.itkRegion()
    args = []

    if cnt is None:
        return None, 0

    for c in cnt:
        reg['index'][0], reg['index'][1] = int(c[1]-geobin_radius), int(c[0]-geobin_radius)
        reg['size'][0], reg['size'][1] = int(geobin_size), int(geobin_size)

        _src.PropagateRequestedRegion(out_param_src, reg)
        _src_img = _src.ExportImage(out_param_src)
        src = _src_img['array']
        src = src[:,:,src_band].copy()

        _tgt.PropagateRequestedRegion(out_param_tgt, reg)
        _tgt_img = _tgt.ExportImage(out_param_tgt)
        tgt = _tgt_img['array']
        tgt = tgt[:, :, tgt_band].copy()

        args.append((src,tgt))

    shifts = []
    with mp.Pool(n_proc) as p:
        for res in p.starmap(local_shift, args):
            shifts.extend(res)
    if len(shifts) > 0:
        shifts = np.concatenate(shifts)

    """
    for xx in args:
        yy = local_shift(xx[0],xx[1])
        if yy[0] is not None:
            shifts.append(yy[0])
        n_tiepoints += yy[1]
    """

    reg['index'][0], reg['index'][1] = 0, 0
    reg['size'][0], reg['size'][1] = W, H
    _src.PropagateRequestedRegion(out_param_src, reg)
    _tgt.PropagateRequestedRegion(out_param_tgt, reg)

    if len(shifts) > 0 and filter > 0:
        nrm = np.linalg.norm(shifts, axis=1)
        shifts = shifts[nrm < filter]

    if len(shifts) > 0:
        return np.mean(np.array(shifts), axis=0), len(shifts)
    else:
        return None, 0

def get_descriptors(_src: otb.Application, src_band=2, out_param='out',
                    geobin_size=32, geobin_spacing=256, margin=32):

    sz = _src.GetImageSize(out_param)
    W, H = sz[0], sz[1]
    sft = SIFT()
    reg = otb.itkRegion()

    keypoints = []
    descriptors = []

    for h in range(margin, H - margin - geobin_size, geobin_size+geobin_spacing):
        for w in range(margin, W - margin - geobin_size, geobin_size + geobin_spacing):
            reg['index'][0], reg['index'][1] = w, h
            reg['size'][0], reg['size'][1] = geobin_size, geobin_size

            _src.PropagateRequestedRegion(out_param, reg)
            _src_img = _src.ExportImage(out_param)
            src = _src_img['array']

            try:
                sft.detect_and_extract(src[:,:,src_band])
            except:
                continue

            keypoints.append(sft.keypoints)
            descriptors.append(sft.descriptors)

    reg['index'][0], reg['index'][1] = 0, 0
    reg['size'][0], reg['size'][1] = W, H
    _src.PropagateRequestedRegion(out_param, reg)
    keypoints = np.concatenate(keypoints, axis=0)
    descriptors = np.concatenate(descriptors, axis=0)

    return keypoints, descriptors

def get_displacements_sequentially(_lst: List[otb.Application],
                                   band=2, out_param='out',
                                   geobin_size=32, geobin_spacing=256, margin=32,
                                   filter=5):

    shifts = []
    for i in tqdm(range(len(_lst)-1)):
        sh = compute_displacement(_lst[i],_lst[i+1], src_band=band, tgt_band=band,
                                           out_param_src=out_param, out_param_tgt=out_param,
                                           geobin_size=geobin_size, geobin_spacing=geobin_spacing,
                                           margin=margin, filter=filter)
        shifts = [s + sh[0] for s in shifts]
        shifts.append(sh[0])
    shifts.append(np.array([0,0]))
    shifts = np.array(shifts)
    shifts -= np.mean(shifts, axis=0)

    return shifts

def get_displacements_to_ref(_lst: List[otb.Application], _msk: List[otb.Application],
                             ref_mode='index', ref_idx=0, ref_ext=None, ref_ext_msk=None, band=2, ext_band=0,
                             out_param='out', geobin_radius=16, num_geobins=16, margin=32, filter=5,
                             use_masks=True, min_points=10):

    sz = _lst[0].GetImageSize(out_param)
    sz = min(sz[0], sz[1])
    idx = None
    ref_img = None
    if ref_mode == "index":
        idx = ref_idx
        ref_img = _lst[idx]
        ref_msk = _msk[idx]
        ext_band = band
    elif ref_mode == "external":
        ref_img = ref_ext
        ref_msk = ref_ext_msk

    geobin_size = 2*geobin_radius+1
    geobin_spacing = floor((sz-geobin_size*ceil(sqrt(num_geobins))-2*margin) / 4)

    shifts = []
    for i in tqdm(range(0,len(_lst)), desc="Computing displacements"):
        if i != idx:
            if use_masks:
                sh = compute_displacement_with_masks(_lst[i], ref_img, _msk[i], ref_msk,
                                                     src_band=band, tgt_band=ext_band,
                                                     out_param_src=out_param, out_param_tgt=out_param,
                                                     out_param_src_msk=out_param, out_param_tgt_msk=out_param,
                                                     geobin_radius=geobin_radius, num_geobins=num_geobins,
                                                     margin=margin, filter=filter)
            else:
                sh = compute_displacement(_lst[i], ref_img, src_band=band, tgt_band=ext_band,
                                          out_param_src=out_param, out_param_tgt=out_param,
                                          geobin_size=geobin_size, geobin_spacing=geobin_spacing,
                                          margin=margin, filter=filter)
            if sh[1] >= min_points:
                shifts.append(sh[0])
            else:
                shifts.append(np.array([0.0, 0.0]))
        else:
            shifts.append(np.array([0.0, 0.0]))

    shifts = np.array(shifts)
    shifts -= np.mean(shifts, axis=0)

    return shifts

def get_clearest_central_image(_lst: List[otb.Application], dates, threshold=0.8, out_param='out'):
    coverages = []
    for l in _lst:
        img = l.GetImageAsNumpyArray(out_param)
        msk = (1 - img).astype(np.bool)
        coverages.append(np.sum(msk)/(msk.shape[0]*msk.shape[1]))
        l.FreeRessources()

    dts = [datetime.strptime(x,'%Y%m%d') for x in dates]
    ctr = dts[0] + (dts[-1] - dts[0])/2
    idx = dts.index(min(dts, key=lambda x:abs(ctr - x)))
    k = 0
    found = False
    while idx > 0 and idx < len(coverages):
        if coverages[idx] > threshold:
            found = True
            break
        k += 1
        idx += (2*(k%2)-1)*k
    if found:
        return idx
    else:
        return None

def distance_constrained_2d_sampling(n, shape, min_dist):
    # thanks Samir
    # https://stackoverflow.com/users/5231231/samir

    # specify params
    shape = np.roll(np.array(shape), 1)
    d = floor(min_dist / 2)

    # compute grid shape based on number of points
    width_ratio = shape[1] / shape[0]
    num_y = np.int32(np.sqrt(n / width_ratio)) + 1
    num_x = np.int32(n / num_y) + 1

    # create regularly spaced neurons
    x = np.linspace(d, shape[1] - d, num_x, dtype=np.float32)
    y = np.linspace(d, shape[0] - d, num_y, dtype=np.float32)
    coords = np.stack(np.meshgrid(x, y), -1).reshape(-1, 2)

    # compute spacing
    init_dist = np.min((x[1] - x[0], y[1] - y[0]))

    if init_dist <= min_dist:
        #print(
        #    '[INFO] Grid too small for the requested number of patches, returning regular grid without distance constraint.')
        return np.round(coords).astype(np.int)

    # perturb points
    max_movement = floor((init_dist - min_dist) / 2)
    noise = np.random.uniform(
        low=-max_movement,
        high=max_movement,
        size=(len(coords), 2))
    coords += noise

    # Push points close to border (< d) to distance d for patching purposes
    coords[coords < d] = d
    coords[coords[:, 0] > shape[1] - d, 0] = shape[1] - d
    coords[coords[:, 1] > shape[0] - d, 1] = shape[0] - d

    return np.round(coords).astype(np.int)

def get_patch_centers(msk, patch_size, n_patches, min_coverage=0, min_cov_extent=0, margin=0):
    if margin > 0:
        mask = msk[margin:-margin, margin:-margin]
    else:
        mask = msk
    npx_img = mask.shape[0] * mask.shape[1]

    if np.sum(mask==0) > 0:
        # Solution with binary erosion more precise but less efficient
        # Taxicab distance gives very close results, excludes very few close-to-mask positions
        # sel = np.ones((patch_size, patch_size))
        # in_arr = binary_erosion(arr, selem=sel)
        in_arr = (distance_transform_cdt(mask) > patch_size / 2)
        npx_msk = np.sum(in_arr == 1)
        coverage = (npx_msk / npx_img)
        cnt = find_contours(np.pad(in_arr,(1,1)),0)
        if len(cnt) > 0:
            bbox = (np.min(np.array([np.min(c,axis=0) for c in cnt]), axis=0),
                np.max(np.array([np.max(c,axis=0) for c in cnt]), axis=0))
            cov_extent = (bbox[1][0]-bbox[0][0]) * (bbox[1][1]-bbox[0][1]) / npx_img
        else:
            cov_extent = 0
    else:
        arr = np.ones(mask.shape)
        in_arr = arr
        npx_msk = npx_img
        coverage, cov_extent = 1, 1

    if coverage > min_coverage and cov_extent > min_cov_extent:
        n_req_px = n_patches * (npx_img / npx_msk)
        coords = distance_constrained_2d_sampling(ceil(n_req_px), mask.shape, patch_size)

        out_coords = []

        q = floor(patch_size / 2)
        for c in coords:
            if q <= c[0] <= mask.shape[0] - q and q <= c[1] <= mask.shape[1] - q and in_arr[c[0], c[1]] == 1:
                out_coords.append(c)

        if margin>0:
            out_coords = [[c[0]+margin, c[1]+margin] for c in out_coords]

        return np.asarray(out_coords), coverage, cov_extent
    else:
        return None, coverage, cov_extent