From da9ee9268fd90c9b45a15170117629960a18cb1e Mon Sep 17 00:00:00 2001
From: "raffaele.gaetano" <raffaele.gaetano@cirad.fr>
Date: Tue, 26 May 2020 11:52:58 +0200
Subject: [PATCH] ENH: some vector processing functions added.

---
 mtdUtils.py | 83 +++++++++++++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 80 insertions(+), 3 deletions(-)

diff --git a/mtdUtils.py b/mtdUtils.py
index ef13f96..f53f728 100644
--- a/mtdUtils.py
+++ b/mtdUtils.py
@@ -294,7 +294,7 @@ def getRasterExtentAsShapefile(ras, shp):
     shp_ly.CreateFeature(shp_feat)
     return 0
 
-def cloneVectorDataStructure(ds_in, fname, ly = 0, epsg = None):
+def cloneVectorDataStructure(ds_in, fname, ly = 0, epsg = None, drv_name='ESRI Shapefile'):
     ds_in_ly = ds_in.GetLayer(ly)
 
     if epsg is None:
@@ -303,7 +303,7 @@ def cloneVectorDataStructure(ds_in, fname, ly = 0, epsg = None):
         srs = osr.SpatialReference()
         srs.ImportFromEPSG(epsg)
 
-    drv = ogr.GetDriverByName('ESRI Shapefile')
+    drv = ogr.GetDriverByName(drv_name)
     ds_out = drv.CreateDataSource(fname)
     ds_out_ly = ds_out.CreateLayer(os.path.splitext(os.path.basename(fname))[0],
                                    srs=srs,
@@ -943,4 +943,81 @@ def getBufferedGeographicExtent(img,buf = 0, toFile = None):
         ds = None
 
 
-    return ul.GetX(), ul.GetY(), lr.GetX(), lr.GetY()
\ No newline at end of file
+    return ul.GetX(), ul.GetY(), lr.GetX(), lr.GetY()
+
+
+def assign_class_id(fname,existing_class_fld,class_ids,class_fld='ID_CLS'):
+    ds = ogr.Open(fname,0)
+    ly = ds.GetLayer()
+    fc = []
+    for f in ly:
+        fc.append(f.GetField(existing_class_fld))
+
+    lyd = ly.GetLayerDefn()
+    ln = [lyd.GetFieldDefn(i).GetName() for i in range(lyd.GetFieldCount())]
+
+    ly = None
+    ds = None
+
+    clset = sorted(set(fc))
+    assert(len(clset)==len(class_ids))
+    cldict = {x[0]:x[1] for x in zip(clset,class_ids)}
+
+    ty = None
+    if isinstance(class_ids[0], int):
+        ty = ogr.OFTInteger
+    elif isinstance(class_ids[0], str):
+        ty = ogr.OFTString
+
+    assert(ty is not None)
+
+    ds = ogr.Open(fname,1)
+    ly = ds.GetLayer()
+
+    if class_fld not in ln:
+        ly.CreateField(ogr.FieldDefn(class_fld, ty))
+
+    for q in ly:
+        k = q.GetField(existing_class_fld)
+        n = cldict[k]
+        q.SetField(class_fld,n)
+        ly.SetFeature(q)
+
+    ds = None
+
+    return cldict
+
+def per_class_shapefiles(fname,class_fld):
+
+    def toalnum(s):
+        return ''.join(c if c.isalnum() else '_' for c in s)
+
+    ds = ogr.Open(fname,0)
+    ly = ds.GetLayer()
+    fc = []
+    for f in ly:
+        fc.append(f.GetField(class_fld))
+
+    clset = sorted(set(fc))
+
+    basename = os.path.splitext(os.path.basename(fname))[0]
+    fld = os.path.dirname(fname) + os.sep + basename + '_per_class'
+    flist = [fld + os.sep + basename + '_' + toalnum(x) + '.sqlite' for x in clset]
+
+    ds_out = []
+    if not os.path.exists(fld):
+        os.mkdir(fld)
+    for fn in flist:
+        ds_out.append(cloneVectorDataStructure(ds, fn, drv_name='SQLite'))
+
+    dsdict = {x[0]: x[1] for x in zip(clset, ds_out)}
+
+    ly.ResetReading()
+    for f in ly:
+        dsdict[f.GetField(class_fld)].GetLayer().CreateFeature(f)
+
+    ds = None
+    for xds in ds_out:
+        xds = None
+
+    return
\ No newline at end of file
-- 
GitLab