Commit b1ef5e51 authored by Cresson Remi's avatar Cresson Remi
Browse files

Merge branch '6_deepnets' into 'develop'

Add deepnets

Closes #6

See merge request !23
Showing with 241 additions and 55 deletions
+241 -55
...@@ -222,7 +222,6 @@ class Scene(ABC): ...@@ -222,7 +222,6 @@ class Scene(ABC):
self.bbox_wgs84 = bbox_wgs84 self.bbox_wgs84 = bbox_wgs84
assert isinstance(epsg, int), "epsg must be an int" assert isinstance(epsg, int), "epsg must be an int"
self.epsg = epsg self.epsg = epsg
self.metadata = self.get_metadata()
@abstractmethod @abstractmethod
def get_imagery(self, **kwargs): def get_imagery(self, **kwargs):
...@@ -251,7 +250,7 @@ class Scene(ABC): ...@@ -251,7 +250,7 @@ class Scene(ABC):
def get_serializable_metadata(self): def get_serializable_metadata(self):
"""Enable one instance to be used with print()""" """Enable one instance to be used with print()"""
return {k: str(v) for k, v in self.metadata.items()} return {k: str(v) for k, v in self.get_metadata().items()}
def __repr__(self): def __repr__(self):
""" """
......
# -*- coding: utf-8 -*-
"""
This module aims to deal with dates
"""
import datetime
def str2datetime(datestr):
"""
Converts an input date as string into a datetime instance.
Args:
datestr: date (str) in the format "YYYY-MM-DD" or "DD/MM/YYYY" or "DD-MM-YYYY"
Returns:
A datetime.datetime instance
"""
# source (with a few enhancements):
# https://stackoverflow.com/questions/23581128/how-to-format-date-string-via-multiple-formats-in-python
assert isinstance(datestr, str), "Input must be a str!"
formats = ('%Y-%m-%d', '%d/%m/%Y', '%d-%m-%Y')
for fmt in formats:
try:
return datetime.datetime.strptime(datestr, fmt)
except ValueError:
pass
raise ValueError(f'No valid date format found. Accepted formats are {formats}. Input was: {datestr}')
def any2datetime(str_or_datetime):
"""
Normalizes the input such as the returned object is a datetime.datetime instance.
Args:
str_or_datetime: a str (see `str2datetime()` for supported dates formats) or a datetime.datetime
Returns:
A datetime.datetime instance
"""
if isinstance(str_or_datetime, datetime.datetime):
return str_or_datetime
assert isinstance(str_or_datetime, str), "Date must be a str, or a datetime.datetime instance!"
return str2datetime(str_or_datetime)
"""
This module provides tools to easily interact with deep learning models.
It is based on OTBTF.
"""
import os
import zipfile
import pyotb
from scenes import download
SR4RS_MODEL_URL = "https://nextcloud.inrae.fr/s/boabW9yCjdpLPGX/download/" \
"sr4rs_sentinel2_bands4328_france2020_savedmodel.zip"
def inference(dic):
"""
Generic function to perform deep nets inference.
When OTBTF is not found, a warning message is printed.
Args:
dic: otb parameters dict
Returns:
pyotb App instance. When OTBTF is not found, None is returned.
"""
output = None
try:
output = pyotb.TensorflowModelServe(dic)
except ImportError as e:
print("OTBTF has not been found in the system! It is mandatory to use deepnets. "
f"Original error: {e}")
return output
def sr4rs(input_image, model_url=SR4RS_MODEL_URL, tmp_dir="/tmp"):
"""
Applies the SR4RS model for super-resolution
See https://github.com/remicres/sr4rs for implementation details.
Args:
input_image: pyotb Input
model_url: SR4RS pre-trained model URL. Must point to a online .zip file.
tmp_dir: directory for temporary files.
Returns:
pyotb Output
"""
efield = 512
gen_fcn = 64
ratio = 0.25
rfield = int((efield + 2 * gen_fcn) * ratio)
# download the model if not already here
tmp_zip = os.path.join(tmp_dir, os.path.basename(model_url))
tmp_unzipped = os.path.splitext(tmp_zip)[0]
if not os.path.exists(tmp_unzipped):
download.curl_url(model_url, postdata=None, out_file=tmp_zip)
with zipfile.ZipFile(tmp_zip, 'r') as zip_ref:
print("Unzipping model...")
zip_ref.extractall(tmp_dir)
return inference({"source1.il": input_image,
"source1.rfieldx": rfield,
"source1.rfieldy": rfield,
"source1.placeholder": "lr_input",
"model.dir": tmp_unzipped,
"model.fullyconv": "on",
"output.names": f"output_{gen_fcn}",
"output.efieldx": efield,
"output.efieldy": efield,
"output.spcscale": ratio,
})
...@@ -10,19 +10,7 @@ import json ...@@ -10,19 +10,7 @@ import json
from urllib.parse import urlencode from urllib.parse import urlencode
import pycurl import pycurl
from tqdm.autonotebook import tqdm from tqdm.autonotebook import tqdm
from scenes import dates
def bbox2str(bbox):
"""Return a str containing the bounding box
Args:
bbox: the bounding box (BoundingBox instance)
Returns:
a string
"""
return '{},{},{},{}'.format(bbox.ymin, bbox.xmin, bbox.ymax, bbox.xmax)
def compute_md5(fname): def compute_md5(fname):
...@@ -55,14 +43,14 @@ def is_file_complete(filename, md5sum): ...@@ -55,14 +43,14 @@ def is_file_complete(filename, md5sum):
return md5sum == compute_md5(filename) return md5sum == compute_md5(filename)
def curl_url(url, postdata, verbose=False, fp=None, header=None): def curl_url(url, postdata, verbose=False, out_file=None, header=None):
"""Use PyCurl to make some requests """Use PyCurl to make some requests
Args: Args:
url: url url: url
postdata: POST data postdata: POST data
verbose: boolean (Default value = False) verbose: boolean (Default value = False)
fp: file handle (Default value = None) out_file: output file (Default value = None)
header: header. If None is kept, ['Accept:application/json'] is used (Default value = None) header: header. If None is kept, ['Accept:application/json'] is used (Default value = None)
Returns: Returns:
...@@ -84,34 +72,36 @@ def curl_url(url, postdata, verbose=False, fp=None, header=None): ...@@ -84,34 +72,36 @@ def curl_url(url, postdata, verbose=False, fp=None, header=None):
storage = io.BytesIO() storage = io.BytesIO()
if verbose: if verbose:
c.setopt(pycurl.VERBOSE, 1) c.setopt(pycurl.VERBOSE, 1)
if fp is not None: if out_file is not None:
progress_bar = None with open(out_file, "wb") as fp:
last_download_d = 0 progress_bar = None
print("Downloading", flush=True) last_download_d = 0
print("Downloading", flush=True)
def _status(download_t, download_d, *_):
"""Callback function for c.XFERINFOFUNCTION def _status(download_t, download_d, *_):
https://stackoverflow.com/questions/19724222/pycurl-attachments-and-progress-functions """Callback function for c.XFERINFOFUNCTION
https://stackoverflow.com/questions/19724222/pycurl-attachments-and-progress-functions
Args:
download_t: total Args:
download_d: already downloaded download_t: total
*_: any additional param (won't be used) download_d: already downloaded
*_: any additional param (won't be used)
"""
if download_d > 0: """
nonlocal progress_bar, last_download_d if download_d > 0:
if not progress_bar: nonlocal progress_bar, last_download_d
progress_bar = tqdm(total=download_t, unit='iB', unit_scale=True) if not progress_bar:
progress_bar.update(download_d - last_download_d) progress_bar = tqdm(total=download_t, unit='iB', unit_scale=True)
last_download_d = download_d progress_bar.update(download_d - last_download_d)
last_download_d = download_d
c.setopt(c.NOPROGRESS, False)
c.setopt(c.XFERINFOFUNCTION, _status) c.setopt(c.NOPROGRESS, False)
c.setopt(pycurl.WRITEDATA, fp) c.setopt(c.XFERINFOFUNCTION, _status)
c.setopt(pycurl.WRITEDATA, fp)
c.perform()
else: else:
c.setopt(pycurl.WRITEFUNCTION, storage.write) c.setopt(pycurl.WRITEFUNCTION, storage.write)
c.perform() c.perform()
c.close() c.close()
content = storage.getvalue() content = storage.getvalue()
return content.decode(encoding="utf-8", errors="strict") return content.decode(encoding="utf-8", errors="strict")
...@@ -161,6 +151,19 @@ class TheiaDownloader: ...@@ -161,6 +151,19 @@ class TheiaDownloader:
# Maximum number of records # Maximum number of records
self.max_records = max_records self.max_records = max_records
@staticmethod
def _bbox2str(bbox):
"""Return a str containing the bounding box
Args:
bbox: the bounding box (BoundingBox instance)
Returns:
a string
"""
return '{},{},{},{}'.format(bbox.ymin, bbox.xmin, bbox.ymax, bbox.xmax)
def _get_token(self): def _get_token(self):
"""Get the THEIA token""" """Get the THEIA token"""
postdata_token = {"ident": self.config["login_theia"], "pass": self.config["password_theia"]} postdata_token = {"ident": self.config["login_theia"], "pass": self.config["password_theia"]}
...@@ -244,9 +247,7 @@ class TheiaDownloader: ...@@ -244,9 +247,7 @@ class TheiaDownloader:
# Check if the destination file exist and is correct # Check if the destination file exist and is correct
if not is_file_complete(filename, description["checksum"]): if not is_file_complete(filename, description["checksum"]):
print("Downloading {}".format(acq_date)) print("Downloading {}".format(acq_date))
file_handle = open(filename, "wb") curl_url(url, postdata=None, out_file=filename, header=header)
curl_url(url, postdata=None, fp=file_handle, header=header)
file_handle.close()
else: else:
print("{} already downloaded. Skipping.".format(acq_date)) print("{} already downloaded. Skipping.".format(acq_date))
description["local_file"] = filename description["local_file"] = filename
...@@ -258,7 +259,7 @@ class TheiaDownloader: ...@@ -258,7 +259,7 @@ class TheiaDownloader:
Args: Args:
bbox_wgs84: bounding box (WGS84) bbox_wgs84: bounding box (WGS84)
dates_range: a tuple of datetime.datetime instances (start_date, end_date) dates_range: a tuple of datetime.datetime or str instances (start_date, end_date)
download_dir: download directory (Default value = None) download_dir: download directory (Default value = None)
level: LEVEL2A, LEVEL3A, ... (Default value = "LEVEL3A") level: LEVEL2A, LEVEL3A, ... (Default value = "LEVEL3A")
...@@ -267,10 +268,11 @@ class TheiaDownloader: ...@@ -267,10 +268,11 @@ class TheiaDownloader:
""" """
start_date, end_date = dates_range start_date, end_date = dates_range
dict_query = { dict_query = {
"box": bbox2str(bbox_wgs84), # lonmin, latmin, lonmax, latmax "box": self._bbox2str(bbox_wgs84), # lonmin, latmin, lonmax, latmax
"startDate": start_date.strftime("%Y-%m-%d"), "startDate": dates.any2datetime(start_date).strftime("%Y-%m-%d"),
"completionDate": end_date.strftime("%Y-%m-%d"), "completionDate": dates.any2datetime(end_date).strftime("%Y-%m-%d"),
"maxRecords": self.max_records, "maxRecords": self.max_records,
"processingLevel": level "processingLevel": level
} }
...@@ -289,7 +291,7 @@ class TheiaDownloader: ...@@ -289,7 +291,7 @@ class TheiaDownloader:
Args: Args:
bbox_wgs84: bounding box (WGS84) bbox_wgs84: bounding box (WGS84)
acq_date: acquisition date to look around acq_date: acquisition date to look around (datetime.datetime or str)
download_dir: download directory (Default value = None) download_dir: download directory (Default value = None)
level: LEVEL2A, LEVEL3A, ... (Default value = "LEVEL3A") level: LEVEL2A, LEVEL3A, ... (Default value = "LEVEL3A")
...@@ -302,9 +304,10 @@ class TheiaDownloader: ...@@ -302,9 +304,10 @@ class TheiaDownloader:
ndays_seek = datetime.timedelta(days=17) # temporal range to check for monthly synthesis ndays_seek = datetime.timedelta(days=17) # temporal range to check for monthly synthesis
# Query products # Query products
dict_query = {'box': bbox2str(bbox_wgs84)} # lonmin, latmin, lonmax, latmax actual_date = dates.any2datetime(acq_date)
start_date = acq_date - ndays_seek dict_query = {'box': self._bbox2str(bbox_wgs84)} # lonmin, latmin, lonmax, latmax
end_date = acq_date + ndays_seek start_date = actual_date - ndays_seek
end_date = actual_date + ndays_seek
dict_query['startDate'] = start_date.strftime("%Y-%m-%d") dict_query['startDate'] = start_date.strftime("%Y-%m-%d")
dict_query['completionDate'] = end_date.strftime("%Y-%m-%d") dict_query['completionDate'] = end_date.strftime("%Y-%m-%d")
...@@ -320,7 +323,7 @@ class TheiaDownloader: ...@@ -320,7 +323,7 @@ class TheiaDownloader:
for description_date in search_results[tile_name]: for description_date in search_results[tile_name]:
print("\t" + description_date) print("\t" + description_date)
dt = datetime.datetime.strptime(description_date, "%Y-%m-%d") dt = datetime.datetime.strptime(description_date, "%Y-%m-%d")
delta = acq_date - dt delta = actual_date - dt
delta = delta.days delta = delta.days
search_results[tile_name][description_date]["delta"] = delta search_results[tile_name][description_date]["delta"] = delta
......
...@@ -313,3 +313,67 @@ class Sentinel23AScene(Sentinel2SceneBase): ...@@ -313,3 +313,67 @@ class Sentinel23AScene(Sentinel2SceneBase):
"FLG R2": self.flg_r2_msk_file, "FLG R2": self.flg_r2_msk_file,
}) })
return metadata return metadata
def get_scene(archive):
"""
Return the right sentinel scene instance from the givent archive (L2A or L3A)
Args:
archive: L3A or L3A archive
Returns:
a Sentinel23AScene or Sentinel22AScene instance
"""
splits = utils.basename(archive).split("_")
if len(splits) == 5:
level = splits[2]
if level == "L3A":
return Sentinel23AScene(archive)
if level == "L2A":
return Sentinel22AScene(archive)
print(f"Warning: file {archive} is not a valid Sentinel-2 product")
return None
def get_local_scenes(root_dir, tile=None):
"""
Retrieve the sentinel scenes in the directory
Args:
root_dir: directory
tile: tile name (optional) e.g. 31TEJ
Returns:
a list of sentinel scenes instances
"""
scenes_list = []
archives = utils.find_files_in_all_subdirs(pth=root_dir, pattern="*.zip", case_sensitive=False)
for archive in archives:
candidate = get_scene(archive)
if candidate:
tile_name = archive.split("_")[3]
if not tile or tile_name == tile:
scenes_list.append(candidate)
return scenes_list
def get_downloaded_scenes(download_results):
"""
Retrieve the sentinel scenes from the download results from the TheiaDownloader
Args:
download_results: dict as generated from the TheiaDownloader
Returns:
a list of sentinel scenes instances
"""
scenes_list = []
for _, products in download_results.items():
for _, dic in products.items():
archive = dic['local_file']
scenes_list.append(get_scene(archive))
return scenes_list
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment