File size: 3,602 Bytes
40e38d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from functools import partial
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import gradio as gr
from typing import Dict, List
from .data_processing import prepare_for_non_grouped_plotting, prepare_for_group_plotting
from .utils import set_alpha

def plot_scatter(
        data: Dict[str, Dict[float, float]],
        metric_name: str,
        log_scale_x: bool,
        log_scale_y: bool,
        normalization: bool,
        rounding: int,
        cumsum: bool,
        perc: bool,
        progress: gr.Progress,
):
    fig = go.Figure()
    data = {name: histogram for name, histogram in sorted(data.items())}
    for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")):
        histogram_prepared = prepare_for_non_grouped_plotting(histogram, normalization, rounding)
        x = sorted(histogram_prepared.keys())
        y = [histogram_prepared[k] for k in x]
        if cumsum:
            y = np.cumsum(y).tolist()
        if perc:
            y = (np.array(y) * 100).tolist()

        fig.add_trace(
            go.Scatter(
                x=x,
                y=y,
                mode="lines",
                name=name,
                marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)),
            )
        )

    yaxis_title = "Frequency" if normalization else "Total"

    fig.update_layout(
        title=f"Line Plots for {metric_name}",
        xaxis_title=metric_name,
        yaxis_title=yaxis_title,
        xaxis_type="log" if log_scale_x and len(x) > 1 else None,
        yaxis_type="log" if log_scale_y and len(y) > 1 else None,
        width=1200,
        height=600,
        showlegend=True,
    )

    return fig

def plot_bars(
        data: Dict[str, List[Dict[str, float]]],
        metric_name: str,
        top_k: int,
        direction: str,
        regex: str | None,
        rounding: int,
        log_scale_x: bool,
        log_scale_y: bool,
        progress: gr.Progress,
):
    fig = go.Figure()
    x = []
    y = []

    for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")):
        x, y, stds = prepare_for_group_plotting(histogram, top_k, direction, regex, rounding)

        fig.add_trace(go.Bar(
            x=x,
            y=y,
            name=f"{name} Mean",
            marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)),
            error_y=dict(type='data', array=stds, visible=True)
        ))

    fig.update_layout(
        title=f"Bar Plots for {metric_name}",
        xaxis_title=metric_name,
        yaxis_title="Avg. value",
        xaxis_type="log" if log_scale_x and len(x) > 1 else None,
        yaxis_type="log" if log_scale_y and len(y) > 1 else None,
        autosize=True,
        width=1200,
        height=600,
        showlegend=True,
    )

    return fig

def plot_data(data, metric_name, normalization, rounding, grouping, top_k, direction, regex, log_scale_x, log_scale_y,
              cumsum, perc, progress=gr.Progress()):
    if rounding is None or top_k is None:
        return None
    graph_fc = (
        partial(plot_scatter, normalization=normalization, rounding=rounding, cumsum=cumsum, perc=perc)
        if grouping == "histogram"
        else partial(plot_bars, top_k=top_k, direction=direction, regex=regex, rounding=rounding)
    )
    return graph_fc(data=data, metric_name=metric_name, progress=progress, log_scale_x=log_scale_x,
                    log_scale_y=log_scale_y)