diff --git a/Common/geometry.py b/Common/geometry.py
index 2cea80700b36106774d0c16053e68b0ef5aa4ae9..8ed2a4076b551e0eaf9a0b81ee937a9eaed56354 100644
--- a/Common/geometry.py
+++ b/Common/geometry.py
@@ -1,7 +1,7 @@
 import sys
 
 import numpy as np
-from math import ceil, floor
+from math import ceil, floor, sqrt
 from scipy.ndimage.morphology import distance_transform_cdt
 from skimage.feature import SIFT, match_descriptors
 from skimage.measure import find_contours
@@ -100,6 +100,65 @@ def compute_displacement(_src: otb.Application, _tgt: otb.Application,
     else:
         return None, n_tiepoints
 
+def compute_displacement_with_masks(_src: otb.Application, _tgt: otb.Application,
+                                    _src_msk: otb.Application, _tgt_msk: otb.Application,
+                                    src_band=2, tgt_band=2, num_geobins=16,
+                                    out_param_src='out', out_param_tgt='out',
+                                    out_param_src_msk='out', out_param_tgt_msk='out',
+                                    geobin_radius=16, margin=32, n_proc=6):
+
+    geobin_size = 2 * geobin_radius + 1
+    mask = 1 - _src_msk.GetImageAsNumpyArray(out_param_src_msk).astype(bool)
+    mask = np.all((mask, 1 - _tgt_msk.GetImageAsNumpyArray(out_param_tgt_msk).astype(bool)), axis=0).astype(int)
+    H, W = mask.shape
+    cnt, cov, cov_ext = get_patch_centers(mask, geobin_size, num_geobins, margin=margin)
+    _src_msk.FreeRessources()
+    _tgt_msk.FreeRessources()
+
+    reg = otb.itkRegion()
+    args = []
+
+    for c in cnt:
+        reg['index'][0], reg['index'][1] = int(c[1]-geobin_radius), int(c[0]-geobin_radius)
+        reg['size'][0], reg['size'][1] = int(geobin_size), int(geobin_size)
+
+        _src.PropagateRequestedRegion(out_param_src, reg)
+        _src_img = _src.ExportImage(out_param_src)
+        src = _src_img['array']
+        src = src[:,:,src_band].copy()
+
+        _tgt.PropagateRequestedRegion(out_param_tgt, reg)
+        _tgt_img = _tgt.ExportImage(out_param_tgt)
+        tgt = _tgt_img['array']
+        tgt = tgt[:, :, tgt_band].copy()
+
+        args.append((src,tgt))
+
+    n_tiepoints = 0
+    shifts = []
+    with mp.Pool(n_proc) as p:
+        for res in p.starmap(local_shift, args):
+            if res[0] is not None:
+                shifts.append(res[0])
+            n_tiepoints += res[1]
+    """
+    for xx in args:
+        yy = local_shift(xx[0],xx[1])
+        if yy[0] is not None:
+            shifts.append(yy[0])
+        n_tiepoints += yy[1]
+    """
+
+    reg['index'][0], reg['index'][1] = 0, 0
+    reg['size'][0], reg['size'][1] = W, H
+    _src.PropagateRequestedRegion(out_param_src, reg)
+    _tgt.PropagateRequestedRegion(out_param_tgt, reg)
+
+    if n_tiepoints > 0:
+        return np.mean(np.array(shifts), axis=0), n_tiepoints
+    else:
+        return None, n_tiepoints
+
 def get_descriptors(_src: otb.Application, src_band=2, out_param='out',
                     geobin_size=32, geobin_spacing=256, margin=32):
 
@@ -156,7 +215,8 @@ def get_displacements_sequentially(_lst: List[otb.Application],
 
 def get_displacements_bestref(_lst: List[otb.Application], _msk: List[otb.Application],
                               ref_idx, band=2, out_param='out',
-                              geobin_size=32, geobin_spacing=256, margin=32):
+                              geobin_radius=16, num_geobins=16, margin=32,
+                              use_masks=True):
 
     '''
     ov = compute_overlap_matrix(_msk, out_param)
@@ -171,16 +231,30 @@ def get_displacements_bestref(_lst: List[otb.Application], _msk: List[otb.Applic
                     break
                 j -= 1
     '''
+    sz = _lst[0].GetImageSize(out_param)
+    sz = min(sz[0], sz[1])
     idx = ref_idx
+    geobin_size = 2*geobin_radius+1
+    geobin_spacing = floor((sz-geobin_size*ceil(sqrt(num_geobins))-2*margin) / 4)
+    print('geobin size : {} geobin spc : {}'.format(geobin_size, geobin_spacing))
 
     shifts = []
     for i in tqdm(range(0,len(_lst))):
         if i != idx:
-            sh = compute_displacement(_lst[i], _lst[idx], src_band=band, tgt_band=band,
-                                      out_param_src=out_param, out_param_tgt=out_param,
-                                      geobin_size=geobin_size, geobin_spacing=geobin_spacing,
-                                      margin=margin)
-            shifts.append(sh[0])
+            if use_masks:
+                sh = compute_displacement_with_masks(_lst[i], _lst[idx], _msk[i], _msk[idx],
+                                                     src_band=band, tgt_band=band,
+                                                     out_param_src=out_param, out_param_tgt=out_param,
+                                                     out_param_src_msk=out_param, out_param_tgt_msk=out_param,
+                                                     geobin_radius=geobin_radius, num_geobins=num_geobins,
+                                                     margin=margin)
+            else:
+                sh = compute_displacement(_lst[i], _lst[idx], src_band=band, tgt_band=band,
+                                          out_param_src=out_param, out_param_tgt=out_param,
+                                          geobin_size=geobin_size, geobin_spacing=geobin_spacing,
+                                          margin=margin)
+            if sh[0] is not None:
+                shifts.append(sh[0])
         else:
             shifts.append(np.array([0, 0]))
 
@@ -254,7 +328,11 @@ def distance_constrained_2d_sampling(n, shape, min_dist):
 
     return np.round(coords).astype(np.int)
 
-def get_patch_centers(mask, patch_size, n_patches, min_coverage=0, min_cov_extent=0):
+def get_patch_centers(msk, patch_size, n_patches, min_coverage=0, min_cov_extent=0, margin=0):
+    if margin > 0:
+        mask = msk[margin:-margin, margin:-margin]
+    else:
+        mask = msk
     npx_img = mask.shape[0] * mask.shape[1]
 
     if np.sum(mask==0) > 0:
@@ -278,9 +356,6 @@ def get_patch_centers(mask, patch_size, n_patches, min_coverage=0, min_cov_exten
         npx_msk = npx_img
         coverage, cov_extent = 1, 1
 
-    # approximate compensation for "filtered" centers --> OLD
-    # n_patches = floor(1.1 * n_patches)
-
     if coverage > min_coverage and cov_extent > min_cov_extent:
         n_req_px = n_patches * (npx_img / npx_msk)
         coords = distance_constrained_2d_sampling(ceil(n_req_px), mask.shape, patch_size)
@@ -292,6 +367,9 @@ def get_patch_centers(mask, patch_size, n_patches, min_coverage=0, min_cov_exten
             if q <= c[0] <= mask.shape[0] - q and q <= c[1] <= mask.shape[1] - q and in_arr[c[0], c[1]] == 1:
                 out_coords.append(c)
 
+        if margin>0:
+            out_coords = [[c[0]+margin, c[1]+margin] for c in out_coords]
+
         return np.asarray(out_coords), coverage, cov_extent
     else:
         return None, coverage, cov_extent
diff --git a/TimeSeries/s2theia.py b/TimeSeries/s2theia.py
index 4a3d875c0a127336a697d8a864b59cb82e4f54f4..023d27716752ad07a085a18815757c29fecff57d 100644
--- a/TimeSeries/s2theia.py
+++ b/TimeSeries/s2theia.py
@@ -460,7 +460,7 @@ class S2TheiaTilePipeline:
             out_fld = self.temp_fld
         self.write_outputs(out_fld, update_pipe=True, flag_nodata=True)
 
-    def rigid_align(self, cov_th=0.8, ref_date=None, match_band=2, geobin_size=32, geobin_spacing=256, margin=32, out_param='out'):
+    def rigid_align(self, cov_th=0.8, ref_date=None, match_band=2, geobin_radius=16, num_geobins=16, margin=32, out_param='out'):
         proc_idx = self.out_idx.copy()
         self.out_idx = []
         img = [self.pipe[t] for t in range(self.pipe_start+match_band, (10+len(self.PTRN_msk))*len(self.input_dates),
@@ -474,8 +474,8 @@ class S2TheiaTilePipeline:
             ctr = datetime.strptime(ref_date, '%Y%m%d')
             ref_idx = dts.index(min(dts, key=lambda x: abs(ctr - x)))
 
-        shifts = get_displacements_bestref(img,msk,ref_idx,0,out_param,
-                                           geobin_size, geobin_spacing, margin)
+        shifts = get_displacements_bestref(img, msk, ref_idx, 0, out_param,
+                                           geobin_radius, num_geobins, margin)
         k = 0
         for i in proc_idx:
             if ((i - proc_idx[0]) % 2) == 0: