Commit eeecfa4e authored by Gaetano Raffaele's avatar Gaetano Raffaele
Browse files

ENH: Version of rigid coregistration using patch sampling and masks.

No related merge requests found
Showing with 92 additions and 14 deletions
+92 -14
import sys import sys
import numpy as np import numpy as np
from math import ceil, floor from math import ceil, floor, sqrt
from scipy.ndimage.morphology import distance_transform_cdt from scipy.ndimage.morphology import distance_transform_cdt
from skimage.feature import SIFT, match_descriptors from skimage.feature import SIFT, match_descriptors
from skimage.measure import find_contours from skimage.measure import find_contours
...@@ -100,6 +100,65 @@ def compute_displacement(_src: otb.Application, _tgt: otb.Application, ...@@ -100,6 +100,65 @@ def compute_displacement(_src: otb.Application, _tgt: otb.Application,
else: else:
return None, n_tiepoints return None, n_tiepoints
def compute_displacement_with_masks(_src: otb.Application, _tgt: otb.Application,
_src_msk: otb.Application, _tgt_msk: otb.Application,
src_band=2, tgt_band=2, num_geobins=16,
out_param_src='out', out_param_tgt='out',
out_param_src_msk='out', out_param_tgt_msk='out',
geobin_radius=16, margin=32, n_proc=6):
geobin_size = 2 * geobin_radius + 1
mask = 1 - _src_msk.GetImageAsNumpyArray(out_param_src_msk).astype(bool)
mask = np.all((mask, 1 - _tgt_msk.GetImageAsNumpyArray(out_param_tgt_msk).astype(bool)), axis=0).astype(int)
H, W = mask.shape
cnt, cov, cov_ext = get_patch_centers(mask, geobin_size, num_geobins, margin=margin)
_src_msk.FreeRessources()
_tgt_msk.FreeRessources()
reg = otb.itkRegion()
args = []
for c in cnt:
reg['index'][0], reg['index'][1] = int(c[1]-geobin_radius), int(c[0]-geobin_radius)
reg['size'][0], reg['size'][1] = int(geobin_size), int(geobin_size)
_src.PropagateRequestedRegion(out_param_src, reg)
_src_img = _src.ExportImage(out_param_src)
src = _src_img['array']
src = src[:,:,src_band].copy()
_tgt.PropagateRequestedRegion(out_param_tgt, reg)
_tgt_img = _tgt.ExportImage(out_param_tgt)
tgt = _tgt_img['array']
tgt = tgt[:, :, tgt_band].copy()
args.append((src,tgt))
n_tiepoints = 0
shifts = []
with mp.Pool(n_proc) as p:
for res in p.starmap(local_shift, args):
if res[0] is not None:
shifts.append(res[0])
n_tiepoints += res[1]
"""
for xx in args:
yy = local_shift(xx[0],xx[1])
if yy[0] is not None:
shifts.append(yy[0])
n_tiepoints += yy[1]
"""
reg['index'][0], reg['index'][1] = 0, 0
reg['size'][0], reg['size'][1] = W, H
_src.PropagateRequestedRegion(out_param_src, reg)
_tgt.PropagateRequestedRegion(out_param_tgt, reg)
if n_tiepoints > 0:
return np.mean(np.array(shifts), axis=0), n_tiepoints
else:
return None, n_tiepoints
def get_descriptors(_src: otb.Application, src_band=2, out_param='out', def get_descriptors(_src: otb.Application, src_band=2, out_param='out',
geobin_size=32, geobin_spacing=256, margin=32): geobin_size=32, geobin_spacing=256, margin=32):
...@@ -156,7 +215,8 @@ def get_displacements_sequentially(_lst: List[otb.Application], ...@@ -156,7 +215,8 @@ def get_displacements_sequentially(_lst: List[otb.Application],
def get_displacements_bestref(_lst: List[otb.Application], _msk: List[otb.Application], def get_displacements_bestref(_lst: List[otb.Application], _msk: List[otb.Application],
ref_idx, band=2, out_param='out', ref_idx, band=2, out_param='out',
geobin_size=32, geobin_spacing=256, margin=32): geobin_radius=16, num_geobins=16, margin=32,
use_masks=True):
''' '''
ov = compute_overlap_matrix(_msk, out_param) ov = compute_overlap_matrix(_msk, out_param)
...@@ -171,16 +231,30 @@ def get_displacements_bestref(_lst: List[otb.Application], _msk: List[otb.Applic ...@@ -171,16 +231,30 @@ def get_displacements_bestref(_lst: List[otb.Application], _msk: List[otb.Applic
break break
j -= 1 j -= 1
''' '''
sz = _lst[0].GetImageSize(out_param)
sz = min(sz[0], sz[1])
idx = ref_idx idx = ref_idx
geobin_size = 2*geobin_radius+1
geobin_spacing = floor((sz-geobin_size*ceil(sqrt(num_geobins))-2*margin) / 4)
print('geobin size : {} geobin spc : {}'.format(geobin_size, geobin_spacing))
shifts = [] shifts = []
for i in tqdm(range(0,len(_lst))): for i in tqdm(range(0,len(_lst))):
if i != idx: if i != idx:
sh = compute_displacement(_lst[i], _lst[idx], src_band=band, tgt_band=band, if use_masks:
out_param_src=out_param, out_param_tgt=out_param, sh = compute_displacement_with_masks(_lst[i], _lst[idx], _msk[i], _msk[idx],
geobin_size=geobin_size, geobin_spacing=geobin_spacing, src_band=band, tgt_band=band,
margin=margin) out_param_src=out_param, out_param_tgt=out_param,
shifts.append(sh[0]) out_param_src_msk=out_param, out_param_tgt_msk=out_param,
geobin_radius=geobin_radius, num_geobins=num_geobins,
margin=margin)
else:
sh = compute_displacement(_lst[i], _lst[idx], src_band=band, tgt_band=band,
out_param_src=out_param, out_param_tgt=out_param,
geobin_size=geobin_size, geobin_spacing=geobin_spacing,
margin=margin)
if sh[0] is not None:
shifts.append(sh[0])
else: else:
shifts.append(np.array([0, 0])) shifts.append(np.array([0, 0]))
...@@ -254,7 +328,11 @@ def distance_constrained_2d_sampling(n, shape, min_dist): ...@@ -254,7 +328,11 @@ def distance_constrained_2d_sampling(n, shape, min_dist):
return np.round(coords).astype(np.int) return np.round(coords).astype(np.int)
def get_patch_centers(mask, patch_size, n_patches, min_coverage=0, min_cov_extent=0): def get_patch_centers(msk, patch_size, n_patches, min_coverage=0, min_cov_extent=0, margin=0):
if margin > 0:
mask = msk[margin:-margin, margin:-margin]
else:
mask = msk
npx_img = mask.shape[0] * mask.shape[1] npx_img = mask.shape[0] * mask.shape[1]
if np.sum(mask==0) > 0: if np.sum(mask==0) > 0:
...@@ -278,9 +356,6 @@ def get_patch_centers(mask, patch_size, n_patches, min_coverage=0, min_cov_exten ...@@ -278,9 +356,6 @@ def get_patch_centers(mask, patch_size, n_patches, min_coverage=0, min_cov_exten
npx_msk = npx_img npx_msk = npx_img
coverage, cov_extent = 1, 1 coverage, cov_extent = 1, 1
# approximate compensation for "filtered" centers --> OLD
# n_patches = floor(1.1 * n_patches)
if coverage > min_coverage and cov_extent > min_cov_extent: if coverage > min_coverage and cov_extent > min_cov_extent:
n_req_px = n_patches * (npx_img / npx_msk) n_req_px = n_patches * (npx_img / npx_msk)
coords = distance_constrained_2d_sampling(ceil(n_req_px), mask.shape, patch_size) coords = distance_constrained_2d_sampling(ceil(n_req_px), mask.shape, patch_size)
...@@ -292,6 +367,9 @@ def get_patch_centers(mask, patch_size, n_patches, min_coverage=0, min_cov_exten ...@@ -292,6 +367,9 @@ def get_patch_centers(mask, patch_size, n_patches, min_coverage=0, min_cov_exten
if q <= c[0] <= mask.shape[0] - q and q <= c[1] <= mask.shape[1] - q and in_arr[c[0], c[1]] == 1: if q <= c[0] <= mask.shape[0] - q and q <= c[1] <= mask.shape[1] - q and in_arr[c[0], c[1]] == 1:
out_coords.append(c) out_coords.append(c)
if margin>0:
out_coords = [[c[0]+margin, c[1]+margin] for c in out_coords]
return np.asarray(out_coords), coverage, cov_extent return np.asarray(out_coords), coverage, cov_extent
else: else:
return None, coverage, cov_extent return None, coverage, cov_extent
...@@ -460,7 +460,7 @@ class S2TheiaTilePipeline: ...@@ -460,7 +460,7 @@ class S2TheiaTilePipeline:
out_fld = self.temp_fld out_fld = self.temp_fld
self.write_outputs(out_fld, update_pipe=True, flag_nodata=True) self.write_outputs(out_fld, update_pipe=True, flag_nodata=True)
def rigid_align(self, cov_th=0.8, ref_date=None, match_band=2, geobin_size=32, geobin_spacing=256, margin=32, out_param='out'): def rigid_align(self, cov_th=0.8, ref_date=None, match_band=2, geobin_radius=16, num_geobins=16, margin=32, out_param='out'):
proc_idx = self.out_idx.copy() proc_idx = self.out_idx.copy()
self.out_idx = [] self.out_idx = []
img = [self.pipe[t] for t in range(self.pipe_start+match_band, (10+len(self.PTRN_msk))*len(self.input_dates), img = [self.pipe[t] for t in range(self.pipe_start+match_band, (10+len(self.PTRN_msk))*len(self.input_dates),
...@@ -474,8 +474,8 @@ class S2TheiaTilePipeline: ...@@ -474,8 +474,8 @@ class S2TheiaTilePipeline:
ctr = datetime.strptime(ref_date, '%Y%m%d') ctr = datetime.strptime(ref_date, '%Y%m%d')
ref_idx = dts.index(min(dts, key=lambda x: abs(ctr - x))) ref_idx = dts.index(min(dts, key=lambda x: abs(ctr - x)))
shifts = get_displacements_bestref(img,msk,ref_idx,0,out_param, shifts = get_displacements_bestref(img, msk, ref_idx, 0, out_param,
geobin_size, geobin_spacing, margin) geobin_radius, num_geobins, margin)
k = 0 k = 0
for i in proc_idx: for i in proc_idx:
if ((i - proc_idx[0]) % 2) == 0: if ((i - proc_idx[0]) % 2) == 0:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment