diff --git a/classificationWorkflow.py b/classificationWorkflow.py index bd74d10731a580facabaa4111736c2c738851147..8e41e1a454660028d1531f84104173a2235e6cc4 100644 --- a/classificationWorkflow.py +++ b/classificationWorkflow.py @@ -164,18 +164,38 @@ def baseDeepTraining(shp,code,flds, params, model_file, out_scaler, csv_classes, model.add(tf.keras.layers.Dense(nb_class, activation="softmax")) model.compile(loss="sparse_categorical_crossentropy", - optimizer="sgd", + optimizer="Adam", metrics=["accuracy"]) - history = model.fit(cfeats, ctargets, epochs=epochs) - model.save(model_file) - joblib.dump(scaler,out_scaler) + history = model.fit(cfeats, ctargets, epochs=epochs) + model.save(model_file) + joblib.dump(scaler,out_scaler) + + # elif params[1].lowercase() == 'm3fusion': + # from tf.keras.callbacks import CSVLogger + # weights_path = model_file.replace('.','_weights.') + # log_path = model_file.replace('.h5','.log') + # log = CSVLogger(log_path, separator=',', append=False) + # n_timestamps,ts_size = TFmodels.retrieveTSinfo(os.path.join(os.path.dirname(shp),'field_names.csv')) + # patch_size = 25 + # n_bands = 4 + + # model = TFmodelss.M3Fusion(nb_class,n_timestamps,patch_size,n_bands)ts_size,patch_size,n_bands + # model.compile(optimizer="Adam", + # loss="sparse_categorical_crossentropy", + # loss_weights=[1,0.3,0.3], + # metrics=['accuracy'] + # ) + + # model.fit([x_train_rnn,x_train_cnn],[labels,labels,labels], epochs=200, batch_size=128, validation_data=([x_valid_rnn,x_valid_cnn],[labels_val,labels_val,labels_val]), callbacks=[log]) + # model.save(model_file) + # model.save_weights(weights_path) + with open(csv_classes,'w') as file: writer = csv.writer(file) for i,c in enumerate(classes): writer.writerow([i,c]) - def baseDeepClassify(shp, model_file, code,flds, out_file, compute_confidence=False): import geopandas import tensorflow as tf @@ -209,7 +229,7 @@ def baseDeepClassify(shp, model_file, code,flds, out_file, compute_confidence=Fa out_file = shp ds.to_file(out_file) -def deepTraining(shp,code,model_fld,params,feat,feat_mode = 'list', epochs=20): +def deepTraining(data,code,model_fld,params,feat,feat_mode = 'list', epochs=20): import geopandas import tensorflow as tf from sklearn.preprocessing import StandardScaler @@ -218,26 +238,29 @@ def deepTraining(shp,code,model_fld,params,feat,feat_mode = 'list', epochs=20): if params[1] == 'standard': ext = "tfDense" + elif params[1].lower() == 'm3fusion': + ext = "M3F" + if len(data) not in [3,6]: + sys.exit("Wrong number of arguments for train/valid dataset, must be 3 or 6") model_file = os.path.join(model_fld, ext + '_' + code + '.h5') scaler_file = os.path.join(model_fld, ext + '_scaler_' + code + '.joblib') + scaler_files = [] csv_classes = os.path.join(model_fld, ext + '_class_'+ code + '.csv') - #Read training shp with geopandas - ds = geopandas.read_file(shp) - #Extract feats and - feats = ds.filter(feat,axis=1) - targets = ds[code] - nb_class = len(targets.unique()) - classes = np.array(sorted(targets.unique())) - nb_class = len(targets.unique()) - ctargets = np.array([np.where(classes==i)[0][0] for i in targets]) - - scaler = StandardScaler() - scaler.fit(feats) - cfeats = scaler.transform(feats) - if params[1] == 'standard': + #Read training shp with geopandas + ds = geopandas.read_file(data) + #Extract feats and + feats = ds.filter(feat,axis=1) + targets = ds[code] + nb_class = len(targets.unique()) + classes = np.array(sorted(targets.unique())) + ctargets = np.array([np.where(classes==i)[0][0] for i in targets]) + + scaler = StandardScaler() + scaler.fit(feats) + cfeats = scaler.transform(feats) # Model init model = tf.keras.models.Sequential() @@ -251,15 +274,70 @@ def deepTraining(shp,code,model_fld,params,feat,feat_mode = 'list', epochs=20): optimizer="sgd", metrics=["accuracy"]) - history = model.fit(cfeats, ctargets, epochs=epochs) - model.save(model_file) - joblib.dump(scaler,scaler_file) + history = model.fit(cfeats, ctargets, epochs=epochs) + model.save(model_file) + joblib.dump(scaler,scaler_file) + scaler_files.append(scaler_file) + elif params[1].lower() == 'm3fusion': + from tensorflow.keras.callbacks import CSVLogger + import TFmodels + # weights_path = model_file.replace('.','_weights.') + log_path = model_file.replace('.h5','.log') + log = CSVLogger(log_path, separator=',', append=False) + ts_size, n_timestamps = TFmodels.retrieveTSinfo(feat) + print(ts_size, n_timestamps) + patch_size = 25 + n_bands = 4 + + x_train_rnn = np.load(data[0]) + scaler_rnn = StandardScaler() + scaler_rnn.fit(x_train_rnn) + x_train_rnn = scaler_rnn.transform(x_train_rnn) + x_train_rnn = np.reshape(x_train_rnn, (x_train_rnn.shape[0], n_timestamps, -1)) + + x_train_cnn = np.load(data[1]) + scaler_cnn = StandardScaler() + scaler_cnn.fit(x_train_cnn.reshape(-1,patch_size*patch_size*n_bands)) + x_train_cnn = scaler_cnn.transform(x_train_cnn.reshape(-1,patch_size*patch_size*n_bands)).reshape(-1,patch_size,patch_size,n_bands) + + labels = np.load(data[2]) + nb_class = len(np.unique(labels)) + classes = np.array(sorted(np.unique(labels))) + clabels = np.array([np.where(classes==i)[0][0] for i in labels]) + + model = TFmodels.M3FusionModel(nb_class,n_timestamps,ts_size,patch_size,n_bands) + opt = tf.keras.optimizers.Adam(learning_rate=0.0002) + model.compile(optimizer=opt, + loss="sparse_categorical_crossentropy", + loss_weights=[1,0.3,0.3], + metrics=['accuracy'] + ) + + if len(data) == 3: + model.fit([x_train_rnn,x_train_cnn],[clabels,clabels,clabels], epochs=epochs, batch_size=128, callbacks=[log]) + elif len(data) == 6: + x_valid_rnn = np.load(data[3]) + x_valid_rnn = scaler_rnn.transform(x_valid_rnn) + x_valid_rnn = np.reshape(x_valid_rnn, (x_valid_rnn.shape[0], n_timestamps, -1)) + x_valid_cnn = np.load(data[4]) + x_valid_cnn = scaler_cnn.transform(x_valid_cnn.reshape(-1,patch_size*patch_size*n_bands)).reshape(-1,patch_size,patch_size,n_bands) + labels_val = np.load(data[5]) + if nb_class != len(np.unique(labels_val)): + sys.exit("Missing classes in validation dataset : {}".format(data[5])) + clabels_val = np.array([np.where(classes==i)[0][0] for i in labels_val]) + model.fit(x=[x_train_rnn,x_train_cnn],y=[clabels,clabels,clabels], epochs=epochs, batch_size=128, validation_data=([x_valid_rnn,x_valid_cnn],[clabels_val,clabels_val,clabels_val]), callbacks=[log]) + # model.save(model_file, save_format="tf") + model.save_weights(model_file) + joblib.dump(scaler_rnn,scaler_file.replace(".joblib","_rnn.joblib")) + joblib.dump(scaler_cnn,scaler_file.replace(".joblib","_cnn.joblib")) + scaler_files.append(scaler_file.replace(".joblib","_rnn.joblib")) + scaler_files.append(scaler_file.replace(".joblib","_cnn.joblib")) with open(csv_classes,'w') as file : writer = csv.writer(file) for i,c in enumerate(classes): writer.writerow([i,c]) - return model_file, scaler_file, csv_classes + return model_file, scaler_files, csv_classes def deepClassify(shp_list,code,scaler,model_file, csv_classes,out_fld,out_ext,feat,feat_mode = 'list',Nproc=1,compute_confidence=False): import geopandas diff --git a/launchChain.py b/launchChain.py index a026b833c12e69928b69b168ce2ec2c112bb0017..ed8a0cddb365afba392baf95f2e990da4d38824b 100644 --- a/launchChain.py +++ b/launchChain.py @@ -555,103 +555,106 @@ def main(argv): feat_groups_str = '::'.join([','.join(x) for x in feat_groups]) # generating GT samples and merging tiles - if os.path.exists(samples_fld + '/GT_samples.shp'): - warnings.warn('File ' + samples_fld + '/GT_samples.shp' + ' already exists, skipping zonal stats on reference.') - else: - gttl = sorted(glob.glob(seg_fld + '/GT_segmentation_*.shp')) - ogttl = [] - ext_flds = getFieldNames(gttl[0]) + params = config.get('TRAINING CONFIGURATION', 'parameters').split(' ') + if '-classifier' in params or ('-tensorflow' in params and 'standard' in params) : + if os.path.exists(samples_fld + '/GT_samples.shp'): + warnings.warn('File ' + samples_fld + '/GT_samples.shp' + ' already exists, skipping zonal stats on reference.') + else: + gttl = sorted(glob.glob(seg_fld + '/GT_segmentation_*.shp')) + ogttl = [] - # BEGIN ** GT_SAMPLES ** - cmd_list = [] + ext_flds = getFieldNames(gttl[0]) - for tl in gttl: - tocp = glob.glob(os.path.splitext(tl)[0] + '.*') - for fn in tocp: - shutil.copyfile(fn, samples_fld + '/' + os.path.basename(fn)) - otl = samples_fld + '/' + os.path.basename(tl) - ogttl.append(otl) - cmd_list.append( - ['python3', 'mrzonalstats.py', '--series-prefix', 'T', '--raster-prefix', 'F', '--band-prefix', 'D', - '--fix-nodata', '--overwrite', feat_groups_str, otl, "Segment_ID", 'mean']) + # BEGIN ** GT_SAMPLES ** + cmd_list = [] - queuedProcess(cmd_list, N_processes=1, shell=sh) - # END ** GT_SAMPLES ** + for tl in gttl: + tocp = glob.glob(os.path.splitext(tl)[0] + '.*') + for fn in tocp: + shutil.copyfile(fn, samples_fld + '/' + os.path.basename(fn)) + otl = samples_fld + '/' + os.path.basename(tl) + ogttl.append(otl) + cmd_list.append( + ['python3', 'mrzonalstats.py', '--series-prefix', 'T', '--raster-prefix', 'F', '--band-prefix', 'D', + '--fix-nodata', '--overwrite', feat_groups_str, otl, "Segment_ID", 'mean']) - var_keys = [x for x in getFieldNames(ogttl[0]) if x not in ext_flds] + queuedProcess(cmd_list, N_processes=1, shell=sh) + # END ** GT_SAMPLES ** - with open(samples_fld + '/field_names.csv', 'w') as varfile: - vw = csv.writer(varfile) - for x in zip(var_keys, feat_to_vars): - vw.writerow(x) + var_keys = [x for x in getFieldNames(ogttl[0]) if x not in ext_flds] - cmd_base = ['ogr2ogr', '-f', 'ESRI Shapefile'] + with open(samples_fld + '/field_names.csv', 'w') as varfile: + vw = csv.writer(varfile) + for x in zip(var_keys, feat_to_vars): + vw.writerow(x) - cmd = cmd_base + [samples_fld + '/GT_samples.shp', ogttl[0]] - subprocess.call(cmd, shell=sh) + cmd_base = ['ogr2ogr', '-f', 'ESRI Shapefile'] - for k in range(1, len(ogttl)): - cmd = cmd_base + ['-update', '-append', samples_fld + '/GT_samples.shp', ogttl[k], '-nln', 'GT_samples'] + cmd = cmd_base + [samples_fld + '/GT_samples.shp', ogttl[0]] subprocess.call(cmd, shell=sh) - for otl in ogttl: - shpd.DeleteDataSource(otl) - os.remove(otl.replace('.shp','.tif')) + for k in range(1, len(ogttl)): + cmd = cmd_base + ['-update', '-append', samples_fld + '/GT_samples.shp', ogttl[k], '-nln', 'GT_samples'] + subprocess.call(cmd, shell=sh) - if ch_mode <= 0 or ch_mode == 2: - # generating VAL samples and merging tiles - gttl = sorted(glob.glob(seg_fld + '/VAL_segmentation_*.shp')) - ogttl = [] + for otl in ogttl: + shpd.DeleteDataSource(otl) + os.remove(otl.replace('.shp','.tif')) - # BEGIN ** VAL_SAMPLES ** - cmd_list = [] + if ch_mode <= 0 or ch_mode == 2: + # generating VAL samples and merging tiles + gttl = sorted(glob.glob(seg_fld + '/VAL_segmentation_*.shp')) + ogttl = [] - for tl in gttl: - tocp = glob.glob(os.path.splitext(tl)[0] + '.*') - for fn in tocp: - shutil.copyfile(fn, val_fld + '/' + os.path.basename(fn)) - otl = val_fld + '/' + os.path.basename(tl) - ogttl.append(otl) - cmd_list.append( - ['python3', 'mrzonalstats.py', '--series-prefix', 'T', '--raster-prefix', 'F', '--band-prefix', 'D', - '--fix-nodata', '--overwrite', feat_groups_str, otl, "Segment_ID", 'mean']) + # BEGIN ** VAL_SAMPLES ** + cmd_list = [] - queuedProcess(cmd_list, N_processes=1, shell=sh) - # END ** VAL_SAMPLES ** + for tl in gttl: + tocp = glob.glob(os.path.splitext(tl)[0] + '.*') + for fn in tocp: + shutil.copyfile(fn, val_fld + '/' + os.path.basename(fn)) + otl = val_fld + '/' + os.path.basename(tl) + ogttl.append(otl) + cmd_list.append( + ['python3', 'mrzonalstats.py', '--series-prefix', 'T', '--raster-prefix', 'F', '--band-prefix', 'D', + '--fix-nodata', '--overwrite', feat_groups_str, otl, "Segment_ID", 'mean']) - cmd_base = ['ogr2ogr', '-f', 'ESRI Shapefile'] + queuedProcess(cmd_list, N_processes=1, shell=sh) + # END ** VAL_SAMPLES ** - cmd = cmd_base + [val_fld + '/VAL_samples.shp', ogttl[0]] - subprocess.call(cmd, shell=sh) + cmd_base = ['ogr2ogr', '-f', 'ESRI Shapefile'] - for k in range(1, len(ogttl)): - cmd = cmd_base + ['-update', '-append', val_fld + '/VAL_samples.shp', ogttl[k], '-nln', 'VAL_samples'] + cmd = cmd_base + [val_fld + '/VAL_samples.shp', ogttl[0]] subprocess.call(cmd, shell=sh) - for otl in ogttl: - shpd.DeleteDataSource(otl) - os.remove(otl.replace('.shp', '.tif')) - - if ch_mode > 0 or force_sample_generation is True: - # generating map samples (TBD) - gttl = sorted(glob.glob(seg_fld + '/segmentation_*.shp')) - ogttl = [] - - # BEGIN ** SAMPLES ** - cmd_list = [] - - for tl in gttl: - tocp = glob.glob(os.path.splitext(tl)[0] + '.*') - for fn in tocp: - shutil.copyfile(fn, test_fld + '/' + os.path.basename(fn)) - otl = test_fld + '/' + os.path.basename(tl) - ogttl.append(otl) - cmd_list.append( - ['python3', 'mrzonalstats.py', '--series-prefix', 'T', '--raster-prefix', 'F', '--band-prefix', 'D', - '--fix-nodata', '--overwrite', feat_groups_str, otl, "Segment_ID", 'mean']) - - queuedProcess(cmd_list, N_processes=N_proc, shell=sh) + for k in range(1, len(ogttl)): + cmd = cmd_base + ['-update', '-append', val_fld + '/VAL_samples.shp', ogttl[k], '-nln', 'VAL_samples'] + subprocess.call(cmd, shell=sh) + + for otl in ogttl: + shpd.DeleteDataSource(otl) + os.remove(otl.replace('.shp', '.tif')) + + if ch_mode > 0 or force_sample_generation is True: + # generating map samples (TBD) + gttl = sorted(glob.glob(seg_fld + '/segmentation_*.shp')) + ogttl = [] + + # BEGIN ** SAMPLES ** + cmd_list = [] + + for tl in gttl: + tocp = glob.glob(os.path.splitext(tl)[0] + '.*') + for fn in tocp: + shutil.copyfile(fn, test_fld + '/' + os.path.basename(fn)) + otl = test_fld + '/' + os.path.basename(tl) + ogttl.append(otl) + cmd_list.append( + ['python3', 'mrzonalstats.py', '--series-prefix', 'T', '--raster-prefix', 'F', '--band-prefix', 'D', + '--fix-nodata', '--overwrite', feat_groups_str, otl, "Segment_ID", 'mean']) + + queuedProcess(cmd_list, N_processes=N_proc, shell=sh) # END ** SAMPLES ** if single_step: @@ -722,6 +725,8 @@ def main(argv): val_out_list = deepClassify(val_list, 'p' + cfield, scaler, mfn, csv_classes, mfld, '_' + cfield, var_list) os.rename(val_out_list[0], val_out) + else : + sys.exit('\n This classification method is not yet available for kfold treatement') txt_out = mfld + '/' + os.path.basename(kfold_test_samples[i]).replace('.shp','_' + cfield + '_report.txt') classes, cm, acc, kappa, prf = surfaceValidation(test_folds[i], val_out, val_out_check, cfield) formatValidationTxt(classes, cm, acc, kappa, prf, txt_out) @@ -739,8 +744,53 @@ def main(argv): if params[0] == "-classifier": training(samples_fld + '/GT_samples.shp',cfield,model_fld,params,var_list) elif params[0] == '-tensorflow': - deepTraining(samples_fld + '/GT_samples.shp',cfield,model_fld,params,var_list) - + if params[1] == 'standard': + deepTraining(samples_fld + '/GT_samples.shp',cfield,model_fld,params,var_list) + elif params[1].lower() == 'm3fusion': + import Moringa2DL + epochs=int(params[2]) + tf_data = Moringa2DL.generateAll(feat_fld+'/*',reference,seg_src,cfield,ofld=samples_fld) + # tf_data=['/media/peillet/DATA/Mada2_venus/Mada_S2_new_GT/GT_SAMPLES/train_code_1_rnn.npy','/media/peillet/DATA/Mada2_venus/Mada_S2_new_GT/GT_SAMPLES/train_code_1_cnn.npy','/media/peillet/DATA/Mada2_venus/Mada_S2_new_GT/GT_SAMPLES/train_code_1_label.npy','/media/peillet/DATA/Mada2_venus/Mada_S2_new_GT/VAL_SAMPLES/valid_code_1_rnn.npy','/media/peillet/DATA/Mada2_venus/Mada_S2_new_GT/VAL_SAMPLES/valid_code_1_cnn.npy','/media/peillet/DATA/Mada2_venus/Mada_S2_new_GT/VAL_SAMPLES/valid_code_1_label.npy'] + # gttl = sorted(glob.glob(seg_fld + '/GT_segmentation_*.shp')) + # gt_rnn_feat = None + # gt_rnn_label = np.zeros(0) + # gt_cnn_feat = None + # gt_cnn_label = np.zeros(0) + # for gt in gttl : + # rnn_feat, rnn_label = Moringa2DL.generateAnnotatedDataset(feat_fld, gt, cfield,sname=os.path.splitext(os.path.basename(gt))[0]+'_rnn_'+cfield,ofld=samples_fld) + # if gt_rnn_feat is None : + # gt_rnn_feat = np.zeros((0,rnn_feat.shape[1])) + # np.concatenate(gt_rnn_feat,rnn_feat) + # np.concatenate(gt_rnn_label,rnn_label) + # cnn_feat, cnn_label = Moringa2DL.CNNAnnotedDataset(seg_src, gt, cfield, 25, sname=os.path.splitext(os.path.basename(gt))[0]+'_cnn_'+cfield, ofld=samples_fld) + # if gt_cnn_feat is None : + # gt_cnn_feat = np.zeros((0,cnn_feat.shape[1],cnn_feat.shape[2],cnn_feat.shape[3])) + # np.concatenate(gt_cnn_feat,cnn_feat) + # np.concatenate(gt_cnn_label,cnn_label) + + # tf_data.append(gt_rnn_feat) + # tf_data.append(gt_rnn_label) + # tf_data.append(gt_cnn_feat) + # tf_data.append(gt_rnn_feat) + + val_rnn_feat = None + val_rnn_label = None + if ch_mode == 0 or ch_mode == 2 : + tf_valid_data = Moringa2DL.generateAll(feat_fld+'/*',validation,seg_src,cfield,ofld=val_fld,dstype='valid') + tf_data += tf_valid_data + # gttl = sorted(glob.glob(seg_fld + '/VAL_segmentation_*.shp')) + # val_rnn_label = np.zeros(0) + # for gt in gttl : + # rnn_feat, rnn_label = Moringa2DL.generateAnnotatedDataset(feat_fld, gt, cfield,sname=os.path.splitext(os.path.basename(gt))[0]+'_'+cfield,ofld=val_fld) + # if val_rnn_feat is None : + # val_rnn_feat = np.zeros((0,rnn_feat.shape[1])) + # np.concatenate(val_rnn_feat,rnn_feat) + # np.concatenate(val_rnn_label,rnn_label) + # tf_data.append(val_rnn_feat) + # tf_data.append(val_rnn_label) + deepTraining(tf_data,cfield,model_fld,params,var_list,epochs=epochs) + else: + sys.exit("Tensorflow parameter {} is not available".format(params[1])) if rfimp == True: for cfield in cfieldlist: getVariableImportance(samples_fld + '/GT_samples.shp',var_list,cfield,samples_fld + '/var_importance_' + cfield + '.csv', nbtrees=rfimpntrees, nodesize=rfimpnodesize, mtry=rfimpmaxfeat, nruns=rfimpnruns, field_names=samples_fld + '/field_names.csv')