# Usefull library
library(ggplot2)
library(scales)
library(qpdf)
library(gridExtra)
library(gridtext)
library(dplyr)
library(grid)
library(ggh4x)
library(RColorBrewer)


# Sourcing R file
source('plotting/panel.R', encoding='latin1')


panels_layout = function (df_data, df_meta, layout_matrix, isplot=c('datasheet', 'matrix', 'map'), figdir='', filedir_opt='', filename_opt='', variable='', df_trend=NULL, p_threshold=0.1, unit2day=365.25, type='', trend_period=NULL, mean_period=NULL, axis_xlim=NULL, missRect=FALSE, time_header=NULL, info_header=TRUE, info_ratio=1, time_ratio=2, var_ratio=3, df_shapefile=NULL) {
    
    outfile = "Panels"
    if (filename_opt != '') {
        outfile = paste(outfile, '_', filename_opt, sep='')
    }
    outfile = paste(outfile, '.pdf', sep='')

    # If there is not a dedicated figure directory it creats one
    outdir = file.path(figdir, filedir_opt, sep='')
    if (!(file.exists(outdir))) {
        dir.create(outdir)
    }

    outdirTmp = file.path(outdir, 'tmp')
    if (!(file.exists(outdirTmp))) {
        dir.create(outdirTmp)
    } else {
        unlink(outdirTmp, recursive=TRUE)
        dir.create(outdirTmp)
    }

    nbp = length(df_data)

    if (all(class(df_data) != 'list')) {
        df_data = list(df_data)
    }

    if (all(class(df_trend) != 'list')) {
        df_trend = list(df_trend)
        if (length(df_trend) == 1) {
            df_trend = replicate(nbp, df_trend)
        }}

    if (all(class(p_threshold) != 'list')) {
        p_threshold = list(p_threshold)
        if (length(p_threshold) == 1) {
            p_threshold = replicate(nbp, p_threshold)
        }}
  
    if (all(class(unit2day) != 'list')) {
        unit2day = list(unit2day)
        if (length(unit2day) == 1) {
            unit2day = replicate(nbp, unit2day)
        }}

    if (all(class(type) != 'list')) {
        type = list(type)
        if (length(type) == 1) {
            type = replicate(nbp, type)
        }}

    if (all(class(missRect) != 'list')) {
        missRect = list(missRect)
        if (length(missRect) == 1) {
            missRect = replicate(nbp, missRect)
        }}

    # Get all different stations code
    Code = levels(factor(df_meta$code))
    nCode = length(Code)


    # print(df_trend)
    df_trendtmp = df_trend[[1]]

    # print(df_trendtmp)

    nPeriod_max = 0
    for (code in Code) {

        df_trend_code = df_trendtmp[df_trendtmp$code == code,]

        Start = df_trend_code$period_start
        UStart = levels(factor(Start))
        
        End = df_trend_code$period_end
        UEnd = levels(factor(End))
        
        nPeriod = max(length(UStart), length(UEnd))

        if (nPeriod > nPeriod_max) {
            nPeriod_max = nPeriod
        }
    }

    list_df2plot = vector(mode='list', length=nbp)
    # minTrend = c()
    # maxTrend = c()

    for (i in 1:nbp) {
        
        df2plot = list(data=df_data[[i]], 
                       trend=df_trend[[i]],
                       p_threshold=p_threshold[[i]],
                       unit2day=unit2day[[i]],
                       type=type[[i]],
                       missRect=missRect[[i]])
        
        # okTrend = df_trend[[i]]$trend[df_trend[[i]]$p <= p_threshold[[i]]]

        # minTrend[i] = min(okTrend, na.rm=TRUE)
        # maxTrend[i] = max(okTrend, na.rm=TRUE)

        list_df2plot[[i]] = df2plot
    }


    if ('datasheet' %in% isplot) {

        Start_code = vector(mode='list', length=nCode)
        End_code = vector(mode='list', length=nCode)
        Code_code = vector(mode='list', length=nCode)
        Periods_code = vector(mode='list', length=nCode)

        for (j in 1:nCode) {
            
            code = Code[j]

            df_trend_code = df_trendtmp[df_trendtmp$code == code,]

            Start = df_trend_code$period_start
            UStart = levels(factor(Start))
            
            End = df_trend_code$period_end
            UEnd = levels(factor(End))
            
            nPeriod = max(length(UStart), length(UEnd))

            Periods = c()

            for (i in 1:nPeriod_max) {
                Periods = append(Periods, 
                                 paste(substr(Start[i], 1, 4),
                                       substr(End[i], 1, 4),
                                       sep=' / '))
            }
            Start_code[[j]] = Start
            End_code[[j]] = End
            Code_code[[j]] = code
            Periods_code[[j]] = Periods  
        }

        TrendMean_code = array(rep(1, nPeriod_max*nbp*nCode),
                               dim=c(nPeriod_max, nbp, nCode))

        for (j in 1:nPeriod_max) {

            for (k in 1:nCode) {
                
                code = Code[k]
                
                for (i in 1:nbp) {
                    
                    df_data = list_df2plot[[i]]$data
                    df_trend = list_df2plot[[i]]$trend
                    p_threshold = list_df2plot[[i]]$p_threshold

                    df_data_code = df_data[df_data$code == code,] 
                    df_trend_code = df_trend[df_trend$code == code,]

                    Start = Start_code[Code_code == code][[1]][j]
                    End = End_code[Code_code == code][[1]][j]
                    Periods = Periods_code[Code_code == code][[1]][j]

                    df_data_code_per =
                        df_data_code[df_data_code$Date >= Start 
                                     & df_data_code$Date <= End,]

                    df_trend_code_per = 
                        df_trend_code[df_trend_code$period_start == Start 
                                      & df_trend_code$period_end == End,]

                    Ntrend = nrow(df_trend_code_per)
                    if (Ntrend > 1) {
                        df_trend_code_per = df_trend_code_per[1,]
                    }
                    
                    dataMean = mean(df_data_code_per$Qm3s, na.rm=TRUE)
                    trendMean = df_trend_code_per$trend / dataMean

                    if (df_trend_code_per$p <= p_threshold){
                        TrendMean_code[j, i, k] = trendMean
                    } else {
                        TrendMean_code[j, i, k] = NA
                    }
                }
            }
        }

        minTrendMean = apply(TrendMean_code, c(1, 2), min, na.rm=TRUE)
        maxTrendMean = apply(TrendMean_code, c(1, 2), max, na.rm=TRUE)
        
        for (code in Code) {
            
            # Print code of the station for the current plotting
            print(paste("Plotting for station :", code))
            
            nbh = as.numeric(info_header) + as.numeric(!is.null(time_header))
            nbg = nbp + nbh 

            P = vector(mode='list', length=nbg)

            if (info_header) {
                time_header_code = time_header[time_header$code == code,]
                
                Hinfo = info_panel(list_df2plot, 
                                   df_meta,
                                   df_shapefile=df_shapefile,
                                   codeLight=code,
                                   df_data_code=time_header_code)
                P[[1]] = Hinfo
                # P[[1]] = void
            }

            if (!is.null(time_header)) {

                time_header_code = time_header[time_header$code == code,]
                axis_xlim = c(min(time_header_code$Date),
                              max(time_header_code$Date))
                
                Htime = time_panel(time_header_code, df_trend_code=NULL,
                                   trend_period=trend_period, missRect=TRUE,
                                   unit2day=365.25, type='Q', first=FALSE)

                P[[2]] = Htime
            }
            
            # map = map_panel()
            
            nbcol = ncol(as.matrix(layout_matrix))
            for (i in 1:nbp) {
                df_data = list_df2plot[[i]]$data
                df_trend = list_df2plot[[i]]$trend
                p_threshold = list_df2plot[[i]]$p_threshold
                unit2day = list_df2plot[[i]]$unit2day
                missRect = list_df2plot[[i]]$missRect
                type = list_df2plot[[i]]$type

                df_data_code = df_data[df_data$code == code,] 
                df_trend_code = df_trend[df_trend$code == code,]

                color = c()
                # for (j in 1:nrow(df_trend_code)) {
                grey = 85
                for (j in 1:nPeriod_max) {
                    if (df_trend_code$p[j] <= p_threshold){
                        # color_res = get_color(df_trend_code$trend[j], 
                        # minTrend[i],
                        # maxTrend[i], 
                        # palette_name='perso',
                        # reverse=TRUE)
                        
                        Start = Start_code[Code_code == code][[1]][j]
                        End = End_code[Code_code == code][[1]][j]
                        Periods = Periods_code[Code_code == code][[1]][j]
                        
                        df_data_code_per =
                            df_data_code[df_data_code$Date >= Start 
                                         & df_data_code$Date <= End,]
                        
                        df_trend_code_per = 
                            df_trend_code[df_trend_code$period_start == Start 
                                          & df_trend_code$period_end == End,]
                        
                        Ntrend = nrow(df_trend_code_per)
                        if (Ntrend > 1) {
                            df_trend_code_per = df_trend_code_per[1,]
                        }
                        
                        dataMean = mean(df_data_code$Qm3s, na.rm=TRUE)
                        trendMean = df_trend_code_per$trend / dataMean
                        
                        color_res = get_color(trendMean, 
                                              minTrendMean[j, i],
                                              maxTrendMean[j, i],
                                              palette_name='perso',
                                              reverse=TRUE)

                        colortmp = color_res 
                    } else {  
                        colortmp = paste('grey', grey, sep='')
                        grey = grey - 10
                    }

                    color = append(color, colortmp)                
                }
                
                p = time_panel(df_data_code, df_trend_code, type=type,
                               p_threshold=p_threshold, missRect=missRect,
                               trend_period=trend_period,
                               mean_period=mean_period, axis_xlim=axis_xlim, 
                               unit2day=unit2day, last=(i > nbp-nbcol),
                               color=color)
                
                P[[i+nbh]] = p 
            }   

            layout_matrix = as.matrix(layout_matrix)
            nel = nrow(layout_matrix)*ncol(layout_matrix)

            idNA = which(is.na(layout_matrix), arr.ind=TRUE)

            layout_matrix[idNA] = seq(max(layout_matrix, na.rm=TRUE) + 1,
                                      max(layout_matrix, na.rm=TRUE) + 1 +
                                      nel)

            layout_matrix_H = layout_matrix + nbh


            info_ratio_scale = info_ratio
            time_ratio_scale = time_ratio
            var_ratio_scale = var_ratio

            ndec_info = 0
            ndec_time = 0
            ndec_var = 0

            if (info_ratio_scale != round(info_ratio_scale)) {
                ndec_info = nchar(gsub('^[0-9]+.', '',
                                       as.character(info_ratio_scale)))
            }

            if (time_ratio_scale != round(time_ratio_scale)) {
                ndec_time = nchar(gsub('^[0-9]+.', '',
                                       as.character(time_ratio_scale)))
            }
            
            if (var_ratio_scale != round(var_ratio_scale)) {
                ndec_var = nchar(gsub('^[0-9]+.', '',
                                      as.character(var_ratio_scale)))
            }
            
            ndec = max(c(ndec_info, ndec_time, ndec_var))
            
            info_ratio_scale = info_ratio_scale * 10^ndec
            time_ratio_scale = time_ratio_scale * 10^ndec
            var_ratio_scale = var_ratio_scale * 10^ndec
            
            LM = c()
            LMcol = ncol(layout_matrix_H)
            LMrow = nrow(layout_matrix_H)
            for (i in 1:(LMrow+nbh)) {

                if (info_header & i == 1) {
                    # LM = rbind(LM, rep(i, times=LMcol))
                    LM = rbind(LM,
                               matrix(rep(rep(i, times=LMcol),
                                          times=info_ratio_scale),
                                      ncol=LMcol, byrow=TRUE))
                    
                } else if (!is.null(time_header) & i == 2) {
                    LM = rbind(LM,
                               matrix(rep(rep(i, times=LMcol),
                                          times=time_ratio_scale),
                                      ncol=LMcol, byrow=TRUE))

                } else {
                    LM = rbind(LM, 
                               matrix(rep(layout_matrix_H[i-nbh,],
                                          times=var_ratio_scale),
                                      ncol=LMcol, byrow=TRUE))
                }}

            plot = grid.arrange(grobs=P, layout_matrix=LM)
            
            # plot = grid.arrange(rbind(cbind(ggplotGrob(P[[2]]), ggplotGrob(P[[2]])), cbind(ggplotGrob(P[[3]]), ggplotGrob(P[[3]]))), heights=c(1/3, 2/3))
            

            # Saving
            ggsave(plot=plot, 
                   path=outdirTmp,
                   filename=paste(as.character(code), '.pdf', sep=''),
                   width=21, height=29.7, units='cm', dpi=100)

        }
    }

    if ('matrix' %in% isplot) {
        matrice_panel(list_df2plot, df_meta, trend_period, mean_period,
                      slice=12, outdirTmp=outdirTmp, A3=TRUE)
    }
    if ('map' %in% isplot) {
        map_panel(list_df2plot, 
                  df_meta,
                  idPer=length(trend_period),
                  df_shapefile=df_shapefile,
                  outdirTmp=outdirTmp,
                  margin=margin(t=5, r=0, b=5, l=5, unit="mm"))
    }
    
    # PDF combine
    pdf_combine(input=file.path(outdirTmp, list.files(outdirTmp)),
                output=file.path(outdir, outfile))
    # unlink(outdirTmp, recursive=TRUE)

}