import os import pickle import gradio as gr import numpy as np import matplotlib import matplotlib.pyplot as plt from src.sampler import mcmc_sampler from src.data_parser import Parser matplotlib.use('Agg') font = {'size': 30} matplotlib.rc('font', **font) def sample(country, d, n_iterations, burnin): P = parser.parse_population(country) start_date = "2020-03-01" end_date = "2020-06-15" i, r = parser.parse_data(start_date, end_date, country) i, r = i.values, r.values s = np.repeat(P, i.shape[0]) - i - r p, lam, t, lam_ar, t_ar = mcmc_sampler(s, i, d, P, n_iterations, burnin, M=3, sigma=0.01, alpha=np.repeat(2, d), beta=np.repeat(0.1, d), a=1, b=1, phi=0.995) lam_estimated = np.average(lam, axis=1) t_estimated = np.average(t, axis=1) p_estimated = np.average(p) # Plot the series. fig, axs = plt.subplots(nrows=2) fig.set_figheight(30) fig.set_figwidth(30) ax1_left = axs[0] ax2_left = axs[1] ax1_right = ax1_left.twinx() ax2_right = ax2_left.twinx() ax1_left.plot(s, color='red', label="Susceptible") ax1_right.plot(i, color='blue', label="Infected") ax1_left.legend(loc=2) ax1_right.legend(loc=1) delta_i = -np.diff(s) ax2_left.plot(delta_i, color="blue", label="Newly Infected Individuals") ax2_right.plot(i, color='blue', linestyle='dashed', label="Infected") ax2_left.legend(loc=2) ax2_right.legend(loc=1) # Display obtained breakpoints on plot. for breakpoint in np.average(t, axis=1): ax1_right.axvline(breakpoint, color="green") ax2_right.axvline(breakpoint, color="green") # Get output strings lam_string = "" for j, lam_component in enumerate(lam_estimated): lam_string += f"Component {j+1}: {round(lam_component, 4)}\n" lam_string = lam_string.rstrip() t_string = "" for j, t_component in enumerate(t_estimated): t_string += f"Breakpoint {j+1}: {int(round(t_component, 0))}\n" t_string = t_string.rstrip() p_string = f"{round(p_estimated, 4)}" return fig, lam_string, t_string, p_string if __name__ == "__main__": confirmed_path = "confirmed.csv" deaths_path = "deaths.csv" recovered_path = "recovered.csv" population_path = "population.csv" data_path = os.path.join(os.getcwd(), "data") parser = Parser(os.path.join(data_path, confirmed_path), os.path.join(data_path, deaths_path), os.path.join(data_path, recovered_path), os.path.join(data_path, population_path)) countries = parser.countries # Inputs dropdown = gr.Dropdown(choices=countries, value="Germany", label="Select the Country") slider = gr.Slider(minimum=1, maximum=5, value=4, step=1, label="Select the Number of Breakpoints") n_iterations = gr.Number(value=10000, precision=0, label="Select the Number of iterations") burnin = gr.Number(value=1000, precision=0, label="Select the Number of Burn-In Iterations", info="Such iterations will be discarded.") # Outputs initial_plot = pickle.load(open("data/germany_estimate.pkl", "rb")) plot = gr.Plot(value=initial_plot, label="Results") initial_lam = "Component 1: 0.3114\nComponent 2: 0.1499\nComponent 3: 0.0822\nComponent 4: 0.0419" lam = gr.Text(value=initial_lam, label="Estimated Lambda") initial_t = "Breakpoint 1: 19\nBreakpoint 2: 29\nBreakpoint3: 38" t = gr.Text(value=initial_t, label="Estimated Breakpoints") initial_p = "0.0653" p = gr.Text(value=initial_p, label="Estimated Recovery Probability") interface = gr.Interface(sample, inputs=[dropdown, slider, n_iterations, burnin], outputs=[plot, lam, t, p]) interface.launch()