stac.py 8.74 KiB
"""
This module enables various operations from STAC catalogs.

# Spot 6/7 from DINAMIS

```py
from scenes import BoundingBox
from scenes.stac import dinamis_spot67drs

# Get `Spot67DRSScene` instances from STAC search
bbox = BoundingBox(xmin=3.4989739222368432, ymin=49.495592616979884, xmax=4.31841698848754, ymax=50.044096523322686)
scs = dinamis_spot67drs(date_min="04/10/1986", date_max="21/01/2056", bbox_wgs84=bbox)

# Do stuf
for sc in scs:
    print(sc)
scs[0].get_xs().cld_msk_drilled().write("/tmp/test.tif")
```
"""
from __future__ import annotations
import os
import tempfile
import datetime
import threading
from abc import abstractmethod
import requests
from tqdm.autonotebook import tqdm
import pystac
from pystac_client import Client
from scenes.core import Scene
from scenes.spatial import BoundingBox
from scenes.dates import any2datetime
from scenes.spot import Spot67DRSScene
from scenes.sentinel import Sentinel2MPCScene
from scenes.auth import OAuth2KeepAlive


class ProviderBase:
    """
    Base class for STAC providers
    """
    def __init__(self, url, default_collection):
        """

        Args:
            url: STAC endpoint
            default_collection: default collection
        """
        assert url
        self.url = url
        self.default_collection = default_collection
        self.vsicurl_media_types = [pystac.MediaType.GEOPACKAGE,
                                    pystac.MediaType.GEOJSON,
                                    pystac.MediaType.COG,
                                    pystac.MediaType.GEOTIFF,
                                    pystac.MediaType.TIFF,
                                    pystac.MediaType.JPEG2000]

    def stac_search(self, bbox_wgs84: BoundingBox, collections: list[str] = None,
                    date_min: datetime.datetime | str = None, date_max: datetime.datetime | str = None,
                    filt: dict = None, query: dict = None):
        """
        Search an item in a STAC catalog.
        see https://pystac-client.readthedocs.io/en/latest/api.html#pystac_client.Client.search

        Args:
            bbox_wgs84: The bounding box in WGS84 (BoundingBox instance)
            collections: names of the collections to search
            date_min: date min (datetime.datetime or str)
            date_max: date max (datetime.datetime or str)
            filt: JSON of query parameters as per the STAC API filter extension
            query: List or JSON of query parameters as per the STAC API query extension
        Returns:
            Search result

        """
        dt = [any2datetime(date) for date in [date_min, date_max] if date] if date_min or date_max else None
        api = Client.open(self.url)
        results = api.search(max_items=None, bbox=bbox_wgs84.to_list(), datetime=dt, filter=filt, query=query,
                             collections=[self.default_collection] if not collections else collections)
        return results.items()

    def scenes_search(self, *args, as_generator=False, **kwargs) -> list[Scene]:
        """
        Perform a STAC search then converts the resulting items into `Scene` objects

        Args:
            *args: same args as stac_search()
            as_generator: return scenes as generator, or not
            **kwargs: same kwargs as stac_search()

        Returns: a list of `Scenes`

        """
        items = self.stac_search(*args, **kwargs)
        gen = (self.stac_item_to_scene(item) for item in tqdm(items))
        if as_generator:
            return gen
        return list(gen)

    def get_asset_path(self, asset: pystac.asset) -> str:
        """
        Return the URI suited for GDAL if the asset is some geospatial data.
        If the asset.href starts with "http://", "https://", etc. then the returned URI is prefixed with "/vsicurl".

        Args:
            asset: STAC asset

        Returns:
            URI, with or without "/vsicurl/" prefix

        """
        url = asset.href
        if asset.media_type in self.vsicurl_media_types:
            if any(url.lower().startswith(prefix) for prefix in ["http://", "https://", "ftp://", "sftp://"]):
                return f"/vsicurl/{url}"
        return url

    @abstractmethod
    def stac_item_to_scene(self, item: pystac.item) -> Scene:
        """
        Convert a STAC item into a `Scene`

        Args:
            item: STAC item

        Returns: scene

        """
        raise NotImplementedError("")


class DinamisSpot67Provider(ProviderBase):
    """
    Provider for Spot-6/7 DRS products from DINAMIS
    """
    def __init__(self, auth=None):
        super().__init__(url="https://stacapi.147.100.200.143.nip.io",
                         default_collection="spot-6-7-drs")
        self.tmp_dir = tempfile.TemporaryDirectory()
        self.headers_file = None
        self.__auth_headers = None
        self.__auth_headers_lock = threading.Lock()
        # if not auth:
        #     self.headers_file = os.path.join(self.tmp_dir.name, 'headers.txt')
        #     self.auth = OAuth2KeepAlive(
        #         keycloak_server_url="https://stacapi.147.100.200.143.nip.io/auth/",
        #         keycloak_client_id="device-client",
        #         keycloak_realm="dinamis",
        #         refresh_callback=self.update_headers
        #     )

    def set_auth_headers(self, token):
        """
        Set authorization headers

        """
        assert 'access_token' in token
        access_token = token['access_token']
        with self.__auth_headers_lock:
            self.__auth_headers = {"Authorization": f"Bearer {access_token}"}

    def get_auth_headers(self) -> dict[str, str]:
        """
        Get authorization headers

        """
        with self.__auth_headers_lock:
            return self.__auth_headers

    def get_asset_path(self, asset: pystac.asset) -> str:
        """
        Overrides parent method

        Args:
            asset: STAC asset

        Returns:
            URI

        """
        url = super().get_asset_path(asset=asset)
        if self.headers_file:
            url = url.replace("/vsicurl/", f"/vsicurl?header_file={self.headers_file}&url=")
        return url

    def update_headers(self, token):
        """
        Update the authorization header file that GDAL reads

        Args:
            token: token

        """
        self.set_auth_headers(token=token)
        tmp_headers_file = f"{self.headers_file}.tmp"
        old_headers_file = f"{self.headers_file}.old"
        auth_header_str = "\n".join([f'{key}: {value}' for key, value in self.get_auth_headers().items()])
        with open(tmp_headers_file, "w", encoding='UTF-8') as text_file:
            text_file.write(auth_header_str)
        if os.path.exists(self.headers_file):
            os.rename(self.headers_file, old_headers_file)
        os.rename(tmp_headers_file, self.headers_file)

    def stac_item_to_scene(self, item: pystac.item) -> Scene:
        """
        Convert a STAC item into a `Spot67DRSScene`

        Args:
            item: STAC item

        Returns: scene

        """
        return Spot67DRSScene(assets_paths={key: self.get_asset_path(asset) for key, asset in item.assets.items()},
                              assets_headers=self.get_auth_headers())


class MPCProvider(ProviderBase):
    """
    Provider for Microsoft Planetary Computer
    """
    def __init__(self, asset_signing_duration=60):
        super().__init__(url="https://planetarycomputer.microsoft.com/api/stac/v1",
                         default_collection="sentinel-2-l2a")
        self.asset_signing_duration = asset_signing_duration

    def sign_url(self, asset_path: str):
        """
        Sign PC URL

        Args:
            asset_path: asset path

        Returns:

        """
        vsicurl_prefix = "/vsicurl/"
        is_vsi = asset_path.startswith(vsicurl_prefix)
        asset_path = asset_path.replace(vsicurl_prefix, "") if is_vsi else asset_path
        ret = requests.get("https://planetarycomputer.microsoft.com/api/sas/v1/sign?"
                           f"href={asset_path}&duration={self.asset_signing_duration}",
                           headers={"accept": "application/json"},
                           timeout=10)
        assert ret.status_code == 200, f"Request returned: {ret.text}"
        data = ret.json()
        assert "href" in data
        new_url = data["href"]
        return f"{vsicurl_prefix}{new_url}" if is_vsi else new_url

    def stac_item_to_scene(self, item: pystac.item) -> Scene:
        """
        Convert a STAC item into a `Sentinel2MPCScene`

        Args:
            item: STAC item

        Returns: scene

        """
        assets_paths = {key: self.get_asset_path(asset) for key, asset in item.assets.items()}
        assets_paths_func = {key: self.sign_url for key in assets_paths}
        return Sentinel2MPCScene(assets_paths=assets_paths, assets_paths_func=assets_paths_func)