Commit bf7cb042 authored by Cresson Remi's avatar Cresson Remi
Browse files

Merge branch 'fix-tfrecords-dtype' into 'modifs'

Preserve image data type when converting to TFRecords

See merge request !31
parents 2076b5b6 60851101
3 merge requests!39Release 3.2,!31Preserve image data type when converting to TFRecords,!26TFRecord, to_tfrecords(), refac python modules, new CI
Pipeline #35863 passed with stages
in 127 minutes and 43 seconds
Showing with 9 additions and 3 deletions
+9 -3
......@@ -118,7 +118,7 @@ sr4rs:
- wget -O sr4rs_data.zip https://nextcloud.inrae.fr/s/kDms9JrRMQE2Q5z/download
- unzip -o sr4rs_data.zip
- rm -rf sr4rs
- git clone https://github.com/remicres/sr4rs.git
- git clone -b 44-cast_float_input https://github.com/remicres/sr4rs.git
- export PYTHONPATH=$PYTHONPATH:$PWD/sr4rs
- python -m pytest --junitxml=$ARTIFACT_TEST_DIR/report_sr4rs.xml $OTBTF_SRC/test/sr4rs_unittest.py
......
......@@ -38,12 +38,13 @@ def gdal_open(filename):
return gdal_ds
def read_as_np_arr(gdal_ds, as_patches=True):
def read_as_np_arr(gdal_ds, as_patches=True, dtype=None):
"""
Read a GDAL raster as numpy array
:param gdal_ds: a GDAL dataset instance
:param as_patches: if True, the returned numpy array has the following shape (n, psz_x, psz_x, nb_channels). If
False, the shape is (1, psz_y, psz_x, nb_channels)
:param dtype: if not None array dtype will be cast to given numpy data type (np.float32, np.uint16...)
:return: Numpy array of dim 4
"""
buffer = gdal_ds.ReadAsArray()
......@@ -56,4 +57,9 @@ def read_as_np_arr(gdal_ds, as_patches=True):
else:
n_elems = 1
size_y = gdal_ds.RasterYSize
return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount)))
buffer = buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount))
if dtype is not None:
buffer = buffer.astype(dtype)
return buffer
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment