visualize.py 2.23 KiB
# coding = utf-8

import matplotlib.pyplot as plt
import plotly.graph_objs as go
import seaborn as sns
from plotly.offline import init_notebook_mode, iplot

from .pareto import is_pareto_front

init_notebook_mode(connected=True)

cm = sns.color_palette("hls", 8)
def draw_pareto_static(df, x_label, criteria, x_ax_label="X", y_ax_label="Y", title="Titre"):
    fig, ax = plt.subplots(figsize=(10, 5), ncols=1)
    for i in range(len(criteria)):
        y_label = criteria[i]
        df_is_pareto = df.apply(lambda row: is_pareto_front(df, row, [x_label, y_label]), axis=1)
        df_pareto = df.ix[df_is_pareto].sort_values(by=x_label)

        sns.swarmplot(x=x_label, y=y_label, data=df, ax=ax, color=cm[i])
        ax.plot(df_pareto[x_label].index, df_pareto[y_label].values, '--', color=cm[i],
                label='P. Frontier for {0}'.format(criteria[i]))

    plt.xlabel(x_ax_label)
    plt.ylabel(y_ax_label)
    plt.xticks(rotation=90)
    plt.title(title)
    plt.show()

def draw_pareto_dynamic(df, x_label, criteria, layout = None):
    if not layout:
        fig = go.Figure(data=data_pareto(df, x_label, criteria))
    else:
        fig = go.Figure(data=data_pareto(df, x_label, criteria), layout=layout)
    return iplot(fig)


def data_pareto(df, x_label, criteria):
    data = []
    for i in range(len(criteria)):
        y_label = criteria[i]
        # df.assign(normalized=df.bought.div(df.groupby('user').bought.transform('sum')))
        df_is_pareto = df.apply(lambda row: is_pareto_front(df, row, [x_label, y_label]), axis=1)
        df_pareto = df.ix[df_is_pareto].sort_values(by=x_label)
        data.append(go.Scatter(
            x=df[x_label],  # assign x as the dataframe column 'x'
            y=df[y_label],
            mode="markers",
            marker=dict(
                color=("rgb" + str(cm[i])),
            ),
            name="{0} ".format(criteria[i]),
        ))

        data.append(
            go.Scatter(
                x=df_pareto[x_label],  # assign x as the dataframe column 'x'
                y=df_pareto[y_label],
                name="{0} Pareto Frontier".format(criteria[i]),
                line=dict(
                    color=("rgb" + str(cm[i])),
                    width=4, )
            )

        )
    return data