fix repeated cv bug

parent f34d22f3
......@@ -21,6 +21,8 @@ def split_data (radar_seq,opt_seq,indices_seq,yields,n_folds=3,n_random=10) :
indices = np.load(indices_seq)
idx = np.arange(y.shape[0])
idx_random = np.arange(y.shape[0])
n_samples = int(y.shape[0]/n_folds)
for j in range(n_random) :
......@@ -28,10 +30,11 @@ def split_data (radar_seq,opt_seq,indices_seq,yields,n_folds=3,n_random=10) :
random.seed(dt.microsecond)
for i in range(n_folds):
if i==0:
random.shuffle(idx)
random.shuffle(idx)
test_samples = idx[i*n_samples:(i+1)*n_samples]
random.shuffle(idx_random)
random.shuffle(idx_random)
test_samples = idx_random[i*n_samples:(i+1)*n_samples]
test_idx = np.where(np.isin(idx,test_samples))
train_idx = np.where(np.isin(idx,test_samples,invert=True))
......@@ -53,41 +56,41 @@ def split_data (radar_seq,opt_seq,indices_seq,yields,n_folds=3,n_random=10) :
if __name__ == "__main__":
# Niakhar 2017
radar_seq = "./data/niakhar_2017_rad_seq.npy"
opt_seq = "./data/niakhar_2017_opt_seq.npy"
indices_seq = "./data/niakhar_2017_indices_seq.npy"
radar_seq = "./data/niakhar_2017_rad.npy"
opt_seq = "./data/niakhar_2017_opt.npy"
indices_seq = "./data/niakhar_2017_indices.npy"
yields = "./data/niakhar_2017_yields.npy"
split_data (radar_seq,opt_seq,indices_seq,yields)
# Niakhar 2018
radar_seq = "./data/niakhar_2018_rad_seq.npy"
opt_seq = "./data/niakhar_2018_opt_seq.npy"
indices_seq = "./data/niakhar_2018_indices_seq.npy"
yields = "./data/niakhar_2018_yields.npy"
# # Niakhar 2018
# radar_seq = "./data/niakhar_2018_rad.npy"
# opt_seq = "./data/niakhar_2018_opt.npy"
# indices_seq = "./data/niakhar_2018_indices.npy"
# yields = "./data/niakhar_2018_yields.npy"
split_data (radar_seq,opt_seq,indices_seq,yields)
# split_data (radar_seq,opt_seq,indices_seq,yields)
# Niakhar 2018 # SIMCO
radar_seq = "./data/niakhar-simco_2018_rad_seq.npy"
opt_seq = "./data/niakhar-simco_2018_opt_seq.npy"
indices_seq = "./data/niakhar-simco_2018_indices_seq.npy"
radar_seq = "./data/niakhar-simco_2018_rad.npy"
opt_seq = "./data/niakhar-simco_2018_opt.npy"
indices_seq = "./data/niakhar-simco_2018_indices.npy"
yields = "./data/niakhar-simco_2018_yields.npy"
split_data (radar_seq,opt_seq,indices_seq,yields)
# Niakhar 2018 # SERENA
radar_seq = "./data/niakhar-serena_2018_rad_seq.npy"
opt_seq = "./data/niakhar-serena_2018_opt_seq.npy"
indices_seq = "./data/niakhar-serena_2018_indices_seq.npy"
# Niakhar 2018 # SERENA
radar_seq = "./data/niakhar-serena_2018_rad.npy"
opt_seq = "./data/niakhar-serena_2018_opt.npy"
indices_seq = "./data/niakhar-serena_2018_indices.npy"
yields = "./data/niakhar-serena_2018_yields.npy"
split_data (radar_seq,opt_seq,indices_seq,yields)
# Nioro 2018
radar_seq = "./data/nioro_2018_rad_seq.npy"
opt_seq = "./data/nioro_2018_opt_seq.npy"
indices_seq = "./data/nioro_2018_indices_seq.npy"
yields = "./data/nioro_2018_yields.npy"
# # Nioro 2018
# radar_seq = "./data/nioro_2018_rad.npy"
# opt_seq = "./data/nioro_2018_opt.npy"
# indices_seq = "./data/nioro_2018_indices.npy"
# yields = "./data/nioro_2018_yields.npy"
split_data (radar_seq,opt_seq,indices_seq,yields)
\ No newline at end of file
# split_data (radar_seq,opt_seq,indices_seq,yields)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment