s1base.py 13.73 KiB
import otbApplication as otb
from Common.otb_numpy_proc import to_otb_pipeline
import os
import glob
import numpy as np
import uuid
import xml.etree.ElementTree as ET
import shutil
from itertools import groupby
from Common.geotools import get_query_bbox
from eodag import EODataAccessGateway

def fetch(shp, dt, output_fld, credentials):
    bbox = get_query_bbox(shp)
    dag = EODataAccessGateway(user_conf_file_path=credentials)
    dag.set_preferred_provider("peps")
    search_criteria = {
        "productType": "S1_SAR_GRD",
        "start": dt.split("/")[0],
        "end": dt.split("/")[1],
        "geom": {"lonmin": bbox[0], "latmin": bbox[1], "lonmax": bbox[2], "latmax": bbox[3]}
    }
    res = dag.search_all(**search_criteria)
    res.filter_property(sensorMode='IW')
    if len(res) > 0:
        os.makedirs(output_fld, exist_ok=True)
        dag.download_all(res, outputs_prefix=output_fld, extract=True)

    # return S1GRDPipeline(output_fld)
    return


class S1GRDPipeline:
    # --- BEGIN SENSOR PROTOTYPE ---
    NAME = 'S1-IW-GRD'
    VAL_TYPE = otb.ImagePixelType_int16
    TMP_TYPE = otb.ImagePixelType_float
    PTRN_dir = 'S1*_IW_GRD*/S1*_IW_GRD*.SAFE'
    PTRN_ref = '-iw-grd-'
    VH_name = 'vh'
    PTRN = ['measurement/s1*-iw-grd-vh-*.tiff', 'measurement/s1*-iw-grd-vv-*.tiff']
    FEAT_exp = {
        'VH': 'im1b1',
        'VV': 'im2b1',
        'VH_db': '1000*log10(abs(im1b1)+1e-6)',
        'VV_db': '1000*log10(abs(im2b1)+1e-6)',
        'POL_RATIO': 'im1b1/im2b1'
    }
    NDT = 0.0

    @classmethod
    def _check(cls, x):
        lst = os.path.basename(cls.PTRN_dir).split('*')
        return all([t in os.path.basename(x) for t in lst])

    @classmethod
    def _img_id(cls, x):
        if cls._check(x):
            return '_'.join(os.path.basename(x).split('_')[4:7])
        else:
            return None

    @classmethod
    def _img_date(cls, x):
        if cls._check(x):
            return os.path.basename(x).split('_')[4].split('T')[0]
        else:
            return None

    @classmethod
    def _check_roi(cls, x, roi, min_surf=0.0):
        if cls._check(x):
            if os.path.isdir(x):
                er = otb.Registry.CreateApplication('ExtractROI')
                bnd = glob.glob(os.path.join(x, cls.PTRN[0]))[0]
                er.SetParameterString('in', bnd)
                er.SetParameterString('mode', 'fit')
                er.SetParameterString('mode.fit.im', roi)
                try:
                    er.Execute()
                except:
                    return False
                arr = er.GetImageAsNumpyArray('out')
                if (np.sum(arr != cls.NDT) / (arr.shape[0] * arr.shape[1])) <= min_surf:
                    return False
                else:
                    return True

    @classmethod
    def _check_direction(cls, x):
        root = ET.parse(os.path.join(x, 'manifest.safe')).getroot()
        for item in root.iter():
            if item.tag.endswith('pass'):
                if item.text == 'ASCENDING':
                    return 0
                elif item.text == 'DESCENDING':
                    return 1
                else:
                    return None

    @classmethod
    def _get_stiched_filename(cls, fn):
        tm1 = os.path.basename(fn).split('-')[-4].split('t')[1]
        tm2 = os.path.basename(fn).split('-')[-5].split('t')[1]
        fn = fn.replace(tm1, 'xxxxxx').replace(tm2, 'xxxxxx')
        return fn

    def _check_satellite(self, x):
        return os.path.basename(x).split('_')[0].lower()

    def __init__(self, fld, roi, temp_fld='/tmp', input_date_interval=None, roi_min_surf=0.0,
                 direction=None, satellite=None):
        self.pipe = []
        self.files = []
        self.types = []
        self.out_p = []
        self.out_idx = []

        self.id = str(uuid.uuid4())
        self.folder = os.path.abspath(fld)
        self.temp_fld = temp_fld + os.sep + self.id
        self.direction = direction
        self.satellite = satellite
        self.roi = roi
        self.image_list = self.parse_folder(self.folder, roi_min_surf)
        self.input_dates = [self._img_date(x) for x in self.image_list]

        if input_date_interval is not None:
            idx = [i for i in range(len(self.input_dates)) if
                   input_date_interval[0] <= self.input_dates[i] <= input_date_interval[1]]
            self.image_list = [self.image_list[i] for i in idx]
            self.input_dates = [self.input_dates[i] for i in idx]

        for img in self.image_list:
            for p in self.PTRN:
                ifn, ofn = self.get_file(img, p)
                self.append(to_otb_pipeline(ifn), ofn, self.VAL_TYPE, 'out', is_output=True)

    def __del__(self):
        if os.path.exists(self.temp_fld):
            shutil.rmtree(self.temp_fld)

    def parse_folder(self, fld, roi_min_surf=0.0):
        img_list = [os.path.abspath(x) for x in glob.glob(os.path.join(fld, self.PTRN_dir))
                    if os.path.isdir(x) and self._check_roi(x, self.roi, roi_min_surf)]

        if self.satellite is not None:
            img_list = [x for x in img_list if self._check_satellite(x) == self.satellite]

        asc, desc = [], []
        for x in img_list:
            if self._check_direction(x) == 0:
                asc.append(x)
            elif self._check_direction(x) == 1:
                desc.append(x)

        out = None
        if self.direction == 'ascending':
            out = asc
        elif self.direction == 'descending':
            out = desc
        elif self.direction == None:
            if len(asc) >= len(desc):
                self.direction = 'ascending'
                out = asc
            else:
                self.direction = 'descending'
                out = desc
            print('[INFO] No direction selected, returning majority : ' + self.direction)
        return sorted(out, key=lambda x: self._img_id(x))

    def get_file(self, img, ptrn):
        in_fn = glob.glob(os.path.join(img, ptrn))[0]
        out_fn = in_fn.replace(self.folder, '').lstrip(os.sep)
        return in_fn, os.path.splitext(out_fn)[0] + '.tiff'

    def append(self, app, fname=None, ftype=None, outp=None, is_output=False):
        if is_output:
            self.out_idx.append(len(self.pipe))
        self.pipe.append(app)
        self.files.append(fname)
        self.types.append(ftype)
        self.out_p.append(outp)

    def calibrate(self, lut='sigma'):
        proc_idx = self.out_idx.copy()
        self.out_idx = []
        for t in proc_idx:
            sarcal = otb.Registry.CreateApplication('SARCalibration')
            sarcal.SetParameterInputImage('in', self.pipe[t].GetParameterOutputImage(self.out_p[t]))
            sarcal.SetParameterString('lut', lut)
            sarcal.Execute()
            fn = self.files[t].replace('.tiff', '_calib.tiff')
            ty = self.TMP_TYPE
            self.append(sarcal, fn, ty, 'out', is_output=True)

    def orthorectify(self, dem_fld, geoid=None, grid_spacing=40):
        assert(os.path.isdir(dem_fld))
        proc_idx = self.out_idx.copy()
        self.out_idx = []
        for t in proc_idx:
            ortho = otb.Registry.CreateApplication('OrthoRectification')
            ortho.SetParameterString('elev.dem', dem_fld)
            if geoid is not None:
                assert (os.path.exists(geoid))
                ortho.SetParameterString('elev.geoid', geoid)
            ortho.SetParameterInputImage('io.in', self.pipe[t].GetParameterOutputImage(self.out_p[t]))
            ortho.SetParameterInt('opt.gridspacing', grid_spacing)
            ortho.SetParameterString('outputs.mode', 'orthofit')
            ortho.SetParameterString('outputs.ortho', self.roi)
            ortho.Execute()
            fn = self.files[t].replace('.tiff', '_ortho.tiff')
            ty = self.TMP_TYPE
            self.append(ortho, fn, ty, 'io.out', is_output=True)

    def superimpose(self):
        proc_idx = self.out_idx.copy()
        self.out_idx = []
        for t in proc_idx:
            si = otb.Registry.CreateApplication('Superimpose')
            si.SetParameterString('inr', self.roi)
            si.SetParameterInputImage('inm', self.pipe[t].GetParameterOutputImage(self.out_p[t]))
            si.Execute()
            fn = self.files[t].replace('.tiff', '_roi.tiff')
            self.append(si, fn, self.TMP_TYPE, 'out', is_output=True)

    def stitch(self):
        proc_idx = self.out_idx.copy()
        self.out_idx = []
        q = [[self.image_list.index(v) for v in list(i)]
             for j, i in groupby(self.image_list, lambda x: self._img_date(x))]
        for s in q:
            if len(s) == 1:
                self.append(self.pipe[proc_idx[2 * s[0]]], self.files[proc_idx[2 * s[0]]],
                            self.types[proc_idx[2 * s[0]]], self.out_p[proc_idx[2 * s[0]]], is_output=True)
                self.append(self.pipe[proc_idx[2 * s[0] + 1]], self.files[proc_idx[2 * s[0] + 1]],
                            self.types[proc_idx[2 * s[0] + 1]], self.out_p[proc_idx[2 * s[0] + 1]], is_output=True)
            else:
                for k in range(2):
                    fn = self._get_stitched_filename(self.files[proc_idx[2*s[0]+k]])
                    bm = otb.Registry.CreateApplication('BandMathX')
                    [bm.AddImageToParameterInputImageList('il', self.pipe[proc_idx[2*u+k]].GetParameterOutputImage(self.out_p[proc_idx[2*u+k]])) for u in s]
                    bm.SetParameterString('exp',
                                          'vmax({' + ';'.join(['im%db1' % (i + 1) for i in range(len(s))]) + '})')
                    bm.Execute()
                    self.append(bm, fn, self.types[proc_idx[2*s[0]+k]], 'out', is_output=True)

    def multitemp_speckle_filter(self, win_size=3, outcore_on_disk=True):
        proc_idx = self.out_idx.copy()
        self.out_idx = []
        tmp_pipe = []
        tmp_fns = []
        ty = self.TMP_TYPE
        for u in range(2):
            oc = otb.Registry.CreateApplication('MultitempFilteringOutcore')
            for t in proc_idx[u::2]:
                oc.AddImageToParameterInputImageList('inl', self.pipe[t].GetParameterOutputImage(self.out_p[t]))
            oc.SetParameterString('wr', str(win_size))
            if not outcore_on_disk:
                oc.Execute()
                oc_idx = len(self.pipe)
                self.append(oc)
            else:
                if not os.path.exists(self.temp_fld):
                    os.makedirs(self.temp_fld)
                oc_fn = os.path.join(self.temp_fld, 'outcore_{}.tif'.format(['vh', 'vv'][u]))
                oc.SetParameterString('oc', oc_fn)
                oc.ExecuteAndWriteOutput()
            for t in proc_idx[u::2]:
                smooth = otb.Registry.CreateApplication('Smoothing')
                smooth.SetParameterInputImage('in', self.pipe[t].GetParameterOutputImage(self.out_p[t]))
                smooth.SetParameterString('type', 'mean')
                smooth.SetParameterString('type.mean.radius', str(win_size))
                smooth.Execute()
                self.append(smooth)
                bm = otb.Registry.CreateApplication('BandMath')
                if not outcore_on_disk:
                    bm.AddImageToParameterInputImageList('il', self.pipe[oc_idx].GetParameterOutputImage('oc'))
                else:
                    bm.SetParameterStringList('il', [oc_fn])
                bm.AddImageToParameterInputImageList('il', self.pipe[-1].GetParameterOutputImage('out'))
                bm.SetParameterString('exp', 'im2b1*im1b1/im1b2')
                bm.Execute()
                tmp_fns.append(self.files[t].replace('.tiff', '_filt.tiff'))
                tmp_pipe.append(bm)

        N = int(len(tmp_pipe) / 2)
        for i in range(N):
            self.append(tmp_pipe[i], tmp_fns[i], ty, 'out', is_output=True)
            self.append(tmp_pipe[N + i], tmp_fns[N + i], ty, 'out', is_output=True)

    def compute_features(self, feat_list=['VH_db', 'VV_db']):
        proc_idx = self.out_idx.copy()
        self.out_idx = []
        for k in range(0,len(proc_idx),2):
            bm = otb.Registry.CreateApplication('BandMathX')
            bm.AddImageToParameterInputImageList('il', self.pipe[proc_idx[k]].GetParameterOutputImage(
                self.out_p[proc_idx[k]]))
            bm.AddImageToParameterInputImageList('il', self.pipe[proc_idx[k+1]].GetParameterOutputImage(
                self.out_p[proc_idx[k+1]]))
            expr = []
            for f in feat_list:
                expr.append(self.FEAT_exp[f])
            expr = '{' + ';'.join(expr) + '}'
            bm.SetParameterString('exp', expr)
            bm.Execute()
            fn = self.files[proc_idx[k]].replace('-vh', '-feat')
            if all(['db' in x for x in feat_list]):
                ty = self.VAL_TYPE
            else:
                ty = self.TMP_TYPE
            self.append(bm, fn, ty, 'out', is_output=True)


    def write_outputs(self, fld, update_pipe=False, compress=False):
        out = []
        proc_idx = self.out_idx.copy()
        if update_pipe:
            self.out_idx = []
        for t in proc_idx:
            er = otb.Registry.CreateApplication('ExtractROI')
            er.SetParameterInputImage('in', self.pipe[t].GetParameterOutputImage(self.out_p[t]))
            er.SetParameterString('mode', 'fit')
            er.SetParameterString('mode.fit.im', self.roi)
            out_file = os.path.join(fld, self.files[t])
            if compress:
                out_file += '?gdal:co:compress=deflate'
            if not os.path.exists(os.path.dirname(out_file)):
                os.makedirs(os.path.dirname(out_file))
            er.SetParameterString('out', out_file)
            er.SetParameterOutputImagePixelType('out', self.types[t])
            er.ExecuteAndWriteOutput()
            out.append(out_file)
            if update_pipe:
                self.append(to_otb_pipeline(out_file), self.files[t], self.types[t], 'out', is_output=True)
        return out