# Graph.py -- Pamhyr
# Copyright (C) 2023  INRAE
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

# -*- coding: utf-8 -*-

from functools import reduce

from Model.Network.Node import Node
from Model.Network.Edge import Edge


class Graph(object):
    def __init__(self, status=None):
        super(Graph, self).__init__()

        self._status = status

        self._node_ctor = Node
        self._edge_ctor = Edge

        self._nodes = []
        self._edges = []

    def __repr__(self):
        return f"Graph {{nodes: {self._nodes}, edges: {self._edges}}}"

    def nodes(self):
        return self._nodes

    def nodes_names(self):
        return list(map(lambda n: n.name, self._nodes))

    def edges(self):
        return self._edges

    def enable_edges(self):
        return list(
            self._enable_edges()
        )

    def _enable_edges(self):
        """Return a generator"""
        return filter(
            lambda e: e.is_enable(),
            self._edges
        )

    def edges_names(self):
        return list(map(lambda e: e.name, self._edges))

    def nodes_counts(self):
        return len(self._nodes)

    def edges_counts(self):
        return len(self._edges)

    def enable_edges_counts(self):
        return reduce(
            lambda acc, e: acc + 1 if e.is_enable() else acc,
            self._edges,
            0
        )

    def is_node_exists(self, node_name):
        return reduce(
            lambda acc, n: (acc or (n.name == node_name)),
            self._nodes,
            False
        )

    def is_edge_exists(self, edge_name):
        return reduce(
            lambda acc, e: (acc or (e.name == edge_name)),
            self._edges,
            False
        )

    def node(self, node_name: str):
        node = list(
            filter(
                lambda n: n.name == node_name,
                self._nodes
            )
        )

        if len(node) == 0:
            return None

        return node[0]

    def edge(self, edge_name: str):
        edge = list(
            filter(
                lambda e: e.name == edge_name,
                self._edges
            )
        )

        if len(edge) == 0:
            return None

        return edge[0]

    def _create_node(self, x: float, y: float):
        node = self._node_ctor(
            -1,
            "",
            x=x, y=y,
            status=self._status
        )
        return node

    def _add_node(self, node):
        self._nodes.append(node)
        self._status.modified()
        return node

    def add_node(self, x: float = 0.0, y: float = 0.0):
        node = self._create_node(x, y)
        return self._add_node(node)

    def insert_node(self, node):
        return self._add_node(node)

    def remove_node(self, node_name: str):
        self._nodes = list(
            filter(
                lambda n: n.name != node_name,
                self._nodes
            )
        )
        self._remove_associated_edge(node_name)
        self._status.modified()

    def _remove_associated_edge(self, node_name: str):
        edges = list(
            filter(
                lambda e: (e.node1.name == node_name or
                           e.node2.name == node_name),
                self._edges,
            )
        )
        for edge in edges:
            self.remove_edge(edge.name)

    def create_node(self, x: float = 0.0, y: float = 0.0):
        node = self._create_node(x, y)
        return node

    def _create_edge(self, n1: Node, n2: Node):
        edge = self._edge_ctor(
            -1,
            "", n1, n2,
            status=self._status
        )
        return edge

    def _add_edge(self, edge):
        # This edge already exists ?
        if any(filter(lambda e: (e.node1 == edge.node1 and
                                 e.node2 == edge.node2),
                      self._edges)):
            return None

        self._edges.append(edge)

        self._status.modified()
        return edge

    def add_edge(self, n1: Node, n2: Node):
        edge = self._create_edge(n1, n2)
        return self._add_edge(edge)

    def insert_edge(self, edge):
        return self._add_edge(edge)

    def create_edge(self, n1: Node, n2: Node):
        return self._create_edge(n1, n2)

    def remove_edge(self, edge_name: str):
        self._edges = list(
            filter(
                lambda e: e.name != edge_name,
                self._edges
            )
        )
        self._status.modified()

    def is_upstream_node(self, node):
        return reduce(
            lambda acc, e: (acc and (e.node2 != node or not e.enable)),
            self._enable_edges(),
            True
        )

    def is_downstream_node(self, node):
        return reduce(
            lambda acc, e: (acc and (e.node1 != node or not e.enable)),
            self._enable_edges(),
            True
        )

    def is_enable_node(self, node):
        return reduce(
            lambda acc, e: (
                acc or (
                    (e.node1 == node or
                     e.node2 == node)
                )
            ),
            self._enable_edges(),
            False
        )

    def is_enable_edge(self, edge):
        return edge._enable

    # def get_edge_id(self, reach):
    #    for i, e in enumerate(self.enable_edges):
    #        if e.id == reach.id:
    #            return i

    def get_edge_id(self, reach):
        return next(
            filter(
                lambda e: e[1].id == reach.id,
                enumerate(self.enable_edges())
            )
        )[0]