File size: 4,073 Bytes
406ac25
3c54378
406ac25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c54378
406ac25
 
 
 
 
 
 
 
3c54378
 
 
 
 
 
 
3a46db5
406ac25
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import pickle
import gradio as gr
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from src.sampler import mcmc_sampler
from src.data_parser import Parser
matplotlib.use('Agg')
font = {'size': 30}
matplotlib.rc('font', **font)


def sample(country, d, n_iterations, burnin):
    P = parser.parse_population(country)
    start_date = "2020-03-01"
    end_date = "2020-06-15"
    i, r = parser.parse_data(start_date, end_date, country)
    i, r = i.values, r.values
    s = np.repeat(P, i.shape[0]) - i - r

    p, lam, t, lam_ar, t_ar = mcmc_sampler(s, i, d, P, n_iterations, burnin,
                                           M=3, sigma=0.01,
                                           alpha=np.repeat(2, d),
                                           beta=np.repeat(0.1, d),
                                           a=1, b=1, phi=0.995)

    lam_estimated = np.average(lam, axis=1)
    t_estimated = np.average(t, axis=1)
    p_estimated = np.average(p)

    # Plot the series.
    fig, axs = plt.subplots(nrows=2)
    fig.set_figheight(30)
    fig.set_figwidth(30)
    ax1_left = axs[0]
    ax2_left = axs[1]
    ax1_right = ax1_left.twinx()
    ax2_right = ax2_left.twinx()
    ax1_left.plot(s, color='red', label="Susceptible")
    ax1_right.plot(i, color='blue', label="Infected")
    ax1_left.legend(loc=2)
    ax1_right.legend(loc=1)
    delta_i = -np.diff(s)
    ax2_left.plot(delta_i, color="blue", label="Newly Infected Individuals")
    ax2_right.plot(i, color='blue', linestyle='dashed', label="Infected")
    ax2_left.legend(loc=2)
    ax2_right.legend(loc=1)
    # Display obtained breakpoints on plot.
    for breakpoint in np.average(t, axis=1):
        ax1_right.axvline(breakpoint, color="green")
        ax2_right.axvline(breakpoint, color="green")

    # Get output strings
    lam_string = ""
    for j, lam_component in enumerate(lam_estimated):
        lam_string += f"Component {j+1}: {round(lam_component, 4)}\n"
    lam_string = lam_string.rstrip()
    t_string = ""
    for j, t_component in enumerate(t_estimated):
        t_string += f"Breakpoint {j+1}: {int(round(t_component, 0))}\n"
    t_string = t_string.rstrip()
    p_string = f"{round(p_estimated, 4)}"

    return fig, lam_string, t_string, p_string


if __name__ == "__main__":
    confirmed_path = "confirmed.csv"
    deaths_path = "deaths.csv"
    recovered_path = "recovered.csv"
    population_path = "population.csv"
    data_path = os.path.join(os.getcwd(), "data")
    parser = Parser(os.path.join(data_path, confirmed_path),
                    os.path.join(data_path, deaths_path),
                    os.path.join(data_path, recovered_path),
                    os.path.join(data_path, population_path))
    countries = parser.countries

    # Inputs
    dropdown = gr.Dropdown(choices=countries, value="Germany",
                           label="Select the Country")
    slider = gr.Slider(minimum=1, maximum=5, value=4, step=1,
                       label="Select the Number of Breakpoints")
    n_iterations = gr.Number(value=10000, precision=0,
                             label="Select the Number of iterations")
    burnin = gr.Number(value=1000, precision=0,
                       label="Select the Number of Burn-In Iterations",
                       info="Such iterations will be discarded.")

    # Outputs
    initial_plot = pickle.load(open("data/germany_estimate.pkl", "rb"))
    plot = gr.Plot(value=initial_plot, label="Results")
    initial_lam = "Component 1: 0.3114\nComponent 2: 0.1499\nComponent 3: 0.0822\nComponent 4: 0.0419"
    lam = gr.Text(value=initial_lam, label="Estimated Lambda")
    initial_t = "Breakpoint 1: 19\nBreakpoint 2: 29\nBreakpoint3: 38"
    t = gr.Text(value=initial_t, label="Estimated Breakpoints")
    initial_p = "0.0653"
    p = gr.Text(value=initial_p, label="Estimated Recovery Probability")
    interface = gr.Interface(sample,
                             inputs=[dropdown, slider, n_iterations, burnin],
                             outputs=[plot, lam, t, p])
    interface.launch()