From 63bd402b97a7b46c51c943b1afc3b5e8c3047335 Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@inrae.fr>
Date: Thu, 23 Mar 2023 15:39:16 +0100
Subject: [PATCH] DOC: update + ENH: improve spot tests speed

---
 scenes/core.py              | 24 +++++++++++++
 scenes/spot.py              | 67 +++++++++++++++++--------------------
 test/spot67_imagery_test.py | 43 ++++++++++++++++++------
 3 files changed, 88 insertions(+), 46 deletions(-)

diff --git a/scenes/core.py b/scenes/core.py
index 8ae96d1..4dbdfca 100644
--- a/scenes/core.py
+++ b/scenes/core.py
@@ -284,6 +284,30 @@ class CommonImagerySource(Source):
         })
         return self.new_source(ref_img, superimpose)
 
+    def subset(
+            self,
+            startx: int,
+            starty: int,
+            sizex: int,
+            sizey: int
+    ) -> CommonImagerySource:
+        """
+        Return a subset
+
+        Args:
+            startx: start x
+            starty: start y
+            sizex: size x
+            sizey: size y
+
+        Returns:
+            subset
+
+        """
+        return self.new_source(
+            self[startx:startx+sizex, starty:starty+sizey, :]
+        )
+
     def clip_over_img(
             self,
             ref_img: Union[str, pyotb.core.otbObject]
diff --git a/scenes/spot.py b/scenes/spot.py
index 465c46b..6d2792a 100644
--- a/scenes/spot.py
+++ b/scenes/spot.py
@@ -139,15 +139,15 @@ rgb_nice.write("image.tif", pixel_type="uint8")
 
 """
 from __future__ import annotations
-from typing import List, Dict, Type, Union
 
-from datetime import datetime
 import os
 import re
 import xml.etree.ElementTree as ET
+from datetime import datetime
 from functools import partial
-import requests
+from typing import List, Dict, Type, Union
 from tqdm.autonotebook import tqdm
+import requests
 import pyotb
 
 from scenes import dates
@@ -436,39 +436,34 @@ def spot67_metadata_parser(xml_path: str) -> Dict[str, str]:
     else:
         root = ET.parse(xml_path).getroot()
 
-    metadata = {}
-
-    # Acquisition angles
-    c_nodes = root.find(
-        "Geometric_Data/Use_Area/Located_Geometric_Values/Acquisition_Angles")
-    scalars_mapping = {
-        "AZIMUTH_ANGLE": "azimuth_angle",
-        "VIEWING_ANGLE_ACROSS_TRACK": "viewing_angle_across",
-        "VIEWING_ANGLE_ALONG_TRACK": "viewing_angle_along",
-        "VIEWING_ANGLE": "viewing_angle",
-        "INCIDENCE_ANGLE": "incidence_angle",
+    scalars_mappings = {
+        "Geometric_Data/Use_Area/Located_Geometric_Values/Acquisition_Angles":
+            {
+                "AZIMUTH_ANGLE": "azimuth_angle",
+                "VIEWING_ANGLE_ACROSS_TRACK": "viewing_angle_across",
+                "VIEWING_ANGLE_ALONG_TRACK": "viewing_angle_along",
+                "VIEWING_ANGLE": "viewing_angle",
+                "INCIDENCE_ANGLE": "incidence_angle",
+            },
+        "Geometric_Data/Use_Area/Located_Geometric_Values":
+            {
+                "TIME": "acquisition_date"
+            },
+        "Geometric_Data/Use_Area/Located_Geometric_Values/Solar_Incidences":
+            {
+                "SUN_AZIMUTH": "sun_azimuth",
+                "SUN_ELEVATION": "sun_elevation"
+            }
     }
-    for node in c_nodes:
-        key = node.tag
-        if key in scalars_mapping:
-            new_key = scalars_mapping[key]
-            metadata.update({new_key: float(node.text)})
-
-    # Acquisition date
-    c_nodes = root.find("Geometric_Data/Use_Area/Located_Geometric_Values")
-    for node in c_nodes:
-        if node.tag == "TIME":
-            metadata.update({"acquisition_date": node.text[0:10]})
-            break
-
-    # Sun angles
-    c_nodes = root.find(
-        "Geometric_Data/Use_Area/Located_Geometric_Values/Solar_Incidences")
-    for node in c_nodes:
-        if node.tag == "SUN_AZIMUTH":
-            metadata.update({"sun_azimuth": float(node.text)})
-        elif node.tag == "SUN_ELEVATION":
-            metadata.update({"sun_elevation": float(node.text)})
+
+    metadata = {}
+    for section, scalars_mapping in scalars_mappings.items():
+        for node in root.find(section):
+            key = node.tag
+            if key in scalars_mapping:
+                new_key = scalars_mapping[key]
+                text = node.text
+                metadata[new_key] = float(text) if text.isdigit() else text
 
     return metadata
 
@@ -556,7 +551,7 @@ class Spot67DRSScene(Spot67Scene):
         assert "dimap_xs" in assets_paths, "XS DIMAP XML document is missing"
         additional_md = spot67_metadata_parser(assets_paths["dimap_xs"])
         acquisition_date = datetime.strptime(
-            additional_md["acquisition_date"], "%Y-%m-%d"
+            additional_md["acquisition_date"], "%Y-%m-%dT%H:%M:%S.%fZ"
         )
 
         # Call parent constructor, before accessing to self.assets_paths
diff --git a/test/spot67_imagery_test.py b/test/spot67_imagery_test.py
index e6828fc..6d1afbf 100644
--- a/test/spot67_imagery_test.py
+++ b/test/spot67_imagery_test.py
@@ -1,4 +1,6 @@
 # -*- coding: utf-8 -*-
+from functools import singledispatch
+from typing import Any
 import pyotb
 
 import scenes
@@ -26,6 +28,28 @@ pan = sc.get_pan()
 pxs = sc.get_pxs()
 
 
+@singledispatch
+def subset(inp: Any) -> Any:
+    raise TypeError(
+        "Invalid type, must be one of: pyotb.otbObject, Source"
+    )
+
+
+@subset.register
+def subset_source(inp: scenes.core.Source):
+    return inp.subset(
+        startx=roi[0],
+        sizex=roi[0] + roi[2],
+        starty=roi[1],
+        sizey=roi[1] + roi[3]
+    )
+
+
+@subset.register
+def subset_pyotb(inp: pyotb.otbObject):
+    return inp[roi[0]:roi[0] + roi[2], roi[1]:roi[1] + roi[3], :]
+
+
 class Spot67ImageryTest(ScenesTestBase):
 
     def test_instanciate_sc(self):
@@ -73,9 +97,6 @@ class Spot67ImageryTest(ScenesTestBase):
             roi=roi
         )
 
-    def slice(self, inp):
-        return inp[2048:512, 2048:512, :]
-
     def test_pxs_dn_msk_drilled_cached(self):
         """
         Dummy test since no cloud in the input scene.
@@ -84,9 +105,8 @@ class Spot67ImageryTest(ScenesTestBase):
         """
         for _ in range(2):
             self.compare_images(
-                image=pxs.cld_msk_drilled().cached(),
-                reference=pxs_dn_ref,
-                roi=roi
+                image=subset(pxs).cld_msk_drilled().cached(),
+                reference=subset(pxs_dn_ref)
             )
 
     def test_pxs_toa_msk_drilled_cached(self):
@@ -97,9 +117,8 @@ class Spot67ImageryTest(ScenesTestBase):
         """
         for _ in range(2):
             self.compare_images(
-                image=pxs.reflectance().cld_msk_drilled().cached(),
-                reference=pxs_toa_ref,
-                roi=roi
+                image=subset(pxs).reflectance().cld_msk_drilled().cached(),
+                reference=subset(pxs_toa_ref),
             )
 
     def test_print(self):
@@ -108,6 +127,10 @@ class Spot67ImageryTest(ScenesTestBase):
         print(pan)
         print(pxs)
 
+
 if __name__ == '__main__':
-    #unittest.main()
+    # unittest.main()
     Spot67ImageryTest().test_pxs_toa_msk_drilled_cached()
+    Spot67ImageryTest().test_pxs_dn_msk_drilled_cached()
+    Spot67ImageryTest().test_pxs_toa_msk_drilled()
+    Spot67ImageryTest().test_pxs_dn_msk_drilled()
-- 
GitLab