From e7544f154473526ddc7dd26dc267d9d28d50e58e Mon Sep 17 00:00:00 2001
From: David <david.dorchies@inrae.fr>
Date: Fri, 29 Mar 2024 14:03:47 +0100
Subject: [PATCH] feat(CreateInputsModel) add arg Qrelease

Refs #146
---
 R/CreateInputsModel.GRiwrm.R            | 35 +++++----------
 R/utils.CreateInputsModel.R             | 59 +++++++++++++++++++++++++
 tests/testthat/test-CreateInputsModel.R |  9 ++++
 3 files changed, 78 insertions(+), 25 deletions(-)
 create mode 100644 R/utils.CreateInputsModel.R

diff --git a/R/CreateInputsModel.GRiwrm.R b/R/CreateInputsModel.GRiwrm.R
index bf5c3b2..875dfc9 100644
--- a/R/CreateInputsModel.GRiwrm.R
+++ b/R/CreateInputsModel.GRiwrm.R
@@ -16,6 +16,9 @@
 #' @param Qmin (optional) [matrix] or [data.frame] of [numeric] containing
 #'        minimum flows to let downstream of a node with a Diversion \[m3 per
 #'        time step\]. Default is zero. Column names correspond to node IDs
+#' @param Qrelease (optional) [matrix] or [data.frame] of [numeric] containing
+#'        release flows by nodes using the model `RunModel_Reservoir` \[m3 per
+#'        time step\]
 #' @param PrecipScale (optional) named [vector] of [logical] indicating if the
 #'        mean of the precipitation interpolated on the elevation layers must be
 #'        kept or not, required to create CemaNeige module inputs, default `TRUE`
@@ -69,6 +72,7 @@ CreateInputsModel.GRiwrm <- function(x, DatesR,
                                      PotEvap = NULL,
                                      Qobs = NULL,
                                      Qmin = NULL,
+                                     Qrelease = NULL,
                                      PrecipScale = TRUE,
                                      TempMean = NULL, TempMin = NULL,
                                      TempMax = NULL, ZInputs = NULL,
@@ -120,31 +124,12 @@ CreateInputsModel.GRiwrm <- function(x, DatesR,
     }
   })
 
-  directFlowIds <- x$id[is.na(x$model) | x$model == "Diversion" | x$model == "RunModel_Reservoir"]
-  if (length(directFlowIds) > 0) {
-    err <- FALSE
-    if (is.null(Qobs)) {
-      err <- TRUE
-    } else {
-      Qobs <- as.matrix(Qobs)
-      if (is.null(colnames(Qobs))) {
-        err <- TRUE
-      } else if (!all(directFlowIds %in% colnames(Qobs))) {
-        err <- TRUE
-      }
-    }
-    if (err) stop(sprintf("'Qobs' column names must at least contain %s", paste(directFlowIds, collapse = ", ")))
-  }
-  if (!all(colnames(Qobs) %in% directFlowIds)) {
-    warning(
-      "The following columns in 'Qobs' are ignored since they don't match with ",
-      "Direction Injection (model=`NA`), ",
-      "Reservoir (model=\"RunModelReservoir\"), ",
-      "or Diversion nodes (model=\"Diversion\"): ",
-      paste(setdiff(colnames(Qobs), directFlowIds), collapse = ", ")
-    )
-    Qobs <- Qobs[, directFlowIds]
-  }
+  l <- updateQObsQrelease(g = x, Qobs = Qobs, Qrelease = Qrelease)
+  Qobs <- l$Qobs
+  Qrelease <- l$Qrelease
+  checkQobsQrelease(x, "Qobs", Qobs)
+  checkQobsQrelease(x, "Qrelease", Qrelease)
+
   diversionRows <- getDiversionRows(x)
   if (length(diversionRows) > 0) {
     warn <- FALSE
diff --git a/R/utils.CreateInputsModel.R b/R/utils.CreateInputsModel.R
new file mode 100644
index 0000000..f59d12e
--- /dev/null
+++ b/R/utils.CreateInputsModel.R
@@ -0,0 +1,59 @@
+updateQObsQrelease <- function(g, Qobs, Qrelease) {
+  reservoirIds <- g$id[!is.na(g$model) & g$model == "RunModel_Reservoir"]
+  # Fill Qrelease with Qobs
+  warn_ids <- NULL
+  for(id in reservoirIds) {
+    if (!id %in% names(Qrelease)) {
+      if (id %in% names(Qobs)) {
+        if (!any(g$id == id & (!is.na(g$model) & g$model == "Diversion"))) {
+          if (is.null(Qrelease)) {
+            Qrelease = Qobs[, id, drop = FALSE]
+          } else {
+            Qrelease = cbind(Qrelease, Qobs[, id, drop = FALSE])
+          }
+          Qobs <- Qobs[, names(Qobs) != id]
+          warn_ids = c(warn_ids, id)
+        }
+      }
+    }
+  }
+  if (!is.null(warn_ids)) {
+    warning("Use of the `Qobs` parameter for reservoir releases is depracated\n",
+            "`Qobs` for nodes ", paste(warn_ids, collapse = ", "), " are used as `Qrelease`")
+  }
+  return(list(Qobs = Qobs, Qrelease = Qrelease))
+}
+
+checkQobsQrelease <- function(g, varname, Q) {
+  if (varname == "Qobs") {
+    directFlowIds <- g$id[is.na(g$model) | g$model == "Diversion"]
+  } else {
+    directFlowIds <- g$id[!is.na(g$model) & g$model == "RunModel_Reservoir"]
+  }
+  if (length(directFlowIds) > 0) {
+    err <- FALSE
+    if (is.null(Q)) {
+      err <- TRUE
+    } else {
+      Q <- as.matrix(Q)
+      if (is.null(colnames(Q))) {
+        err <- TRUE
+      } else if (!all(directFlowIds %in% colnames(Q))) {
+        err <- TRUE
+      }
+    }
+    if (err) stop(sprintf("'%s' column names must at least contain %s", varname, paste(directFlowIds, collapse = ", ")))
+  }
+  if (!all(colnames(Q) %in% directFlowIds)) {
+    warning(
+      sprintf("The following columns in '%s' are ignored since they don't match with ", varname),
+      ifelse(varname == "Qobs",
+             c("Direction Injection (model=`NA`), ",
+               "or Diversion nodes (model=\"Diversion\"): "),
+             "Reservoir nodes (model=\"RunModelReservoir\"): "),
+      paste(setdiff(colnames(Q), directFlowIds), collapse = ", ")
+    )
+    Q <- Q[, directFlowIds]
+  }
+  return(Q)
+}
diff --git a/tests/testthat/test-CreateInputsModel.R b/tests/testthat/test-CreateInputsModel.R
index f8695a7..fa6acf9 100644
--- a/tests/testthat/test-CreateInputsModel.R
+++ b/tests/testthat/test-CreateInputsModel.R
@@ -272,3 +272,12 @@ test_that("Node with upstream nodes having area = NA should return correct Basin
   expect_equal(sum(InputsModel$`54001`$BasinAreas),
                g$area[g$id == "54001"])
 })
+
+test_that("Use of Qobs for Qrelease should raise a warning",  {
+  g <- CreateGRiwrm(n_rsrvr)
+  e <- setupRunModel(griwrm = g, runInputsModel = FALSE)
+  for(x in ls(e)) assign(x, get(x, e))
+  expect_warning(CreateInputsModel(griwrm, DatesR, Precip, PotEvap,
+                                   TempMean = TempMean,
+                                   Qobs = Qobs_rsrvr))
+})
-- 
GitLab