"""
CPoulard  ; mai 2021
côté "métier" :
                appliquer la formule de calcul de proba du nombre de crues s en hydrologie (TD 1ere année ENTPE)

coté "Python" : mettre en oeuvre deux widgets de matplotlib (Button et Slider), faciles d'emploi mais au rendu pas très abouti ;
              définir les fonctions associées à des événements sur ces boutons
"""

from matplotlib import pyplot as plt
from matplotlib.widgets import Slider, Button   # widgets du module matplotlib !
from matplotlib import ticker
from scipy.special import comb

# CONSTANTE

# tracé par plot : on utilise "plot" qui est plus facile mais moins adapté à une fonction 'discrète"
# sinon on utilise un tracé avec "hlines et fillbetween"
# en attendant "stairs" (https://matplotlib.org/devdocs/api/_as_gen/matplotlib.pyplot.stairs.html)
TRACE_PAR_PLOT = False # False

# FONCTIONS
#
# trois fonctions correspondant chacune à une manière de définir un échantillon de nb valeurs
def proba_crue(n, k, f):
    """  Probabilité d’avoir exactement p crues de fréquence f en n années : C(n,k)*f^k*(1-f)^(N-k)
    """
    return comb(n, k) * (f ** k) * ((1 - f) ** (n - k))


def afficher_probas(kmax, n, f):
    kmax = max(n, kmax)
    print(f" pour k variant de 0 à {kmax}"
        f"proba d'avoir exactement 'k crue(s)' supérieures à la crue de période de retour {1 / f:.0f} (= fréquence {f:.02f})  en {n} années ")

    for k in range(kmax+1):
        if k < 2:
            texte = "crue supérieure"
        else:
            texte = "crues supérieures"

        print(f"    exactement {k} {texte} :  {proba_crue(n, k, f):.2f)*100:.2f}%")

def calcul_courbes(n,f):
    cumul_au_plus = 0
    liste_probas_exactement=[]
    liste_probas_auplus = []
    k= 0
    while cumul_au_plus < 0.999:
        proba = proba_crue(n,k,f)
        liste_probas_exactement.append(proba)
        cumul_au_plus += proba
        liste_probas_auplus.append(cumul_au_plus)
        k+=1
    # on retourne k-1 car on a effectué le dernier calcul pour p-1
    return k-1, liste_probas_exactement, liste_probas_auplus

def stair_plot(x,y, couleur, label=None):
    x = [*x, x[-1]]  # duplicate last  to draw last line
    ax.hlines(y, xmin=x[:-1], xmax=x[1:], linewidth=3, color=couleur, label=label)
    # fill under line  :  y with duplicated last value to draw last step as well ; hide edge
    ax.fill_between(x, [*y, y[-1]],
                    step='post', facecolor=couleur, edgecolor=None, alpha=0.5)

def update_graphique_plot(val):
    global n, T
    n = int(slider_n.val)
    T = int(slider_T.val)

    k_dernier, liste_probas_exactement, liste_probas_auplus = calcul_courbes(n,1/T)
    fig.suptitle(f"Démo : probabilité d'avoir k crues > T={T}ans en {n} années")

    courbe_exactement.set_data(range(k_dernier + 1), liste_probas_exactement)
    courbe_cumul.set_data(range(k_dernier + 1), liste_probas_auplus)

    ax.set_xlim(0, k_dernier + 1)

    if k_dernier > 15:
        ax.xaxis.set_major_locator(plt.MultipleLocator(5))
        ax.xaxis.set_minor_locator(plt.MultipleLocator(1))
    else:
        ax.xaxis.set_major_locator(plt.MultipleLocator(1))
        ax.xaxis.set_minor_locator(plt.NullLocator())
    ax.xaxis.set_major_formatter(ticker.FormatStrFormatter("%d"))
    fig.canvas.draw_idle()

def update_graphique_stairs(val):
    global n, T
    n = int(slider_n.val)
    T = int(slider_T.val)

    k_dernier, liste_probas_exactement, liste_probas_auplus = calcul_courbes(n, 1 / T)
    fig.suptitle(f"Démo : probabilité d'avoir k crues > T={T}ans en {n} années")

    ax.clear()
    stair_plot(range(k_dernier + 1), liste_probas_exactement, couleur="blue", label="exactement k crues")
    stair_plot(range(k_dernier + 1), liste_probas_auplus, couleur="sienna", label="k crues ou moins de k")

    ax.axhline(y=0, color='grey', ls=':')
    ax.axhline(y=1, c='grey', ls=':')
    ax.set_xlim(0, k_dernier + 1)
    ax.set_xlabel("nombre de crues k")
    ax.set_ylabel("probabilité")

    if k_dernier > 15:
        ax.xaxis.set_major_locator(plt.FixedLocator([0.5 + i for i in range(0, k_dernier, 5)]))
        ax.xaxis.set_minor_locator(plt.FixedLocator([0.5 + i for i in range(0, k_dernier, 1)]))
    else:
        ax.xaxis.set_major_locator(plt.FixedLocator([0.5 + i for i in range(0, k_dernier, 1)]))
        ax.xaxis.set_minor_locator(plt.NullLocator())
    ax.xaxis.set_major_formatter(ticker.FormatStrFormatter("%d"))
    ax.legend(loc='center right')
    fig.canvas.draw_idle()


#CORPS DU PROGRAMME
n = 100
T = 100
k_dernier, liste_probas_exactement, liste_probas_auplus = calcul_courbes(n, 1 / T)


fig, (ax, ax_espace, ax_sn, ax_sT) = plt.subplots(nrows=4, ncols=1, gridspec_kw={'height_ratios':[6,1, 1,1]}, sharex=False)
plt.subplots_adjust(left=0.2, bottom=None, right=None, top=None, wspace=None, hspace=0.1)
fig.canvas.set_window_title("ScE - Hydrologie - Démo probas crue sur N années")
fig.suptitle(f"Démo : probabilité d'avoir k_dernier crues > T={T}ans en {n} années")

if TRACE_PAR_PLOT:
    courbe_exactement, = ax.plot(range(k_dernier + 1), liste_probas_exactement, label="exactement k crues",ls='None', marker='o', markersize=10)
    courbe_cumul, = ax.plot(range(k_dernier + 1), liste_probas_auplus, label="k crues ou moins de k", ls='None', marker='o', markersize=10)
    update_graphique = update_graphique_plot
else:
    stair_plot(range(k_dernier + 1),liste_probas_exactement, couleur="blue", label="exactement k crues")
    stair_plot(range(k_dernier + 1),liste_probas_auplus, couleur="sienna",  label="k crues ou moins de k")
    update_graphique = update_graphique_stairs

for ax_s in [ax_espace, ax_sn, ax_sT]:
    ax_s.xaxis.set_visible(False)
    ax_s.yaxis.set_visible(False)

ax_espace.patch.set_alpha(0.01)  #sinon cet axe cache le titre de l'axe des x au-dessus !

for pos in ['right', 'top', 'bottom', 'left']:
            ax_espace.spines[pos].set_visible(False)

ax.set_xlabel("nombre de crues k")
ax.set_ylabel("probabilité")

# nom connu même hors de la fonction pour éviter le GC ?
# nombre d'années d'observation
slider_n = Slider(
    ax_sn, "nombre d'années d'observation ", valmin=1, valmax = 1000, valfmt='%0.0f', valinit=n, color="green")

slider_n.on_changed(update_graphique)

slider_T = Slider(
    ax_sT, "période de retour T  ", 2, 1000, valinit=T, valstep=1, valfmt='%0.0f', color="blue")

slider_T.on_changed(update_graphique)

ax.legend(loc='center right')
ax.set_xlim(0,k_dernier+1)
ax.axhline(y=0,color='grey',ls=':')
ax.axhline(y=1, c='grey',ls=':')

if TRACE_PAR_PLOT :
    if k_dernier > 15:
        ax.xaxis.set_major_locator(plt.MultipleLocator(5))
        ax.xaxis.set_minor_locator(plt.MultipleLocator(1))
    else:
        ax.xaxis.set_major_locator(plt.MultipleLocator(1))
        ax.xaxis.set_minor_locator(plt.NullLocator())
else:
    if p > 15:
        ax.xaxis.set_major_locator(plt.FixedLocator([0.5+i for i in range(0,k_dernier,5)]))
        ax.xaxis.set_minor_locator(plt.FixedLocator([0.5+i for i in range(0,k_dernier,1)]))
    else:
        ax.xaxis.set_major_locator(plt.FixedLocator([0.5+i for i in range(0,k_dernier,1)]))
        ax.xaxis.set_minor_locator(plt.NullLocator())

ax.xaxis.set_major_formatter(ticker.FormatStrFormatter("%d"))
# plt.tight_layout() : cela raccourcit l'axe qui porte le graphique...
plt.show()


print("Premiers calculs, valeurs par défaut")