# Usefull library
library(ggplot2)
library(scales)
library(qpdf)
library(gridExtra)
library(gridtext)
library(dplyr)
library(grid)
library(ggh4x)
library(RColorBrewer)


time_panel = function (df_data_code, df_trend_code, type, p_threshold=0.1, missRect=FALSE, unit2day=365.25, period=NULL, last=FALSE, first=FALSE, color=NULL) {

    if (type == 'sqrt(Q)') {
        df_data_code$Qm3s = sqrt(df_data_code$Qm3s)
    }
    
    maxQ = max(df_data_code$Qm3s, na.rm=TRUE)
    power = get_power(maxQ) 

    maxQtmp = maxQ/10^power
    if (maxQtmp >= 5) {
        dbrk = 1.0
    } else if (maxQtmp < 5 & maxQtmp >= 3) {
        dbrk = 0.5
    } else if (maxQtmp < 3 & maxQtmp >= 2) {
        dbrk = 0.4
    } else if (maxQtmp < 2 & maxQtmp >= 1) {
        dbrk = 0.2
    } else if (maxQtmp < 1) {
        dbrk = 0.1
    }
    
    dbrk = dbrk * 10^power
    accuracy = NULL
    
    dDate = as.numeric(df_data_code$Date[length(df_data_code$Date)] -
                       df_data_code$Date[1]) / unit2day
    
    if (dDate >= 100) {
        datebreak = 25
        dateminbreak = 5
    } else if (dDate < 100 & dDate >= 50) {
        datebreak = 10
        dateminbreak = 1
    } else if (dDate < 50) {
        datebreak = 5
        dateminbreak = 1
    }
    
    p = ggplot() + 
        
        # theme_bw() +

    theme(panel.background=element_rect(fill='white'),
          text=element_text(family='sans'),

          # panel.border=element_blank(),
          panel.border = element_rect(color="grey85",
                                    fill=NA,
                                    size=0.7),

          # panel.grid.major.y=element_line(color='grey85', size=0.3),
          panel.grid.major.y=element_line(color='grey85', size=0.15),
          panel.grid.major.x=element_blank(),
          
          # axis.ticks.y=element_blank(),
          axis.ticks.y=element_line(color='grey75', size=0.3),
          axis.ticks.x=element_line(color='grey75', size=0.3),
          
          axis.text.x=element_text(color='grey40'),
          axis.text.y=element_text(color='grey40'),

          ggh4x.axis.ticks.length.minor=rel(0.5),
          axis.ticks.length=unit(1.5, 'mm'),

          plot.title=element_text(size=9, vjust=-2, 
                                  hjust=-1E-3, color='grey20'), 
          axis.title.x=element_blank(),
          axis.title.y=element_blank(),
          # axis.title.y=element_text(size=8, color='grey20'),
          axis.line.x=element_blank(),
          axis.line.y=element_blank(),
          )

    if (last) {
        if (first) {
            p = p +
                theme(plot.margin=margin(5, 5, 5, 5, unit="mm"))
        } else {
            p = p +
                theme(plot.margin=margin(0, 5, 5, 5, unit="mm"))
        }

    } else {
        if (first) {
            p = p +
                theme(plot.margin=margin(5, 5, 0, 5, unit="mm"))
        } else {
            p = p +
                theme(plot.margin=margin(0, 5, 0, 5, unit="mm"))
        }
    }
        

    if (type == 'sqrt(Q)' | type == 'Q') {
        p = p +
            geom_line(aes(x=df_data_code$Date, y=df_data_code$Qm3s),
                      color='grey20',
                      size=0.3)
    } else {
        p = p +
            geom_point(aes(x=df_data_code$Date, y=df_data_code$Qm3s),
                       shape=1, color='grey20', size=1)
    }

    if (missRect) {
        NAdate = df_data_code$Date[is.na(df_data_code$Qm3s)]
        dNAdate = diff(NAdate)
        NAdate_Down = NAdate[append(Inf, dNAdate) != 1]
        NAdate_Up = NAdate[append(dNAdate, Inf) != 1]

        p = p +
            geom_rect(aes(xmin=NAdate_Down, 
                          ymin=0, 
                          xmax=NAdate_Up, 
                          ymax=maxQ*1.1),
                      linetype=0, fill='Wheat', alpha=0.4)
    }

    if ((type == 'sqrt(Q)' | type == 'Q') & !is.null(period)) {
        
        period = as.list(period)
        Imin = 10^99
        for (per in period) {
            I = interval(per[1], per[2])
            if (I < Imin) {
                Imin = I
                period_min = as.Date(per)
            }
        }

        p = p + 
            geom_rect(aes(xmin=min(df_data_code$Date),
                          ymin=0, 
                          xmax=period_min[1], 
                          ymax= maxQ*1.1),
                      linetype=0, fill='grey85', alpha=0.3) +
            
            geom_rect(aes(xmin=period_min[2],
                          ymin=0, 
                          xmax=max(df_data_code$Date), 
                          ymax= maxQ*1.1),
                      linetype=0, fill='grey85', alpha=0.3) 
    }


    if (!is.null(df_trend_code)) {
        
        # print(df_trend_code)

        Start = df_trend_code$period_start
        UStart = levels(factor(Start))
        End = df_trend_code$period_end
        UEnd = levels(factor(End))
        
        nPeriod = max(length(UStart), length(UEnd))
        
        Periods = vector(mode='list', length=nPeriod)
        # for (i in 1:nPeriod) {
        #     Periods[[i]] = as.Date(c(Period_start[i], Period_end[i]))
        # }    
        
        ltype = c('solid', 'dashed', 'dotted', 'twodash')
        lty = c('solid', '22')
        
        ii = 0
        for (i in 1:nPeriod) {

            df_trend_code_per = 
                df_trend_code[df_trend_code$period_start == Start[i] 
                              & df_trend_code$period_end == End[i],]

            if (df_trend_code_per$p <= p_threshold) {

                ii = ii + 1

                iStart = which.min(abs(df_data_code$Date - Start[i]))
                iEnd = which.min(abs(df_data_code$Date - End[i]))

                abs = c(df_data_code$Date[iStart],
                        df_data_code$Date[iEnd])
                
                abs_num = as.numeric(abs) / unit2day


                ord = abs_num * df_trend_code_per$trend +
                    df_trend_code_per$intercept

                plot = tibble(abs=abs, ord=ord)

                if (!is.na(color[i])) {
                    p = p + 
                        geom_line(data=plot, aes(x=abs, y=ord), 
                                      color=color[i], 
                                      linetype=ltype[i], size=0.7)
                } else {                    
                    p = p + 
                        geom_line(aes(x=abs, y=ord), 
                                  color='cornflowerblue')
                }

                codeDate = df_data_code$Date
                codeQ = df_data_code$Qm3s
                
                x = gpct(2, codeDate, shift=TRUE)
                xend = x + gpct(3, codeDate)
               
                dy = gpct(5, codeQ, ref=0)
                y = gpct(105, codeQ, ref=0) - (ii-1)*dy

                xt = xend + gpct(1, codeDate)
                label = bquote(bold(.(format(df_trend_code$trend, scientific=TRUE, digits=3)))~'['*m^{3}*'.'*s^{-1}*'.'*an^{-1}*']')
    
                p = p +
                    annotate("segment",
                             x=x, xend=xend,
                             y=y, yend=y,
                             color=color[i],
                             lty=lty[i], lwd=1) +
                    
                    annotate("text", 
                             label=label, size=3,
                             x=xt, y=y, 
                             hjust=0, vjust=0.4,
                             color=color[i])
                
                
                
                # bquote(bold('tendance')~.(format(df_trend_code$trend, scientific=TRUE, digits=3))~'['*m^{3}*'.'*s^{-1}*'.'*an^{-1}*']')


            }
        }
    }

    p = p +
        ggtitle(bquote(bold(.(type))~~'['*m^{3}*'.'*s^{-1}*']')) +

        # xlab('date') + 
        scale_x_date(date_breaks=paste(as.character(datebreak), 
                                       'year', sep=' '),
                     date_minor_breaks=paste(as.character(dateminbreak), 
                                             'year', sep=' '),
                     guide='axis_minor',
                     date_labels="%Y",
                     limits=c(min(df_data_code$Date), 
                              max(df_data_code$Date)),
                     expand=c(0, 0))

    p = p +
        scale_y_continuous(breaks=seq(0, maxQ*10, dbrk),
                           limits=c(0, maxQ*1.1),
                           expand=c(0, 0),
                           labels=label_number(accuracy=accuracy))

    return(p)
}


text_panel = function(code, df_meta) {
    df_meta_code = df_meta[df_meta$code == code,]

    text1 = paste(
        "<b>", code, '</b>  -  ', df_meta_code$nom, ' &#40;',
        df_meta_code$region_hydro, '&#41;', 
        sep='')

    text2 = paste(
        "<b>",
        "Gestionnaire : ", df_meta_code$gestionnaire, "<br>", 
        "</b>",
        sep='')

    text3 = paste(
        "<b>",
        "Superficie : ", df_meta_code$surface_km2_IN, 
        ' (', df_meta_code$surface_km2_BH, ')', "  [km<sup>2</sup>] <br>",
        "X = ", df_meta_code$L93X_m_IN, 
        ' (', df_meta_code$L93X_m_BH, ')', "  [m ; Lambert 93]", 
        "</b>",
        sep='')
        
    text4 = paste(
        "<b>",
        "Altitude : ", df_meta_code$altitude_m_IN, 
        ' (', df_meta_code$altitude_m_BH, ')', "  [m]<br>",
        "Y = ", df_meta_code$L93Y_m_IN, 
        ' (', df_meta_code$L93Y_m_BH, ')', "  [m ; Lambert 93]",
        "</b>",
        sep='')

    text5 = paste(
        "<b>",
        "INRAE (Banque Hydro)<br>",
        "INRAE (Banque Hydro)",
        "</b>",
        sep='')

    gtext1 = richtext_grob(text1,
                           x=0, y=1,
                           margin=unit(c(t=5, r=5, b=0, l=5), "mm"),
                           hjust=0, vjust=1,
                           gp=gpar(col="#00A3A8", fontsize=14))

    gtext2 = richtext_grob(text2,
                           x=0, y=0.55,
                           margin=unit(c(t=0, r=5, b=0, l=5), "mm"),
                           hjust=0, vjust=1,
                           gp=gpar(col="grey20", fontsize=8))
    
    gtext3 = richtext_grob(text3,
                           x=0, y=1,
                           margin=unit(c(t=0, r=5, b=5, l=5), "mm"),
                           hjust=0, vjust=1,
                           gp=gpar(col="grey20", fontsize=9))
    
    gtext4 = richtext_grob(text4,
                           x=0, y=1,
                           margin=unit(c(t=0, r=5, b=5, l=5), "mm"),
                           hjust=0, vjust=1,
                           gp=gpar(col="grey20", fontsize=9))

    gtext5 = richtext_grob(text5,
                           x=0, y=1,
                           margin=unit(c(t=0, r=5, b=5, l=5), "mm"),
                           hjust=0, vjust=1,
                           gp=gpar(col="grey20", fontsize=9))
    
    gtext_merge = grid.arrange(grobs=list(gtext1, gtext2, gtext3, 
                                          gtext4, gtext5), 
                               layout_matrix=matrix(c(1, 1, 1,
                                                      2, 2, 2,
                                                      3, 4, 5), 
                                                    nrow=3, 
                                                    byrow=TRUE))

    return(gtext_merge)
}



matrice_panel = function (list_df2plot, df_meta) {
    
    nbp = length(list_df2plot)

    minTrend = c()
    maxTrend = c()

    for (i in 1:nbp) {
        
        df_trend = list_df2plot[[i]]$trend
        p_threshold = list_df2plot[[i]]$p_threshold
        
        okTrend = df_trend$trend[df_trend$p <= p_threshold]

        minTrend[i] = min(okTrend, na.rm=TRUE)
        maxTrend[i] = max(okTrend, na.rm=TRUE)
    }

    # Get all different stations code
    Code = levels(factor(df_meta$code))

    Type_mat = list()
    Code_mat = c()
    Trend_mat = c()
    Fill_mat = c()
    Color_mat = c()

    for (code in Code) {
        
        for (i in 1:nbp) {
            df_trend = list_df2plot[[i]]$trend
            p_threshold = list_df2plot[[i]]$p_threshold
            type = list_df2plot[[i]]$type
            
            Type_mat = append(Type_mat, type)
            Code_mat = append(Code_mat, code)

            df_trend_code = df_trend[df_trend$code == code,]

            if (df_trend_code$p <= p_threshold){
                color_res = get_color(df_trend_code$trend, 
                                      minTrend[i],
                                      maxTrend[i],
                                      palette_name='perso',
                                      reverse=FALSE)

                trend = df_trend_code$trend
                fill = color_res$color
                color = 'white'


            } else { 
                trend = NA
                fill = 'white'
                color = 'white'
                
            }

            Trend_mat = append(Trend_mat, trend)
            Fill_mat = append(Fill_mat, fill)
            Color_mat = append(Color_mat, color)
        }
    }

    X = as.integer(factor(as.character(Type_mat)))
    Y = as.integer(factor(Code_mat))

    options(repr.plot.width=X, repr.plot.height=Y)
    
    mat = ggplot() +
        
        theme(
              panel.background=element_rect(fill='white'),
              text=element_text(family='sans'),
              panel.border=element_blank(),

              panel.grid.major.y=element_blank(),
              panel.grid.major.x=element_blank(),
              
              axis.text.x=element_blank(),
              axis.text.y=element_blank(),
              
              axis.ticks.y=element_blank(),
              axis.ticks.x=element_blank(),

              ggh4x.axis.ticks.length.minor=rel(0.5),
              axis.ticks.length=unit(1.5, 'mm'),
          
              plot.title=element_text(size=9, vjust=-3, 
                                  hjust=-1E-3, color='grey20'), 

              axis.title.x=element_blank(),
              axis.title.y=element_blank(),

              axis.line.x=element_blank(),
              axis.line.y=element_blank(),
              
              plot.margin=margin(5, 5, 5, 5, unit="mm"),
              )

    for (i in 1:length(X)) {
        mat = mat +
            gg_circle(r=0.5, xc=X[i], yc=Y[i], fill=Fill_mat[i], color=Color_mat[i])
    }
        

    mat = mat +

    coord_fixed() +

    scale_x_continuous(limits=c(min(X) - rel(1.5), 
                                max(X) + rel(0.5)),
                       expand=c(0, 0)) + 
        
    scale_y_continuous(limits=c(min(Y) - rel(0.5), 
                                max(Y) + rel(1)),
                       expand=c(0, 0))
    
    for (i in 1:length(Code)) {
        mat = mat +
            annotate('text', x=-0.5, y=i,
                     label=Code[i],
                     hjust=0, vjust=0.5, 
                     size=3.5, color='grey40')       
    }

    for (i in 1:nbp) {
        type = list_df2plot[[i]]$type
        mat = mat +
            annotate('text', x=i, y=max(Y) + 0.6,
                     label=bquote(.(type)),
                     hjust=0.5, vjust=0, 
                     size=3.5, color='grey40')       
    }
    
    for (i in 1:length(Trend_mat)) {
        trend = Trend_mat[i]
        if (!is.na(trend)) {
            power = get_power(trend)
            dbrk = 10^power
            trendN = round(trend / dbrk, 2)
            trendC1 = as.character(trendN)
            trendC2 = bquote('x '*10^{.(as.character(power))})
        } else {
            trendC1 = ''
            trendC2 = ''
        }
        mat = mat +
            annotate('text', x=X[i], y=Y[i],
                     label=trendC1,
                     hjust=0.5, vjust=0, 
                     size=3, color='white') +
            annotate('text', x=X[i], y=Y[i],
                     label=trendC2,
                     hjust=0.5, vjust=1.3,
                     size=2, color='white')
        
    }
    
    return (mat)
}





get_color = function (value, min, max, ncolor=256, palette_name='perso', reverse=FALSE) {
    
    if (palette_name == 'perso') {
        palette = colorRampPalette(c(
            '#1a4157',
            '#00af9d',
            '#fbdd7e',
            '#fdb147',
            '#fd4659'
        ))(ncolor)
        
    } else {
        palette = colorRampPalette(brewer.pal(11, palette_name))(ncolor)
    }

    if (reverse) {
        palette = rev(palette)
    }
    
    palette_cold = palette[1:as.integer(ncolor/2)]
    palette_hot = palette[(as.integer(ncolor/2)+1):ncolor]

    ncolor_cold = length(palette_cold)
    ncolor_hot = length(palette_hot)

    if (value < 0) {
        idNorm = (value - min) / (0 - min)
        id = round(idNorm*(ncolor_cold - 1) + 1, 0)
        color = palette_cold[id]
    } else {
        idNorm = (value - 0) / (max - 0)
        id = round(idNorm*(ncolor_hot - 1) + 1, 0)
        color = palette_hot[id]
    }
    
    return(list(color=color, palette=palette))
}

void = ggplot() + geom_blank(aes(1,1)) +
    theme(
        plot.background = element_blank(), 
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(), 
        panel.border = element_blank(),
        panel.background = element_blank(),
        axis.title.x = element_blank(),
        axis.title.y = element_blank(),
        axis.text.x = element_blank(), 
        axis.text.y = element_blank(),
        axis.ticks = element_blank(),
        axis.line = element_blank()
    )



palette_tester = function () {

    n = 300
    X = 1:n
    Y = rep(0, times=n)

    palette = colorRampPalette(c(
        '#1a4157',
        '#00af9d',
        '#fbdd7e',
        '#fdb147',
        '#fd4659'
    ))(n)

    p = ggplot() + 
        geom_line(aes(x=X, y=Y), color=palette[X], size=10) +
        scale_y_continuous(expand=c(0, 0))

    ggsave(plot=p,
           path='/figures',
           filename=paste('palette_test', '.pdf', sep=''),
           width=10, height=10, units='cm', dpi=100)
}

# palette_teste()


get_power = function (value) {
    
    if (value > 1) {
        power = nchar(as.character(as.integer(value))) - 1
    } else {
        dec = gsub('0.', '', as.character(value), fixed=TRUE)
        ndec = nchar(dec)
        nnum = nchar(as.character(as.numeric(dec)))
        power = -(ndec - nnum + 1)
    }
    
    return(power)
}


gg_circle = function(r, xc, yc, color="black", fill=NA, ...) {
    x = xc + r*cos(seq(0, pi, length.out=100))
    ymax = yc + r*sin(seq(0, pi, length.out=100))
    ymin = yc + r*sin(seq(0, -pi, length.out=100))
    annotate("ribbon", x=x, ymin=ymin, ymax=ymax, color=color, fill=fill, ...)
}



gpct = function (pct, L, ref=NULL, shift=FALSE) {
    
    if (is.null(ref)) {
        minL = min(L, na.rm=TRUE)
    } else {
        minL = ref
    }
    
    maxL = max(L, na.rm=TRUE)
    spanL = maxL - minL
 
    xL = pct/100 * as.numeric(spanL)

    if (shift) {
        xL = xL + minL
    }
    return (xL)
}