Commit dcce9984 authored by Gaetano Raffaele's avatar Gaetano Raffaele
Browse files

ENH: Version of rigid coregistration for VHR.

parent 9fbc0215
No related merge requests found
Showing with 79 additions and 19 deletions
+79 -19
...@@ -49,7 +49,8 @@ def local_shift(src, tgt): ...@@ -49,7 +49,8 @@ def local_shift(src, tgt):
def compute_displacement(_src: otb.Application, _tgt: otb.Application, def compute_displacement(_src: otb.Application, _tgt: otb.Application,
src_band=2, tgt_band=2, src_band=2, tgt_band=2,
out_param_src='out', out_param_tgt='out', out_param_src='out', out_param_tgt='out',
geobin_size=32, geobin_spacing=256, margin=32, n_proc=6): geobin_size=32, geobin_spacing=256, margin=32,
filter=5, n_proc=6):
sz = _tgt.GetImageSize(out_param_tgt) sz = _tgt.GetImageSize(out_param_tgt)
W,H = sz[0], sz[1] W,H = sz[0], sz[1]
...@@ -78,6 +79,8 @@ def compute_displacement(_src: otb.Application, _tgt: otb.Application, ...@@ -78,6 +79,8 @@ def compute_displacement(_src: otb.Application, _tgt: otb.Application,
with mp.Pool(n_proc) as p: with mp.Pool(n_proc) as p:
for res in p.starmap(local_shift, args): for res in p.starmap(local_shift, args):
shifts.extend(res) shifts.extend(res)
shifts = np.concatenate(shifts)
""" """
for xx in args: for xx in args:
yy = local_shift(xx[0],xx[1]) yy = local_shift(xx[0],xx[1])
...@@ -92,6 +95,9 @@ def compute_displacement(_src: otb.Application, _tgt: otb.Application, ...@@ -92,6 +95,9 @@ def compute_displacement(_src: otb.Application, _tgt: otb.Application,
_tgt.PropagateRequestedRegion(out_param_tgt, reg) _tgt.PropagateRequestedRegion(out_param_tgt, reg)
if len(shifts) > 0: if len(shifts) > 0:
if filter > 0:
nrm = np.linalg.norm(shifts, axis=1)
shifts = shifts[nrm < filter]
return np.mean(np.array(shifts), axis=0), len(shifts) return np.mean(np.array(shifts), axis=0), len(shifts)
else: else:
return None, 0 return None, 0
...@@ -101,7 +107,7 @@ def compute_displacement_with_masks(_src: otb.Application, _tgt: otb.Application ...@@ -101,7 +107,7 @@ def compute_displacement_with_masks(_src: otb.Application, _tgt: otb.Application
src_band=2, tgt_band=2, num_geobins=16, src_band=2, tgt_band=2, num_geobins=16,
out_param_src='out', out_param_tgt='out', out_param_src='out', out_param_tgt='out',
out_param_src_msk='out', out_param_tgt_msk='out', out_param_src_msk='out', out_param_tgt_msk='out',
geobin_radius=16, margin=32, n_proc=6): geobin_radius=16, margin=32, filter=5, n_proc=6):
geobin_size = 2 * geobin_radius + 1 geobin_size = 2 * geobin_radius + 1
mask = 1 - _src_msk.GetImageAsNumpyArray(out_param_src_msk).astype(bool) mask = 1 - _src_msk.GetImageAsNumpyArray(out_param_src_msk).astype(bool)
...@@ -150,6 +156,9 @@ def compute_displacement_with_masks(_src: otb.Application, _tgt: otb.Application ...@@ -150,6 +156,9 @@ def compute_displacement_with_masks(_src: otb.Application, _tgt: otb.Application
_tgt.PropagateRequestedRegion(out_param_tgt, reg) _tgt.PropagateRequestedRegion(out_param_tgt, reg)
if len(shifts) > 0: if len(shifts) > 0:
if filter > 0:
nrm = np.linalg.norm(shifts, axis=1)
shifts = shifts[nrm < filter]
return np.mean(np.array(shifts), axis=0), len(shifts) return np.mean(np.array(shifts), axis=0), len(shifts)
else: else:
return None, 0 return None, 0
...@@ -192,14 +201,15 @@ def get_descriptors(_src: otb.Application, src_band=2, out_param='out', ...@@ -192,14 +201,15 @@ def get_descriptors(_src: otb.Application, src_band=2, out_param='out',
def get_displacements_sequentially(_lst: List[otb.Application], def get_displacements_sequentially(_lst: List[otb.Application],
band=2, out_param='out', band=2, out_param='out',
geobin_size=32, geobin_spacing=256, margin=32): geobin_size=32, geobin_spacing=256, margin=32,
filter=5):
shifts = [] shifts = []
for i in tqdm(range(len(_lst)-1)): for i in tqdm(range(len(_lst)-1)):
sh = compute_displacement(_lst[i],_lst[i+1], src_band=band, tgt_band=band, sh = compute_displacement(_lst[i],_lst[i+1], src_band=band, tgt_band=band,
out_param_src=out_param, out_param_tgt=out_param, out_param_src=out_param, out_param_tgt=out_param,
geobin_size=geobin_size, geobin_spacing=geobin_spacing, geobin_size=geobin_size, geobin_spacing=geobin_spacing,
margin=margin) margin=margin, filter=filter)
shifts = [s + sh[0] for s in shifts] shifts = [s + sh[0] for s in shifts]
shifts.append(sh[0]) shifts.append(sh[0])
shifts.append(np.array([0,0])) shifts.append(np.array([0,0]))
...@@ -210,7 +220,7 @@ def get_displacements_sequentially(_lst: List[otb.Application], ...@@ -210,7 +220,7 @@ def get_displacements_sequentially(_lst: List[otb.Application],
def get_displacements_bestref(_lst: List[otb.Application], _msk: List[otb.Application], def get_displacements_bestref(_lst: List[otb.Application], _msk: List[otb.Application],
ref_idx, band=2, out_param='out', ref_idx, band=2, out_param='out',
geobin_radius=16, num_geobins=16, margin=32, geobin_radius=16, num_geobins=16, margin=32, filter=5,
use_masks=True): use_masks=True):
''' '''
...@@ -241,12 +251,12 @@ def get_displacements_bestref(_lst: List[otb.Application], _msk: List[otb.Applic ...@@ -241,12 +251,12 @@ def get_displacements_bestref(_lst: List[otb.Application], _msk: List[otb.Applic
out_param_src=out_param, out_param_tgt=out_param, out_param_src=out_param, out_param_tgt=out_param,
out_param_src_msk=out_param, out_param_tgt_msk=out_param, out_param_src_msk=out_param, out_param_tgt_msk=out_param,
geobin_radius=geobin_radius, num_geobins=num_geobins, geobin_radius=geobin_radius, num_geobins=num_geobins,
margin=margin) margin=margin, filter=filter)
else: else:
sh = compute_displacement(_lst[i], _lst[idx], src_band=band, tgt_band=band, sh = compute_displacement(_lst[i], _lst[idx], src_band=band, tgt_band=band,
out_param_src=out_param, out_param_tgt=out_param, out_param_src=out_param, out_param_tgt=out_param,
geobin_size=geobin_size, geobin_spacing=geobin_spacing, geobin_size=geobin_size, geobin_spacing=geobin_spacing,
margin=margin) margin=margin, filter=filter)
if sh[0] is not None: if sh[0] is not None:
shifts.append(sh[0]) shifts.append(sh[0])
else: else:
......
...@@ -4,7 +4,9 @@ import os ...@@ -4,7 +4,9 @@ import os
import glob import glob
from pathlib import Path from pathlib import Path
import re import re
from osgeo import ogr from osgeo import ogr, gdal
from Common.geometry import compute_displacement
from math import ceil, floor, sqrt
class SPOT67RasterPipeline: class SPOT67RasterPipeline:
# BEGIN SPOT6/7 VHR PROTOTYPE # BEGIN SPOT6/7 VHR PROTOTYPE
...@@ -74,6 +76,8 @@ class SPOT67RasterPipeline: ...@@ -74,6 +76,8 @@ class SPOT67RasterPipeline:
ty = self.REF_TYPE ty = self.REF_TYPE
self.append(tile_ms, fn, ty, 'out', is_output=True) self.append(tile_ms, fn, ty, 'out', is_output=True)
self.shift = None
def to_toa(self, clamp=True): def to_toa(self, clamp=True):
proc_idx = self.out_idx.copy() proc_idx = self.out_idx.copy()
self.out_idx = [] self.out_idx = []
...@@ -125,18 +129,45 @@ class SPOT67RasterPipeline: ...@@ -125,18 +129,45 @@ class SPOT67RasterPipeline:
ty = self.REF_TYPE ty = self.REF_TYPE
self.append(btps, fn, ty, 'out', is_output=True) self.append(btps, fn, ty, 'out', is_output=True)
def rigid_align(self, ref_img, this_band=0, ref_band=2, geobin_radius=32, num_geobins=128, margin=32):
si = otb.Registry.CreateApplication("Superimpose")
si.SetParameterInputImage('inm', self.pipe[-1].GetParameterOutputImage(self.out_p[-1]))
si.SetParameterString('inr', ref_img)
si.Execute()
sz = si.GetImageSize('out')
sz = min(sz[0], sz[1])
geobin_size = 2 * geobin_radius + 1
geobin_spacing = floor((sz - geobin_size * ceil(sqrt(num_geobins)) - 2 * margin) / 4)
sh = compute_displacement(si, to_otb_pipeline(ref_img), src_band=this_band, tgt_band=ref_band,
out_param_src='out', out_param_tgt='out',
geobin_size=geobin_size, geobin_spacing=geobin_spacing, margin=margin)
print(sh[0])
if sh is not None:
ref_spc = si.GetImageSpacing('out')
self.shift = (ref_spc[0]*sh[0][1], ref_spc[1]*sh[0][0])
return
def clip(self, roi, buffer=0): def clip(self, roi, buffer=0):
assert(os.path.exists(roi)) assert(os.path.exists(roi))
proc_idx = self.out_idx.copy() proc_idx = self.out_idx.copy()
self.out_idx = [] self.out_idx = []
for t in proc_idx: for t in proc_idx:
spc = self.pipe[t].GetImageSpacing(self.out_p[t])
er = otb.Registry.CreateApplication('ExtractROI') er = otb.Registry.CreateApplication('ExtractROI')
er.SetParameterInputImage('in', self.pipe[t].GetParameterOutputImage(self.out_p[t])) er.SetParameterInputImage('in', self.pipe[t].GetParameterOutputImage(self.out_p[t]))
er.SetParameterString('mode', 'extent') er.SetParameterString('mode', 'extent')
ds = ogr.Open(roi) ds = ogr.Open(roi)
ly = ds.GetLayer(0) ly = ds.GetLayer(0)
extent = ly.GetExtent() extent = list(ly.GetExtent())
extent = (extent[0] - buffer, extent[1] + buffer, extent[2] - buffer, extent[3] + buffer) extent = [extent[0] - buffer, extent[1] + buffer, extent[2] - buffer, extent[3] + buffer]
if spc[0] < 0:
tmp = extent[0]
extent[0] = extent[1]
extent[1] = tmp
if spc[1] < 0:
tmp = extent[2]
extent[2] = extent[3]
extent[3] = tmp
if ly.GetSpatialRef().GetAuthorityCode('PROJCS') == '4326': if ly.GetSpatialRef().GetAuthorityCode('PROJCS') == '4326':
er.SetParameterString('mode.extent.unit', 'lonlat') er.SetParameterString('mode.extent.unit', 'lonlat')
else: else:
...@@ -161,14 +192,15 @@ class SPOT67RasterPipeline: ...@@ -161,14 +192,15 @@ class SPOT67RasterPipeline:
def write_outputs(self, fld, roi=None, update_pipe=False, compress=False): def write_outputs(self, fld, roi=None, update_pipe=False, compress=False):
out = [] out = []
out_idx_bck = self.out_idx.copy() out_idx_bck = proc_idx = self.out_idx.copy()
if roi is not None: if roi is not None:
pipe_length = len(self.pipe) pipe_length = len(self.pipe)
self.clip(roi) self.clip(roi)
proc_idx = self.out_idx.copy()
if update_pipe: if update_pipe:
assert roi is None, 'Cannot set output files as pipe input over a ROI.' assert roi is None, 'Cannot set output files as pipe input over a ROI.'
self.out_idx = [] self.out_idx = []
for t in out_idx_bck: for t in proc_idx:
out_file = os.path.join(fld, self.files[t]) out_file = os.path.join(fld, self.files[t])
if compress: if compress:
out_file += '?gdal:co:compress=deflate&gdal:co:bigtiff=yes' out_file += '?gdal:co:compress=deflate&gdal:co:bigtiff=yes'
...@@ -177,13 +209,23 @@ class SPOT67RasterPipeline: ...@@ -177,13 +209,23 @@ class SPOT67RasterPipeline:
self.pipe[t].SetParameterString(self.out_p[t], out_file) self.pipe[t].SetParameterString(self.out_p[t], out_file)
self.pipe[t].SetParameterOutputImagePixelType(self.out_p[t], self.types[t]) self.pipe[t].SetParameterOutputImagePixelType(self.out_p[t], self.types[t])
self.pipe[t].ExecuteAndWriteOutput() self.pipe[t].ExecuteAndWriteOutput()
if self.shift is not None:
print("INFO: Applying shift {} to current outputs.".format(self.shift))
ds = gdal.Open(out_file, 1)
geot = list(ds.GetGeoTransform())
geot[0] += self.shift[0]
geot[3] += self.shift[1]
ds.SetGeoTransform(tuple(geot))
ds = None
out.append(out_file) out.append(out_file)
if update_pipe: if update_pipe:
self.append(to_otb_pipeline(out_file), self.files[t], self.types[t], 'out', is_output=True) self.append(to_otb_pipeline(out_file), self.files[t], self.types[t], 'out', is_output=True)
if update_pipe:
self.shift = None
if roi is not None: if roi is not None:
self.out_idx = out_idx_bck self.out_idx = out_idx_bck
self.pipe = self.pipe[:pipe_length] self.pipe = self.pipe[:pipe_length]
self.files = self.files[:pipe_length] self.files = self.files[:pipe_length]
self.types = self.types[:pipe_length] self.types = self.types[:pipe_length]
self.out_p = self.out_p[:pipe_length] self.out_p = self.out_p[:pipe_length]
return out return out
\ No newline at end of file
...@@ -12,16 +12,19 @@ def run_segmentation(img, threshold, cw, sw , out_seg, ...@@ -12,16 +12,19 @@ def run_segmentation(img, threshold, cw, sw , out_seg,
OBIA.segmentation.lsgrm(img, params, out_seg, n_proc, memory, None, remove_graph, force_parallel) OBIA.segmentation.lsgrm(img, params, out_seg, n_proc, memory, None, remove_graph, force_parallel)
return return
def preprocess_spot67(in_fld, out_fld, dem_fld, geoid_file, skip_ps, compress): def preprocess_spot67(in_fld, out_fld, dem_fld, geoid_file, skip_ps, compress,
clip, align_to, align_to_band, align_using_band):
sp = VHR.vhrbase.SPOT67RasterPipeline(in_fld) sp = VHR.vhrbase.SPOT67RasterPipeline(in_fld)
sp.to_toa() sp.to_toa()
sp.orthorectify(dem_fld, geoid_file) sp.orthorectify(dem_fld, geoid_file)
if clip is not None and os.path.exists(clip):
sp.clip(clip)
if not skip_ps: if not skip_ps:
sp.write_outputs(out_fld, update_pipe=True, compress=compress) sp.write_outputs(out_fld, update_pipe=True, compress=compress)
sp.pansharp() sp.pansharp()
sp.write_outputs(out_fld, compress=compress) if align_to is not None and os.path.exists(align_to):
else: sp.rigid_align(align_to, this_band=align_using_band-1, ref_band=align_to_band-1)
sp.write_outputs(out_fld, compress=compress) sp.write_outputs(out_fld, compress=compress)
return return
def preprocess_s2(in_fld, out_fld, output_dates_file=None, roi=None, coregister_to=None, coregister_to_band=1, def preprocess_s2(in_fld, out_fld, output_dates_file=None, roi=None, coregister_to=None, coregister_to_band=1,
...@@ -56,7 +59,7 @@ def main(args): ...@@ -56,7 +59,7 @@ def main(args):
prepr.add_argument("--roi", type=str, default=None, help="Path to the ROI vector file.") prepr.add_argument("--roi", type=str, default=None, help="Path to the ROI vector file.")
prepr.add_argument("--coregister_to", type=str, default=None, help="Path to a reference image to which the stacks must be coregistered.") prepr.add_argument("--coregister_to", type=str, default=None, help="Path to a reference image to which the stacks must be coregistered.")
prepr.add_argument("--coregister_to_band", type=int, nargs='?', default=1, help="Band of reference image used for co-registration.") prepr.add_argument("--coregister_to_band", type=int, nargs='?', default=1, help="Band of reference image used for co-registration.")
prepr.add_argument("--coregister_using_band", type=int, nargs='?', default=3, help="Band of reference image used for co-registration.") prepr.add_argument("--coregister_using_band", type=int, nargs='?', default=3, help="Band of current stack used for co-registration.")
prepr.add_argument("--provider", type=str, default='theia', help="S2 image provider. Supported: 'theia', 'theial3a', 'sen2cor', 'planetary'") prepr.add_argument("--provider", type=str, default='theia', help="S2 image provider. Supported: 'theia', 'theial3a', 'sen2cor', 'planetary'")
segmt = subpar.add_parser("segment", help="Performs (large scale Baatz-Shape) segmentation of an input image.", segmt = subpar.add_parser("segment", help="Performs (large scale Baatz-Shape) segmentation of an input image.",
...@@ -79,6 +82,10 @@ def main(args): ...@@ -79,6 +82,10 @@ def main(args):
vhrprep.add_argument("out_fld", type=str, help="Path to the output folder for preprocessed images.") vhrprep.add_argument("out_fld", type=str, help="Path to the output folder for preprocessed images.")
vhrprep.add_argument("dem_fld", type=str, help="Path to the folder containing DEM covering the scene in WGS84 projection.") vhrprep.add_argument("dem_fld", type=str, help="Path to the folder containing DEM covering the scene in WGS84 projection.")
vhrprep.add_argument("geoid", type=str, help="Path to the geoid file.") vhrprep.add_argument("geoid", type=str, help="Path to the geoid file.")
vhrprep.add_argument("--clip", type=str, default=None, help="Path to a vector file for clipping.")
vhrprep.add_argument("--align_to", type=str, default=None, help="Path to a reference image to which the image must be aligned (rigid).")
vhrprep.add_argument("--align_to_band", type=int, nargs='?', default=3, help="Band of reference image used for alignment.")
vhrprep.add_argument("--align_using_band", type=int, nargs='?', default=1, help="Band of current image used for alignment.")
vhrprep.add_argument("--skip_ps", help="Skip pansharpening step.", action='store_true') vhrprep.add_argument("--skip_ps", help="Skip pansharpening step.", action='store_true')
vhrprep.add_argument("--compress", help="Use lossless compresion on outputs.", action='store_true') vhrprep.add_argument("--compress", help="Use lossless compresion on outputs.", action='store_true')
...@@ -99,7 +106,8 @@ def main(args): ...@@ -99,7 +106,8 @@ def main(args):
arg.n_proc, arg.mem_limit, not arg.keep_graph, arg.force_parallel) arg.n_proc, arg.mem_limit, not arg.keep_graph, arg.force_parallel)
if arg.cmd == "preprocess_spot67": if arg.cmd == "preprocess_spot67":
preprocess_spot67(arg.fld, arg.out_fld, arg.dem_fld, arg.geoid, arg.skip_ps, arg.compress) preprocess_spot67(arg.fld, arg.out_fld, arg.dem_fld, arg.geoid, arg.skip_ps, arg.compress,
arg.clip, arg.align_to, arg.align_to_band, arg.align_using_band)
return 0 return 0
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment