# Usefull library
library(ggplot2)
library(scales)
library(qpdf)
library(gridExtra)
library(gridtext)
library(dplyr)
library(grid)
library(ggh4x)


# Time panel
panel = function (df_data, df_meta, layout_matrix, figdir='', filedir_opt='', filename_opt='', variable='', df_trend=NULL, p_threshold=0.1, unit2day=365.25, type='', period=NULL, missRect=FALSE, time_header=NULL, info_header=TRUE, header_ratio=2) {
    
    if (all(class(df_data) != 'list')) {
        df_data = list(df_data)
    }

    nbp = length(df_data)

    if (all(class(df_trend) != 'list')) {
        df_trend = list(df_trend)
        if (length(df_trend) == 1) {
            df_trend = replicate(nbp, df_trend)
        }}

    if (all(class(p_threshold) != 'list')) {
        p_threshold = list(p_threshold)
        if (length(p_threshold) == 1) {
            p_threshold = replicate(nbp, p_threshold)
        }}
  
    if (all(class(unit2day) != 'list')) {
        unit2day = list(unit2day)
        if (length(unit2day) == 1) {
            unit2day = replicate(nbp, unit2day)
        }}

    if (all(class(type) != 'list')) {
        type = list(type)
        if (length(type) == 1) {
            type = replicate(nbp, type)
        }}

    if (all(class(missRect) != 'list')) {
        missRect = list(missRect)
        if (length(missRect) == 1) {
            missRect = replicate(nbp, missRect)
        }}

    list_df2plot = vector(mode='list', length=nbp)

    for (i in 1:nbp) {
        
        df2plot = list(data=df_data[[i]], 
                       trend=df_trend[[i]],
                       p_threshold=p_threshold[[i]],
                       unit2day=unit2day[[i]],
                       type=type[[i]],
                       missRect=missRect[[i]])

        list_df2plot[[i]] = df2plot
    }


    outfile = "Panels"
    if (filename_opt != '') {
        outfile = paste(outfile, '_', filename_opt, sep='')
    }
    outfile = paste(outfile, '.pdf', sep='')

    # If there is not a dedicated figure directory it creats one
    outdir = file.path(figdir, filedir_opt, sep='')
    if (!(file.exists(outdir))) {
        dir.create(outdir)
    }

    outdirTmp = file.path(outdir, 'tmp')
    if (!(file.exists(outdirTmp))) {
        dir.create(outdirTmp)
    }
    
    # Get all different stations code
    Code = levels(factor(df_meta$code))

    for (code in Code) {
        
        # Print code of the station for the current plotting
        print(paste("Plotting for sation :", code))
        
        nbh = as.numeric(info_header) + as.numeric(!is.null(time_header))
        nbg = nbp + nbh

        P = vector(mode='list', length=nbg)

        if (info_header) {
            Htext = text_panel(code, df_meta)
            P[[1]] = Htext
        }

        if (!is.null(time_header)) {
            Htime = time_panel(code, time_header, df_trend=NULL,
                               period=period, missRect=TRUE,
                               unit2day=365.25, type='Q')
            P[[2]] = Htime
        }


        nbcol = ncol(as.matrix(layout_matrix))
        for (i in 1:nbp) {
            df_data = list_df2plot[[i]]$data
            df_trend = list_df2plot[[i]]$trend
            p_threshold = list_df2plot[[i]]$p_threshold
            unit2day = list_df2plot[[i]]$unit2day
            missRect = list_df2plot[[i]]$missRect
            type = list_df2plot[[i]]$type
            
            p = time_panel(code, df_data, df_trend, missRect,
                           p_threshold, unit2day, type,
                           last=(i > nbp-nbcol))

            P[[i+nbh]] = p

        }
        
        layout_matrix = as.matrix(layout_matrix)
        nel = nrow(layout_matrix)*ncol(layout_matrix)

        ##
        idNA = which(is.na(layout_matrix), arr.ind=TRUE)

        layout_matrix[idNA] = seq(max(layout_matrix, na.rm=TRUE) + 1,
                                  max(layout_matrix, na.rm=TRUE) + 1 +
                                  nel)
        ##

        layout_matrix_H = layout_matrix + nbh


        LM = c()
        LMcol = ncol(layout_matrix_H)
        LMrow = nrow(layout_matrix_H)
        for (i in 1:(LMrow+nbh)) {

            if (i <= nbh) {
                LM = rbind(LM, rep(i, times=LMcol))
            } else {
                LM = rbind(LM, 
                           matrix(rep(layout_matrix_H[i-nbh,],
                                      times=header_ratio),
                                  ncol=LMcol, byrow=TRUE))
            }}

        plot = grid.arrange(grobs=P, layout_matrix=LM)
        
        # plot = grid.arrange(rbind(cbind(ggplotGrob(P[[2]]), ggplotGrob(P[[2]])), cbind(ggplotGrob(P[[3]]), ggplotGrob(P[[3]]))), heights=c(1/3, 2/3))
        

        # Saving
        ggsave(plot=plot, 
               path=outdirTmp,
               filename=paste(as.character(code), '.pdf', sep=''),
               width=21, height=29.7, units='cm', dpi=100)

    }

    pdf_combine(input=file.path(outdirTmp, list.files(outdirTmp)),
                output=file.path(outdir, outfile))
    unlink(outdirTmp, recursive=TRUE)
} 






time_panel = function (code, df_data, df_trend, missRect, p_threshold, unit2day, type, period=NULL, norm=TRUE, last=FALSE) {
   
    df_data_code = df_data[df_data$code == code,] 

    if (type == 'sqrt(Q)') {
        df_data_code$Qm3s = sqrt(df_data_code$Qm3s)
    }

    
    maxQ = max(df_data_code$Qm3s, na.rm=TRUE)
    power = nchar(as.character(as.integer(maxQ))) - 1
    dbrk = 10^power


    if (norm) {
        df_data_code$Qm3s = df_data_code$Qm3s / dbrk
        maxQ = max(df_data_code$Qm3s, na.rm=TRUE)
        if (maxQ >= 5) {
            dbrk = 1.0
        } else if (maxQ < 5 & maxQ >= 3) {
            dbrk = 0.5
        } else if (maxQ < 3 & maxQ >= 2) {
            dbrk = 0.4
        } else if (maxQ < 2 & maxQ >= 1) {
            dbrk = 0.3
        } else if (maxQ < 1) {
            dbrk = 0.2
        }
    }

    dDate = as.numeric(df_data_code$Date[length(df_data_code$Date)] -
        df_data_code$Date[1]) / unit2day

    # datebreak = round(as.numeric(dDate) / unit2day / 11 , 0)
    
    if (dDate >= 100) {
        datebreak = 25
        dateminbreak = 5
    } else if (dDate < 100 & dDate >= 50) {
        datebreak = 10
        dateminbreak = 1
    } else if (dDate < 50) {
        datebreak = 5
        dateminbreak = 1
    }
    
    p = ggplot() + 
        
        # theme_bw() +
        
        ggtitle(bquote(.(type)~'['*m^{3}*'.'*s^{-1}*']  x'~10^{.(as.character(power))})) +

    theme(panel.background=element_rect(fill='white'),
          text=element_text(family='sans'),
          panel.border=element_blank(),

          panel.grid.major.y=element_line(color='grey85', size=0.3),
          panel.grid.major.x=element_blank(),
          
          axis.ticks.y=element_blank(),
          axis.ticks.x=element_line(color='grey75', size=0.3),
          
          axis.text.x=element_text(color='grey40'),
          axis.text.y=element_text(color='grey40'),

          ggh4x.axis.ticks.length.minor=rel(0.5),
          axis.ticks.length=unit(1.5, 'mm'),

          plot.title=element_text(size=10, vjust=-4, 
                                  hjust=0, color='grey20'), 
          axis.title.x=element_blank(),
          axis.title.y=element_blank(),
          # axis.title.y=element_text(size=8, color='grey20'),
          axis.line.x=element_blank(),
          axis.line.y=element_blank(),
          )

    if (last) {
        p = p +
            theme(plot.margin=margin(1, 5, 5, 5, unit="mm"))
    } else {
        p = p +
            theme(plot.margin=margin(1, 5, 1, 5, unit="mm"))
    }
        

    if (type == 'sqrt(Q)' | type == 'Q') {
        p = p +
            geom_line(aes(x=df_data_code$Date, y=df_data_code$Qm3s),
                      color='grey20')
    } else {
        p = p +
            geom_line(aes(x=df_data_code$Date, y=df_data_code$Qm3s),
                      color='grey65') +
            geom_point(aes(x=df_data_code$Date, y=df_data_code$Qm3s),
                       shape=1, color='grey20', size=1)
    }

    if (missRect) {
        NAdate = df_data_code$Date[is.na(df_data_code$Qm3s)]
        dNAdate = diff(NAdate)
        NAdate_Down = NAdate[append(Inf, dNAdate) != 1]
        NAdate_Up = NAdate[append(dNAdate, Inf) != 1]

        p = p +
            geom_rect(aes(xmin=NAdate_Down, 
                          ymin=0, 
                          xmax=NAdate_Up, 
                          ymax=maxQ*1.1),
                      linetype=0, fill='Wheat', alpha=0.3)
    }

    if ((type == 'sqrt(Q)' | type == 'Q') & !is.null(period)) {
            period = as.Date(period)
            p = p + 
                geom_rect(aes(xmin=min(df_data_code$Date),
                              ymin=0, 
                              xmax=period[1], 
                              ymax= maxQ*1.1),
                          linetype=0, fill='grey85', alpha=0.3) +
                
                geom_rect(aes(xmin=period[2],
                              ymin=0, 
                              xmax=max(df_data_code$Date), 
                              ymax= maxQ*1.1),
                          linetype=0, fill='grey85', alpha=0.3) 
        }

    if (!is.null(df_trend)) {
        if (df_trend[df_trend$code == code,]$p < p_threshold) {

            abs = c(df_data_code$Date[1],
                    df_data_code$Date[length(df_data_code$Date)])

            abs_num = as.numeric(abs)/unit2day

            ord = abs_num * df_trend$trend[df_trend$code == code] +
                df_trend$intercept[df_trend$code == code]

            p = p + 
                geom_line(aes(x=abs, y=ord), 
                          color='cornflowerblue')
        }}
    

    # if (norm) {
    #     p = p +
    #         ylab(bquote('d�bit ['*m^{3}*'.'*s^{-1}*']  x'~10^{.(as.character(power))}))
    # } else {
    #     p = p +
    #         ylab(expression(paste('d�bit [', m^{3}, '.', 
    #                               s^{-1}, ']', sep='')))
    # }

    p = p + 
        # xlab('date') + 
        scale_x_date(date_breaks=paste(as.character(datebreak), 
                                       'year', sep=' '),
                     date_minor_breaks=paste(as.character(dateminbreak), 
                                             'year', sep=' '),
                     guide='axis_minor',
                     date_labels="%Y",
                     limits=c(min(df_data_code$Date), 
                              max(df_data_code$Date)),
                     expand=c(0, 0)) +
        
        scale_y_continuous(breaks=seq(0, maxQ*10, dbrk),
                           limits=c(0, maxQ*1.1),
                           expand=c(0, 0))

    return(p)
}


text_panel = function(code, df_meta) {
    df_meta_code = df_meta[df_meta$code == code,]

    text = paste(
        "<span style='font-size:18pt'> station <b>", code, "</b></span><br>",
        "nom : ", df_meta_code$nom, "<br>", 
        "territoire : ", df_meta_code$territoire, "<br>",
        "position : (", df_meta_code$L93X, "; ", df_meta_code$L93Y, ")", "<br>",
        "surface : ", df_meta_code$surface_km2, " km<sup>2</sup>",
        sep='')

    gtext = richtext_grob(text,
                          x=0, y=1,
                          margin=unit(c(5, 5, 5, 5), "mm"),
                          hjust=0, vjust=1,
                          gp=gpar(col="grey20", fontsize=12))
    return(gtext)
}


void = ggplot() + geom_blank(aes(1,1)) +
    theme(
        plot.background = element_blank(), 
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(), 
        panel.border = element_blank(),
        panel.background = element_blank(),
        axis.title.x = element_blank(),
        axis.title.y = element_blank(),
        axis.text.x = element_blank(), 
        axis.text.y = element_blank(),
        axis.ticks = element_blank(),
        axis.line = element_blank()
    )