single_scalar_plot.py 2.89 KiB
from extreme_estimator.robustness_plot.abstract_robustness_plot import AbstractPlot
import matplotlib.pyplot as plt
from itertools import product


class SingleScalarPlot(AbstractPlot):
    """
    For a single scalar plot, for the combination of all the parameters of interest,
    then the function
    """

    def single_scalar_from_all_params(self, **kwargs_single_point) -> float:
        pass

    def robustness_grid_plot(self, **kwargs):
        # Extract Grid row and columns values
        grid_row_values = self.grid_row_item.values_from_kwargs(**kwargs)
        grid_column_values = self.grid_column_item.values_from_kwargs(**kwargs)
        nb_grid_rows, nb_grid_columns = len(grid_row_values), len(grid_column_values)
        # Start the overall plot
        fig = plt.figure()
        fig.subplots_adjust(hspace=0.4, wspace=0.4)
        for i, (grid_row_value, grid_column_value) in enumerate(product(grid_row_values, grid_column_values), 1):
            print('Grid plot: {}={} {}={}'.format(self.grid_row_item.dislay_name, grid_row_value,
                                                  self.grid_column_item.dislay_name, grid_column_value))
            ax = fig.add_subplot(nb_grid_rows, nb_grid_columns, i)
            # Adapt the kwargs for the single plot
            kwargs_single_plot = kwargs.copy()
            kwargs_single_plot.update({self.grid_row_item.argument_name: grid_row_value,
                                       self.grid_column_item.argument_name: grid_column_value})
            self.robustness_single_plot(ax, **kwargs_single_plot)
        plt.show()

    def robustness_single_plot(self, ax, **kwargs_single_plot):
        plot_row_values = self.plot_row_item.values_from_kwargs(**kwargs_single_plot)
        plot_label_values = self.plot_label_item.values_from_kwargs(**kwargs_single_plot)
        for j, plot_label_value in enumerate(plot_label_values):
            # Compute
            plot_row_value_to_error = {}
            # todo: do some parallzlization here (do the parallelization in the Asbtract class if possible)
            for plot_row_value in plot_row_values:
                # Adapt the kwargs for the single value
                kwargs_single_point = kwargs_single_plot.copy()
                kwargs_single_point.update({self.plot_row_item.argument_name: plot_row_value,
                                            self.plot_label_item.argument_name: plot_label_value})
                plot_row_value_to_error[plot_row_value] = self.single_scalar_from_all_params(**kwargs_single_point)
            plot_column_values = [plot_row_value_to_error[plot_row_value] for plot_row_value in plot_row_values]
            ax.plot(plot_row_values, plot_column_values, color=self.COLORS[j % len(self.COLORS)], label=str(j))
        ax.legend()
        ax.set_xlabel(self.plot_row_item.dislay_name)
        ax.set_ylabel('Absolute error')
        ax.set_title('Title (display all the other parameters)')