File size: 3,420 Bytes
97ab62b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from typing import List

import numpy as np
import pandas as pd
import plotly.graph_objects as go


def plot_train_test(df1: pd.DataFrame, df2: pd.DataFrame) -> go.Figure:
    """
    Plot the training and test datasets using Plotly.

    Args:
        df1 (pd.DataFrame): Train dataset
        df2 (pd.DataFrame): Test dataset

    Returns:
        None
    """

    # Create a Plotly figure
    fig = go.Figure()

    # Add the first scatter plot with steelblue color
    fig.add_trace(go.Scatter(
            x=df1.index,
            y=df1.iloc[:, 0],
            mode='lines',
            name='Training Data',
            line=dict(color='steelblue'),
            marker=dict(color='steelblue')
            ))

    # Add the second scatter plot with yellow color
    fig.add_trace(go.Scatter(
            x=df2.index,
            y=df2.iloc[:, 0],
            mode='lines',
            name='Test Data',
            line=dict(color='gold'),
            marker=dict(color='gold')
            ))

    # Customize the layout
    fig.update_layout(
            title='Univariate Time Series',
            xaxis=dict(title='Date'),
            yaxis=dict(title='Value'),
            showlegend=True,
            template='plotly_white'
            )
    return fig


def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]) -> go.Figure:
    """
    Plot the true values and forecasts using Plotly.

    Args:
        df (pd.DataFrame): DataFrame with the true values. Assumed to have an index and columns.
        forecasts (List[pd.DataFrame]): List of DataFrames containing the forecasts.

    Returns:
        go.Figure: Plotly figure object.
    """

    # Create a Plotly figure
    fig = go.Figure()

    # Add the true values trace
    fig.add_trace(go.Scatter(
            x=pd.to_datetime(df.index),
            y=df.iloc[:, 0],
            mode='lines',
            name='True values',
            line=dict(color='black')
            ))

    # Add the forecast traces
    colors = ["green", "blue", "purple"]
    for i, forecast in enumerate(forecasts):
        color = colors[i]
        for sample in forecast.samples:
            fig.add_trace(go.Scatter(
                    x=forecast.index.to_timestamp(),
                    y=sample,
                    mode='lines',
                    opacity=0.15,  # Adjust opacity to control visibility of individual samples
                    name=f'Forecast {i + 1}',
                    showlegend=False,  # Hide the individual forecast series from the legend
                    hoverinfo='none',  # Disable hover information for the forecast series
                    line=dict(color=color)
                    ))
        # Add the average
        mean_forecast = np.mean(forecast.samples, axis=0)
        fig.add_trace(go.Scatter(
                x=forecast.index.to_timestamp(),
                y=mean_forecast,
                mode='lines',
                name=f'Mean Forecast',
                line=dict(color='red', dash='dash')
                ))

    # Customize the layout
    fig.update_layout(
            title='Passenger Forecast',
            xaxis=dict(title='Index'),
            yaxis=dict(title='Passenger Count'),
            showlegend=True,
            legend=dict(x=0, y=1, font=dict(size=16)),
            hovermode='x'  # Enable x-axis hover for better interactivity
            )

    # Return the figure
    return fig