diff --git a/ymraster/classification.py b/ymraster/classification.py index 00a73a0b18151aa328302668b5d46be14284c5b9..bb2d99c209034648f9a62ed95459e152b7be2392 100644 --- a/ymraster/classification.py +++ b/ymraster/classification.py @@ -79,8 +79,8 @@ def get_samples_from_roi(in_rst_label,in_rst_roi,in_rst_stat ): col = LABEL.shape[1] #Get the index of each sample in the non-flattened original array indices = [np.empty(nb_samp),np.empty(nb_samp)] - indices[0] = [l_ind // col]#the rows - indices[1] = [l_ind % col]#the columns + indices[0] = [l_ind // col] #the rows + indices[1] = [l_ind % col] #the columns ##set the Y array, ie taking the classes values of each sample Y = ROI[indices].reshape((nb_samp,1)) @@ -237,7 +237,7 @@ def decision_tree(X_train, Y_train, X_test, X_img, reverse_array, raster, meta['count'] = None write_file(out_filename, overwrite=True, array=classif, **meta) - return y_predict + return y_predict, clf def pred_error_metrics(Y_predict, Y_test, target_names = None): """This function calcul the main classification metrics and compute and