covidSIR / app.py
SnoopKilla's picture
Update app.py
3a46db5
import os
import pickle
import gradio as gr
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from src.sampler import mcmc_sampler
from src.data_parser import Parser
matplotlib.use('Agg')
font = {'size': 30}
matplotlib.rc('font', **font)
def sample(country, d, n_iterations, burnin):
P = parser.parse_population(country)
start_date = "2020-03-01"
end_date = "2020-06-15"
i, r = parser.parse_data(start_date, end_date, country)
i, r = i.values, r.values
s = np.repeat(P, i.shape[0]) - i - r
p, lam, t, lam_ar, t_ar = mcmc_sampler(s, i, d, P, n_iterations, burnin,
M=3, sigma=0.01,
alpha=np.repeat(2, d),
beta=np.repeat(0.1, d),
a=1, b=1, phi=0.995)
lam_estimated = np.average(lam, axis=1)
t_estimated = np.average(t, axis=1)
p_estimated = np.average(p)
# Plot the series.
fig, axs = plt.subplots(nrows=2)
fig.set_figheight(30)
fig.set_figwidth(30)
ax1_left = axs[0]
ax2_left = axs[1]
ax1_right = ax1_left.twinx()
ax2_right = ax2_left.twinx()
ax1_left.plot(s, color='red', label="Susceptible")
ax1_right.plot(i, color='blue', label="Infected")
ax1_left.legend(loc=2)
ax1_right.legend(loc=1)
delta_i = -np.diff(s)
ax2_left.plot(delta_i, color="blue", label="Newly Infected Individuals")
ax2_right.plot(i, color='blue', linestyle='dashed', label="Infected")
ax2_left.legend(loc=2)
ax2_right.legend(loc=1)
# Display obtained breakpoints on plot.
for breakpoint in np.average(t, axis=1):
ax1_right.axvline(breakpoint, color="green")
ax2_right.axvline(breakpoint, color="green")
# Get output strings
lam_string = ""
for j, lam_component in enumerate(lam_estimated):
lam_string += f"Component {j+1}: {round(lam_component, 4)}\n"
lam_string = lam_string.rstrip()
t_string = ""
for j, t_component in enumerate(t_estimated):
t_string += f"Breakpoint {j+1}: {int(round(t_component, 0))}\n"
t_string = t_string.rstrip()
p_string = f"{round(p_estimated, 4)}"
return fig, lam_string, t_string, p_string
if __name__ == "__main__":
confirmed_path = "confirmed.csv"
deaths_path = "deaths.csv"
recovered_path = "recovered.csv"
population_path = "population.csv"
data_path = os.path.join(os.getcwd(), "data")
parser = Parser(os.path.join(data_path, confirmed_path),
os.path.join(data_path, deaths_path),
os.path.join(data_path, recovered_path),
os.path.join(data_path, population_path))
countries = parser.countries
# Inputs
dropdown = gr.Dropdown(choices=countries, value="Germany",
label="Select the Country")
slider = gr.Slider(minimum=1, maximum=5, value=4, step=1,
label="Select the Number of Breakpoints")
n_iterations = gr.Number(value=10000, precision=0,
label="Select the Number of iterations")
burnin = gr.Number(value=1000, precision=0,
label="Select the Number of Burn-In Iterations",
info="Such iterations will be discarded.")
# Outputs
initial_plot = pickle.load(open("data/germany_estimate.pkl", "rb"))
plot = gr.Plot(value=initial_plot, label="Results")
initial_lam = "Component 1: 0.3114\nComponent 2: 0.1499\nComponent 3: 0.0822\nComponent 4: 0.0419"
lam = gr.Text(value=initial_lam, label="Estimated Lambda")
initial_t = "Breakpoint 1: 19\nBreakpoint 2: 29\nBreakpoint3: 38"
t = gr.Text(value=initial_t, label="Estimated Breakpoints")
initial_p = "0.0653"
p = gr.Text(value=initial_p, label="Estimated Recovery Probability")
interface = gr.Interface(sample,
inputs=[dropdown, slider, n_iterations, burnin],
outputs=[plot, lam, t, p])
interface.launch()