InitialConditionsDict.py 2.36 KiB
# InitialConditionsDict.py -- Pamhyr
# Copyright (C) 2023  INRAE
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

# -*- coding: utf-8 -*-

import types
import logging

from copy import copy
from tools import trace, timer

from Model.Tools.PamhyrDict import PamhyrModelDict
from Model.InitialConditions.InitialConditions import InitialConditions

logger = logging.getLogger()


class InitialConditionsDict(PamhyrModelDict):
    _sub_classes = [
        InitialConditions,
    ]

    @classmethod
    def _db_load(cls, execute, data=None):
        new = cls(status=data["status"])

        for reach in data["edges"]:
            data["reach"] = reach

            ic = InitialConditions._db_load(
                execute,
                data=data
            )

            if ic is not None:
                new._dict[reach] = ic

        return new

    def _db_save(self, execute, data=None):
        ok = True
        if data is None:
            data = {}

        ics = self._dict
        for reach in ics:
            data["reach"] = reach
            v = self._dict[reach]
            if isinstance(v, types.GeneratorType):
                self._dict[reach] = list(v)[0]

            execute(
                "DELETE FROM initial_conditions " +
                f"WHERE reach = '{reach.id}'"
            )

            ok &= self._dict[reach]._db_save(execute, data)

        return ok

    def get(self, key):
        if key in self._dict:
            v = self._dict[key]
            if isinstance(v, types.GeneratorType):
                self._dict[key] = list(v)[0]

            return self._dict[key]

        new = self.new(key)
        self.set(key, new)
        return new

    def new(self, reach):
        new = InitialConditions(reach=reach, status=self._status)
        self.set(reach, new)
        return new