Spaces:
Sleeping
Sleeping
Commit
·
406ac25
1
Parent(s):
f55aa25
Gradio APP
Browse files- app.py +98 -0
- data/confirmed.csv +0 -0
- data/deaths.csv +0 -0
- data/population.csv +266 -0
- data/recovered.csv +0 -0
- src/data_parser.py +63 -0
- src/sampler.py +160 -0
- src/utility_functions.py +134 -0
app.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from src.sampler import mcmc_sampler
|
7 |
+
from src.data_parser import Parser
|
8 |
+
matplotlib.use('Agg')
|
9 |
+
font = {'size': 30}
|
10 |
+
matplotlib.rc('font', **font)
|
11 |
+
|
12 |
+
|
13 |
+
def sample(country, d, n_iterations, burnin):
|
14 |
+
P = parser.parse_population(country)
|
15 |
+
start_date = "2020-03-01"
|
16 |
+
end_date = "2020-06-15"
|
17 |
+
i, r = parser.parse_data(start_date, end_date, country)
|
18 |
+
i, r = i.values, r.values
|
19 |
+
s = np.repeat(P, i.shape[0]) - i - r
|
20 |
+
|
21 |
+
p, lam, t, lam_ar, t_ar = mcmc_sampler(s, i, d, P, n_iterations, burnin,
|
22 |
+
M=3, sigma=0.01,
|
23 |
+
alpha=np.repeat(2, d),
|
24 |
+
beta=np.repeat(0.1, d),
|
25 |
+
a=1, b=1, phi=0.995)
|
26 |
+
|
27 |
+
lam_estimated = np.average(lam, axis=1)
|
28 |
+
t_estimated = np.average(t, axis=1)
|
29 |
+
p_estimated = np.average(p)
|
30 |
+
|
31 |
+
# Plot the series.
|
32 |
+
fig, axs = plt.subplots(nrows=2)
|
33 |
+
fig.set_figheight(30)
|
34 |
+
fig.set_figwidth(30)
|
35 |
+
ax1_left = axs[0]
|
36 |
+
ax2_left = axs[1]
|
37 |
+
ax1_right = ax1_left.twinx()
|
38 |
+
ax2_right = ax2_left.twinx()
|
39 |
+
ax1_left.plot(s, color='red', label="Susceptible")
|
40 |
+
ax1_right.plot(i, color='blue', label="Infected")
|
41 |
+
ax1_left.legend(loc=2)
|
42 |
+
ax1_right.legend(loc=1)
|
43 |
+
delta_i = -np.diff(s)
|
44 |
+
ax2_left.plot(delta_i, color="blue", label="Newly Infected Individuals")
|
45 |
+
ax2_right.plot(i, color='blue', linestyle='dashed', label="Infected")
|
46 |
+
ax2_left.legend(loc=2)
|
47 |
+
ax2_right.legend(loc=1)
|
48 |
+
# Display obtained breakpoints on plot.
|
49 |
+
for breakpoint in np.average(t, axis=1):
|
50 |
+
ax1_right.axvline(breakpoint, color="green")
|
51 |
+
ax2_right.axvline(breakpoint, color="green")
|
52 |
+
|
53 |
+
# Get output strings
|
54 |
+
lam_string = ""
|
55 |
+
for j, lam_component in enumerate(lam_estimated):
|
56 |
+
lam_string += f"Component {j+1}: {round(lam_component, 4)}\n"
|
57 |
+
lam_string = lam_string.rstrip()
|
58 |
+
t_string = ""
|
59 |
+
for j, t_component in enumerate(t_estimated):
|
60 |
+
t_string += f"Breakpoint {j+1}: {int(round(t_component, 0))}\n"
|
61 |
+
t_string = t_string.rstrip()
|
62 |
+
p_string = f"{round(p_estimated, 4)}"
|
63 |
+
|
64 |
+
return fig, lam_string, t_string, p_string
|
65 |
+
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
confirmed_path = "confirmed.csv"
|
69 |
+
deaths_path = "deaths.csv"
|
70 |
+
recovered_path = "recovered.csv"
|
71 |
+
population_path = "population.csv"
|
72 |
+
data_path = os.path.join(os.getcwd(), "data")
|
73 |
+
parser = Parser(os.path.join(data_path, confirmed_path),
|
74 |
+
os.path.join(data_path, deaths_path),
|
75 |
+
os.path.join(data_path, recovered_path),
|
76 |
+
os.path.join(data_path, population_path))
|
77 |
+
countries = parser.countries
|
78 |
+
|
79 |
+
# Inputs
|
80 |
+
dropdown = gr.Dropdown(choices=countries, value="Germany",
|
81 |
+
label="Select the Country")
|
82 |
+
slider = gr.Slider(minimum=1, maximum=5, value=3, step=1,
|
83 |
+
label="Select the Number of Breakpoints")
|
84 |
+
n_iterations = gr.Number(value=10000, precision=0,
|
85 |
+
label="Select the Number of iterations")
|
86 |
+
burnin = gr.Number(value=1000, precision=0,
|
87 |
+
label="Select the Number of Burn-In Iterations",
|
88 |
+
info="Such iterations will be discarded.")
|
89 |
+
|
90 |
+
# Outputs
|
91 |
+
plot = gr.Plot(label="Results")
|
92 |
+
lam = gr.Text(label="Estimated Lambda")
|
93 |
+
t = gr.Text(label="Estimated Breakpoints")
|
94 |
+
p = gr.Text(label="Estimated Recovery Probability")
|
95 |
+
interface = gr.Interface(sample,
|
96 |
+
inputs=[dropdown, slider, n_iterations, burnin],
|
97 |
+
outputs=[plot, lam, t, p])
|
98 |
+
interface.launch()
|
data/confirmed.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/deaths.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/population.csv
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Country,Population
|
2 |
+
Aruba,106585
|
3 |
+
Africa Eastern and Southern,685112705
|
4 |
+
Afghanistan,38972230
|
5 |
+
Africa Western and Central,466189102
|
6 |
+
Angola,33428486
|
7 |
+
Albania,2837849
|
8 |
+
Andorra,77700
|
9 |
+
Arab World,449228296
|
10 |
+
United Arab Emirates,9287289
|
11 |
+
Argentina,45376763
|
12 |
+
Armenia,2805608
|
13 |
+
American Samoa,46189
|
14 |
+
Antigua and Barbuda,92664
|
15 |
+
Australia,25655289
|
16 |
+
Austria,8916864
|
17 |
+
Azerbaijan,10093121
|
18 |
+
Burundi,12220227
|
19 |
+
Belgium,11538604
|
20 |
+
Benin,12643123
|
21 |
+
Burkina Faso,21522626
|
22 |
+
Bangladesh,167420951
|
23 |
+
Bulgaria,6934015
|
24 |
+
Bahrain,1477469
|
25 |
+
Bahamas,406471
|
26 |
+
Bosnia and Herzegovina,3318407
|
27 |
+
Belarus,9379952
|
28 |
+
Belize,394921
|
29 |
+
Bermuda,63893
|
30 |
+
Bolivia,11936162
|
31 |
+
Brazil,213196304
|
32 |
+
Barbados,280693
|
33 |
+
Brunei,441725
|
34 |
+
Bhutan,772506
|
35 |
+
Botswana,2546402
|
36 |
+
Central African Republic,5343020
|
37 |
+
Canada,38037204
|
38 |
+
Central Europe and the Baltics,102180124
|
39 |
+
Switzerland,8638167
|
40 |
+
Channel Islands,171113
|
41 |
+
Chile,19300315
|
42 |
+
China,1411100000
|
43 |
+
Cote d'Ivoire,26811790
|
44 |
+
Cameroon,26491087
|
45 |
+
"Congo, Dem. Rep.",92853164
|
46 |
+
"Congo, Rep.",5702174
|
47 |
+
Colombia,50930662
|
48 |
+
Comoros,806166
|
49 |
+
Cabo Verde,582640
|
50 |
+
Costa Rica,5123105
|
51 |
+
Caribbean small states,7444768
|
52 |
+
Cuba,11300698
|
53 |
+
Curacao,154947
|
54 |
+
Cayman Islands,67311
|
55 |
+
Cyprus,1237537
|
56 |
+
Czechia,10697858
|
57 |
+
Germany,83160871
|
58 |
+
Djibouti,1090156
|
59 |
+
Dominica,71995
|
60 |
+
Denmark,5831404
|
61 |
+
Dominican Republic,10999664
|
62 |
+
Algeria,43451666
|
63 |
+
East Asia & Pacific (excluding high income),2116424876
|
64 |
+
Early-demographic dividend,-2147483648
|
65 |
+
East Asia & Pacific,-2147483648
|
66 |
+
Europe & Central Asia (excluding high income),400811771
|
67 |
+
Europe & Central Asia,923103879
|
68 |
+
Ecuador,17588595
|
69 |
+
Egypt,107465134
|
70 |
+
Euro area,342913447
|
71 |
+
Eritrea,3555868
|
72 |
+
Spain,47365655
|
73 |
+
Estonia,1329522
|
74 |
+
Ethiopia,117190911
|
75 |
+
European Union,447692315
|
76 |
+
Fragile and conflict affected situations,979418527
|
77 |
+
Finland,5529543
|
78 |
+
Fiji,920422
|
79 |
+
France,67571107
|
80 |
+
Faroe Islands,52415
|
81 |
+
Micronesia,112106
|
82 |
+
Gabon,2292573
|
83 |
+
United Kingdom,67081000
|
84 |
+
Georgia,3722716
|
85 |
+
Ghana,32180401
|
86 |
+
Gibraltar,32709
|
87 |
+
Guinea,13205153
|
88 |
+
Gambia,2573995
|
89 |
+
Guinea-Bissau,2015828
|
90 |
+
Equatorial Guinea,1596049
|
91 |
+
Greece,10698599
|
92 |
+
Grenada,123663
|
93 |
+
Greenland,56367
|
94 |
+
Guatemala,16858333
|
95 |
+
Guam,169231
|
96 |
+
Guyana,797202
|
97 |
+
High income,1240900955
|
98 |
+
"Hong Kong SAR, China",7481000
|
99 |
+
Honduras,10121763
|
100 |
+
Heavily indebted poor countries (HIPC),838066650
|
101 |
+
Croatia,4047680
|
102 |
+
Haiti,11306801
|
103 |
+
Hungary,9750149
|
104 |
+
IBRD only,-2147483648
|
105 |
+
IDA & IBRD total,-2147483648
|
106 |
+
IDA total,1738306807
|
107 |
+
IDA blend,582637127
|
108 |
+
Indonesia,271857970
|
109 |
+
IDA only,1155669680
|
110 |
+
Isle of Man,84046
|
111 |
+
India,1396387127
|
112 |
+
Ireland,4985382
|
113 |
+
Iran,87290193
|
114 |
+
Iraq,42556984
|
115 |
+
Iceland,366463
|
116 |
+
Israel,9215100
|
117 |
+
Italy,59438851
|
118 |
+
Jamaica,2820436
|
119 |
+
Jordan,10928721
|
120 |
+
Japan,126261000
|
121 |
+
Kazakhstan,18755666
|
122 |
+
Kenya,51985780
|
123 |
+
Kyrgyzstan,6579900
|
124 |
+
Cambodia,16396860
|
125 |
+
Kiribati,126463
|
126 |
+
Saint Kitts and Nevis,47642
|
127 |
+
"Korea, South",51836239
|
128 |
+
Kuwait,4360444
|
129 |
+
Latin America & Caribbean (excluding high income),588808380
|
130 |
+
Laos,7319399
|
131 |
+
Lebanon,5662923
|
132 |
+
Liberia,5087584
|
133 |
+
Libya,6653942
|
134 |
+
Saint Lucia,179237
|
135 |
+
Latin America & Caribbean,650534967
|
136 |
+
Least developed countries: UN classification,1073743450
|
137 |
+
Low income,699186538
|
138 |
+
Liechtenstein,38756
|
139 |
+
Sri Lanka,21919000
|
140 |
+
Lower middle income,-2147483648
|
141 |
+
Low & middle income,-2147483648
|
142 |
+
Lesotho,2254100
|
143 |
+
Late-demographic dividend,-2147483648
|
144 |
+
Lithuania,2794885
|
145 |
+
Luxembourg,630419
|
146 |
+
Latvia,1900449
|
147 |
+
"Macao SAR, China",676283
|
148 |
+
St. Martin (French part),32553
|
149 |
+
Morocco,36688772
|
150 |
+
Monaco,36922
|
151 |
+
Moldova,2635130
|
152 |
+
Madagascar,28225177
|
153 |
+
Maldives,514438
|
154 |
+
Middle East & North Africa,479966649
|
155 |
+
Mexico,125998302
|
156 |
+
Marshall Islands,43413
|
157 |
+
Middle income,-2147483648
|
158 |
+
North Macedonia,2072531
|
159 |
+
Mali,21224040
|
160 |
+
Malta,515332
|
161 |
+
Burma,53423198
|
162 |
+
Middle East & North Africa (excluding high income),411810124
|
163 |
+
Montenegro,621306
|
164 |
+
Mongolia,3294335
|
165 |
+
Northern Mariana Islands,49587
|
166 |
+
Mozambique,31178239
|
167 |
+
Mauritania,4498604
|
168 |
+
Mauritius,1265740
|
169 |
+
Malawi,19377061
|
170 |
+
Malaysia,33199993
|
171 |
+
North America,369602177
|
172 |
+
Namibia,2489098
|
173 |
+
New Caledonia,271130
|
174 |
+
Niger,24333639
|
175 |
+
Nigeria,208327405
|
176 |
+
Nicaragua,6755895
|
177 |
+
Netherlands,17441500
|
178 |
+
Norway,5379475
|
179 |
+
Nepal,29348627
|
180 |
+
Nauru,12315
|
181 |
+
New Zealand,5090200
|
182 |
+
OECD members,1370241530
|
183 |
+
Oman,4543399
|
184 |
+
Other small states,32381190
|
185 |
+
Pakistan,227196741
|
186 |
+
Panama,4294396
|
187 |
+
Peru,33304756
|
188 |
+
Philippines,112190977
|
189 |
+
Palau,17972
|
190 |
+
Papua New Guinea,9749640
|
191 |
+
Poland,37899070
|
192 |
+
Pre-demographic dividend,984213438
|
193 |
+
Puerto Rico,3281538
|
194 |
+
"Korea, Dem. People's Rep.",25867467
|
195 |
+
Portugal,10297081
|
196 |
+
Paraguay,6618695
|
197 |
+
West Bank and Gaza,4803269
|
198 |
+
Pacific island small states,2566819
|
199 |
+
Post-demographic dividend,1117443485
|
200 |
+
French Polynesia,301920
|
201 |
+
Qatar,2760385
|
202 |
+
Romania,19265250
|
203 |
+
Russia,144073139
|
204 |
+
Rwanda,13146362
|
205 |
+
South Asia,1882531620
|
206 |
+
Saudi Arabia,35997107
|
207 |
+
Sudan,44440486
|
208 |
+
Senegal,16436120
|
209 |
+
Singapore,5685807
|
210 |
+
Solomon Islands,691191
|
211 |
+
Sierra Leone,8233970
|
212 |
+
El Salvador,6292731
|
213 |
+
San Marino,34007
|
214 |
+
Somalia,16537016
|
215 |
+
Serbia,6899126
|
216 |
+
Sub-Saharan Africa (excluding high income),1151203345
|
217 |
+
South Sudan,10606227
|
218 |
+
Sub-Saharan Africa,1151301807
|
219 |
+
Small states,42392777
|
220 |
+
Sao Tome and Principe,218641
|
221 |
+
Suriname,607065
|
222 |
+
Slovakia,5458827
|
223 |
+
Slovenia,2102419
|
224 |
+
Sweden,10353442
|
225 |
+
Eswatini,1180655
|
226 |
+
Sint Maarten (Dutch part),42310
|
227 |
+
Seychelles,98462
|
228 |
+
Syria,20772595
|
229 |
+
Turks and Caicos Islands,44276
|
230 |
+
Chad,16644701
|
231 |
+
East Asia & Pacific (IDA & IBRD countries),2090523535
|
232 |
+
Europe & Central Asia (IDA & IBRD countries),462023771
|
233 |
+
Togo,8442580
|
234 |
+
Thailand,71475664
|
235 |
+
Tajikistan,9543207
|
236 |
+
Turkmenistan,6250438
|
237 |
+
Latin America & the Caribbean (IDA & IBRD countries),634680385
|
238 |
+
Timor-Leste,1299995
|
239 |
+
Middle East & North Africa (IDA & IBRD countries),407006855
|
240 |
+
Tonga,105254
|
241 |
+
South Asia (IDA & IBRD),1882531620
|
242 |
+
Sub-Saharan Africa (IDA & IBRD countries),1151301807
|
243 |
+
Trinidad and Tobago,1518147
|
244 |
+
Tunisia,12161723
|
245 |
+
Turkey,84135428
|
246 |
+
Tuvalu,11069
|
247 |
+
Tanzania,61704518
|
248 |
+
Uganda,44404611
|
249 |
+
Ukraine,44132049
|
250 |
+
Upper middle income,-2147483648
|
251 |
+
Uruguay,3429086
|
252 |
+
US,331501080
|
253 |
+
Uzbekistan,34232050
|
254 |
+
Saint Vincent and the Grenadines,104632
|
255 |
+
Venezuela,28490453
|
256 |
+
British Virgin Islands,30910
|
257 |
+
Virgin Islands (U.S.),106290
|
258 |
+
Vietnam,96648685
|
259 |
+
Vanuatu,311685
|
260 |
+
World,-2147483648
|
261 |
+
Samoa,214929
|
262 |
+
Kosovo,1790133
|
263 |
+
Yemen,32284046
|
264 |
+
South Africa,58801927
|
265 |
+
Zambia,18927715
|
266 |
+
Zimbabwe,15669666
|
data/recovered.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/data_parser.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
from datetime import datetime
|
4 |
+
|
5 |
+
|
6 |
+
class Parser:
|
7 |
+
def __init__(self, filename_confirmed,
|
8 |
+
filename_deaths,
|
9 |
+
filename_recovered,
|
10 |
+
filename_population):
|
11 |
+
self.confirmed = self.read_csv(filename_confirmed)
|
12 |
+
self.deaths = self.read_csv(filename_deaths)
|
13 |
+
self.recovered = self.read_csv(filename_recovered)
|
14 |
+
self.population = self.read_population(filename_population)
|
15 |
+
self.countries = list(np.intersect1d(self.confirmed.columns.values,
|
16 |
+
self.population.index.values))
|
17 |
+
|
18 |
+
def read_csv(self, filename):
|
19 |
+
# Create pandas dataframe from .csv
|
20 |
+
data = pd.read_csv(filename)
|
21 |
+
|
22 |
+
# Manipulate the dataframe to have dates as row indices and country
|
23 |
+
# names as column names
|
24 |
+
data = data.set_index("Country/Region")
|
25 |
+
data = data.T
|
26 |
+
data.index = pd.to_datetime(data.index)
|
27 |
+
|
28 |
+
return data
|
29 |
+
|
30 |
+
def parse_data(self, start_date, end_date, country):
|
31 |
+
self.validate_date(start_date)
|
32 |
+
self.validate_date(end_date)
|
33 |
+
self.validate_country(country)
|
34 |
+
|
35 |
+
delta_i = self.confirmed.loc[:end_date, country].diff().dropna()
|
36 |
+
delta_i = delta_i.astype(int)
|
37 |
+
r = (self.deaths.loc[:end_date, country]
|
38 |
+
+ self.recovered.loc[:end_date, country])
|
39 |
+
delta_r = r.diff().dropna().astype(int)
|
40 |
+
i = (delta_i - delta_r).cumsum()
|
41 |
+
return i[start_date:], r[start_date:]
|
42 |
+
|
43 |
+
def read_population(self, filename):
|
44 |
+
# Create pandas dataframe from .csv
|
45 |
+
data = pd.read_csv(filename)
|
46 |
+
data = data.set_index("Country")
|
47 |
+
|
48 |
+
return data
|
49 |
+
|
50 |
+
def parse_population(self, country):
|
51 |
+
population = self.population.loc[country, "Population"]
|
52 |
+
|
53 |
+
return population
|
54 |
+
|
55 |
+
def validate_date(self, date_text):
|
56 |
+
try:
|
57 |
+
datetime.strptime(date_text, '%Y-%m-%d')
|
58 |
+
except ValueError:
|
59 |
+
raise ValueError("Incorrect data format, should be YYYY-MM-DD!")
|
60 |
+
|
61 |
+
def validate_country(self, country):
|
62 |
+
if country not in self.countries:
|
63 |
+
raise ValueError("Country not in list!")
|
src/sampler.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.utility_functions import log_pi_lambda, log_pi_t
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def update_lambda(lam, t, s, i, P, sigma, alpha, beta, phi):
|
6 |
+
# This function updates the parameter vector lambda.
|
7 |
+
# INPUT:
|
8 |
+
# - lam: array of the values of lambda;
|
9 |
+
# - t: array of the breakpoints;
|
10 |
+
# - s: array of susceptible individuals during time;
|
11 |
+
# - i: array of infected individuals during time;
|
12 |
+
# - P: total number of individuals;
|
13 |
+
# - sigma: algorithm parameter for the proposal of a new candidate lambda;
|
14 |
+
# - alpha, beta: hyperparameters of the prior of lambda;
|
15 |
+
# - phi parameter phi of the model.
|
16 |
+
# OUTPUT:
|
17 |
+
# - lam: update array;
|
18 |
+
# - accept: number of acceted candidates.
|
19 |
+
# NOTES: The update is done component-wise in a sequential manner.
|
20 |
+
|
21 |
+
current = np.copy(lam) # Get the current state of the chain.
|
22 |
+
candidate = np.copy(current) # Initialize the new candidate.
|
23 |
+
|
24 |
+
# For every component of the parameter vector, we tweak such component
|
25 |
+
# according to the chosen proposal and then update the chain according
|
26 |
+
# to the computed acceptance rate.
|
27 |
+
accepted = 0 # Initialize the count of accepted candidates.
|
28 |
+
for j in range(current.shape[0]):
|
29 |
+
# Tweak the j-th component.
|
30 |
+
candidate[j] = candidate[j] + sigma * np.random.normal()
|
31 |
+
# Compute the acceptance rate.
|
32 |
+
log_alpha = (log_pi_lambda(candidate, t, s, i, P, alpha, beta, phi)
|
33 |
+
- log_pi_lambda(current, t, s, i, P, alpha, beta, phi))
|
34 |
+
|
35 |
+
# If the candidate is accepted, we move the chain (current = candidate)
|
36 |
+
# and increase the count of accepted candidates. Otherwise, we reject
|
37 |
+
# the candidate and the chain does not move from the current state
|
38 |
+
# (candidate = current).
|
39 |
+
if log_alpha > np.log(np.random.uniform()):
|
40 |
+
current = np.copy(candidate)
|
41 |
+
accepted = accepted + 1
|
42 |
+
else:
|
43 |
+
candidate = np.copy(current)
|
44 |
+
|
45 |
+
return current.reshape(-1, 1), accepted
|
46 |
+
|
47 |
+
|
48 |
+
def update_t(lam, t, s, i, P, M, phi):
|
49 |
+
# This function updates the parameter vector lambda.
|
50 |
+
# INPUT:
|
51 |
+
# - lam: array of the values of lambda;
|
52 |
+
# - t: array of the breakpoints;
|
53 |
+
# - s: array of susceptible individuals during time;
|
54 |
+
# - i: array of infected individuals during time;
|
55 |
+
# - P: total number of individuals;
|
56 |
+
# - M: algorithm parameter for the proposal of a new candidate t;
|
57 |
+
# - phi parameter phi of the model.
|
58 |
+
# OUTPUT:
|
59 |
+
# - lam: update array;
|
60 |
+
# - accept: number of acceted candidates.
|
61 |
+
# NOTES: The update is done component-wise in a sequential manner.
|
62 |
+
|
63 |
+
current = np.copy(t) # Get the current state of the chain.
|
64 |
+
candidate = np.copy(current) # Initialize the new candidate.
|
65 |
+
|
66 |
+
# For every component of the parameter vector, we tweak such component
|
67 |
+
# according to the chosen proposal and then update the chain according
|
68 |
+
# to the computed acceptance rate.
|
69 |
+
accepted = 0 # Initialize the count of accepted candidates.
|
70 |
+
for j in range(current.shape[0]):
|
71 |
+
# Tweak the j-th component.
|
72 |
+
candidate[j] = candidate[j] + np.random.choice(np.arange(-M, M + 1))
|
73 |
+
# Compute the acceptance rate.
|
74 |
+
log_alpha = (log_pi_t(lam, candidate, s, i, P, phi)
|
75 |
+
- log_pi_t(lam, current, s, i, P, phi))
|
76 |
+
|
77 |
+
# If the candidate is accepted, we move the chain (current = candidate)
|
78 |
+
# and increase the count of accepted candidates. Otherwise, we reject
|
79 |
+
# the candidate and the chain does not move from the current state
|
80 |
+
# (candidate = current).
|
81 |
+
if log_alpha > np.log(np.random.uniform()):
|
82 |
+
current = np.copy(candidate)
|
83 |
+
accepted = accepted + 1
|
84 |
+
else:
|
85 |
+
candidate = np.copy(current)
|
86 |
+
|
87 |
+
return current.reshape(-1, 1), accepted
|
88 |
+
|
89 |
+
|
90 |
+
def mcmc_sampler(s, i, d, P, n_iterations, burnin, M, sigma,
|
91 |
+
alpha, beta, a, b, phi):
|
92 |
+
# This function implement the hybrid MCMC sampler.
|
93 |
+
# INPUT:
|
94 |
+
# - s: array of susceptible individuals during time;
|
95 |
+
# - i: array of infected individuals during time;
|
96 |
+
# - d: number of breakpoints;
|
97 |
+
# - P: total number of individuals;
|
98 |
+
# - n_iterations: number of iterations for the algorithm;
|
99 |
+
# - burnin: number of burnin iterations to discard;
|
100 |
+
# - M: algorithm parameter for the proposal of a new candidate t;
|
101 |
+
# - sigma: algorithm parameter for the proposal of a new candidate lambda;
|
102 |
+
# - alpha, beta: hyperparameters of the prior of lambda;
|
103 |
+
# - a, b: hyperparameters of the prior of p;
|
104 |
+
# - phi: parameter phi of the model.
|
105 |
+
# OUTPUT:
|
106 |
+
# - p: simulated chain for the probability of removal from
|
107 |
+
# infected population;
|
108 |
+
# - lam: simulated chain for lambda;
|
109 |
+
# - t: simulated chain for the breakpoints.
|
110 |
+
|
111 |
+
T = s.shape[0] - 1 # Index of the final time instant.
|
112 |
+
|
113 |
+
# Initialize the parameters.
|
114 |
+
|
115 |
+
# The initial value of p is drawn from the prior distribution.
|
116 |
+
p = np.random.beta(a, b, size=(1, 1))
|
117 |
+
# Each of the d breakpoints (t_i) is drawn randomly (without replacement)
|
118 |
+
# between 1 and T-1. The obtained vector is then sorted to make sure
|
119 |
+
# that t_1 < t_2 < ... < t_d.
|
120 |
+
t = np.sort(np.random.choice(np.arange(1, T), size=d-1, replace=False))
|
121 |
+
t = t.reshape(-1, 1)
|
122 |
+
# Each of the lambda_i's is drawn independently from
|
123 |
+
# its prior distribution.
|
124 |
+
lam = np.random.gamma(alpha, beta)
|
125 |
+
lam = lam.reshape(-1, 1)
|
126 |
+
|
127 |
+
# Compute the hyperparameters of the posterior of p.
|
128 |
+
a_new = a + i[0] - i[-1] + s[0] - s[-1]
|
129 |
+
b_new = b + np.sum(i[1:]) + s[-1] - s[0]
|
130 |
+
|
131 |
+
# Initialize the count of accepted candidates for lambda and t.
|
132 |
+
a_lam = 0
|
133 |
+
a_t = 0
|
134 |
+
|
135 |
+
# Run the chain
|
136 |
+
for _ in range(n_iterations):
|
137 |
+
# Update p by sampling from its posterior.
|
138 |
+
p = np.hstack((p, np.random.beta(a_new, b_new, size=(1, 1))))
|
139 |
+
# Update lam via Metropolis-Hastings step.
|
140 |
+
new_lam, accepted_lam = update_lambda(lam[:, -1], t[:, -1], s, i, P,
|
141 |
+
sigma, alpha, beta, phi)
|
142 |
+
# Update t via Metropolis-Hastings step.
|
143 |
+
new_t, accepted_t = update_t(lam[:, -1], t[:, -1], s, i, P, M, phi)
|
144 |
+
lam = np.hstack((lam, new_lam))
|
145 |
+
t = np.hstack((t, new_t))
|
146 |
+
|
147 |
+
# Update the counts of accepted candidates for lambda and t.
|
148 |
+
a_lam = a_lam + accepted_lam
|
149 |
+
a_t = a_t + accepted_t
|
150 |
+
|
151 |
+
# Compute the acceptance rates for lambda and t.
|
152 |
+
lam_ar = a_lam / n_iterations / d
|
153 |
+
t_ar = a_t / n_iterations / (d-1)
|
154 |
+
|
155 |
+
# Discard burn-in iterations.
|
156 |
+
p = p[:, burnin:]
|
157 |
+
lam = lam[:, burnin:]
|
158 |
+
t = t[:, burnin:]
|
159 |
+
|
160 |
+
return p, lam, t, lam_ar, t_ar
|
src/utility_functions.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy.special import gammaln
|
3 |
+
|
4 |
+
|
5 |
+
def lambda_time(lam, t, time):
|
6 |
+
# This function computes the value of lambda.
|
7 |
+
# INPUT:
|
8 |
+
# - lam, t: arrays defining the (piecewise constant) function lambda(t);
|
9 |
+
# - time: time instants at which we want to evaluate lambda.
|
10 |
+
# OUTPUT:
|
11 |
+
# - lambda_time: value of lambda.
|
12 |
+
# NOTES: The function is vectorized in the array time. Indeed,
|
13 |
+
# it allows to compute kappa for all the time instants in the array time.
|
14 |
+
|
15 |
+
lambda_time = lam[np.searchsorted(t, time, side="right")]
|
16 |
+
return lambda_time
|
17 |
+
|
18 |
+
|
19 |
+
def compute_kappa_time(s_t, i_t, lam, t, time, phi, P):
|
20 |
+
# This function computes the value of kappa.
|
21 |
+
# INPUT:
|
22 |
+
# - s_t: number of susceptible individuals;
|
23 |
+
# - i_t: number of infected individuals;
|
24 |
+
# - lam, t: arrays defining the (piecewise constant) function lambda(t);
|
25 |
+
# - time: time instants at which we want to evaluate kappa;
|
26 |
+
# - phi: parameter phi of the model;
|
27 |
+
# - P: total number of individuals.
|
28 |
+
# OUTPUT:
|
29 |
+
# - kappa_time: value of kappa.
|
30 |
+
# NOTES: The function is vectorized in the arrays time, s_t and i_t.
|
31 |
+
# Indeed, it allows to compute kappa for all the time instants in
|
32 |
+
# the array time. It is required that time, s_t and i_t
|
33 |
+
# have the same dimension.
|
34 |
+
|
35 |
+
p_si_t = 1 - np.exp(-np.multiply(lambda_time(lam, t, time), i_t) / P)
|
36 |
+
kappa_time = (1/phi - 1) * np.multiply(s_t, p_si_t)
|
37 |
+
return kappa_time
|
38 |
+
|
39 |
+
|
40 |
+
def log_pi_lambda(lam, t, s, i, P, alpha, beta, phi):
|
41 |
+
# This function computes the (log-) full-conditional
|
42 |
+
# of the parameter vector lambda.
|
43 |
+
# INPUT:
|
44 |
+
# - lam: array of the values of lambda;
|
45 |
+
# - t: array of the breakpoints;
|
46 |
+
# - s: array of susceptible individuals during time;
|
47 |
+
# - i: array of infected individuals during time;
|
48 |
+
# - P: total number of individuals;
|
49 |
+
# - alpha, beta: hyperparameters of the prior of lambda;
|
50 |
+
# - phi: parameter phi of the model.
|
51 |
+
# OUTPUT:
|
52 |
+
# - result: (log-) full-conditional of lambda evaluated at lam.
|
53 |
+
|
54 |
+
T = s.shape[0] - 1 # Index of the final time instant.
|
55 |
+
time = np.arange(T + 1) # Array of all time instants.
|
56 |
+
|
57 |
+
# First, we initialize the result to -inf. If all the components of the
|
58 |
+
# vector lam are positive (i.e., the vector is admissible) we compute the
|
59 |
+
# (log-) full-conditional of lambda evaluated at lam and return the result.
|
60 |
+
result = np.NINF
|
61 |
+
if all(lam > 0):
|
62 |
+
kappa_vec = compute_kappa_time(s[:-1], i[:-1], lam,
|
63 |
+
t, time[:-1], phi, P)
|
64 |
+
result = (np.sum(gammaln(-np.diff(s) + kappa_vec)
|
65 |
+
+ kappa_vec * np.log(1 - phi)
|
66 |
+
- gammaln(kappa_vec))
|
67 |
+
+ np.sum(np.log(np.power(lam, alpha-1))
|
68 |
+
- np.multiply(beta, lam)))
|
69 |
+
return result
|
70 |
+
|
71 |
+
|
72 |
+
def log_pi_t(lam, t, s, i, P, phi):
|
73 |
+
# This function computes the (log-) full-conditional
|
74 |
+
# of the parameter vector t.
|
75 |
+
# INPUT:
|
76 |
+
# - lam: array of the values of lambda;
|
77 |
+
# - t: array of the breakpoints;
|
78 |
+
# - s: array of susceptible individuals during time;
|
79 |
+
# - i: array of infected individuals during time;
|
80 |
+
# - P: total number of individuals;
|
81 |
+
# - phi parameter phi of the model.
|
82 |
+
# OUTPUT:
|
83 |
+
# - result: (log-) full-conditional of t evaluated at t.
|
84 |
+
|
85 |
+
T = s.shape[0] - 1 # Index of the final time instant.
|
86 |
+
time = np.arange(T + 1) # Array of all time instants.
|
87 |
+
|
88 |
+
# First, we initialize the result to -inf. If we have that
|
89 |
+
# 0 < t_1 < t_2 < ... < t_(d-1) < T (i.e., the vector is admissible)
|
90 |
+
# we compute the (log-) full-conditional of t evaluated at t
|
91 |
+
# and return the result.
|
92 |
+
result = np.NINF
|
93 |
+
if np.all(np.diff(t) > 0) and t[0] > 0 and t[-1] < T:
|
94 |
+
kappa_vec = compute_kappa_time(s[:-1], i[:-1], lam,
|
95 |
+
t, time[:-1], phi, P)
|
96 |
+
result = np.sum(gammaln(-np.diff(s) + kappa_vec)
|
97 |
+
+ kappa_vec * np.log(1 - phi)
|
98 |
+
- gammaln(kappa_vec))
|
99 |
+
return result
|
100 |
+
|
101 |
+
|
102 |
+
def simulate_data(T, lam, t, s_0, i_0, p_r, phi):
|
103 |
+
# This function simulates data according to the process described above.
|
104 |
+
# INPUT:
|
105 |
+
# - T: index of the final time instant;
|
106 |
+
# - lam, t: (true) arrays defining the function lambda(t);
|
107 |
+
# - s_0: initial number of susceptible individuals;
|
108 |
+
# - i_0: initial number of infected individuals;
|
109 |
+
# - p_r: (true) probability of removal from infected population;
|
110 |
+
# - phi: parameter phi of the model.
|
111 |
+
# OUTPUT:
|
112 |
+
# - s: array of susceptible individuals during time;
|
113 |
+
# - i: array of infected individuals during time.
|
114 |
+
|
115 |
+
# Compute the total number of individuals.
|
116 |
+
P = s_0 + i_0
|
117 |
+
|
118 |
+
# Initialize the arrays s and t.
|
119 |
+
s = np.array([s_0])
|
120 |
+
i = np.array([i_0])
|
121 |
+
|
122 |
+
time = np.arange(T + 1)
|
123 |
+
for t_i in time:
|
124 |
+
# Draw a realization of delta_r.
|
125 |
+
delta_r = np.random.binomial(i[-1], p_r)
|
126 |
+
# Compute the kappa parameter at time t_i.
|
127 |
+
kappa = compute_kappa_time(s[-1], i[-1], lam, t, t_i, phi, P)
|
128 |
+
# Draw a realization of delta_i.
|
129 |
+
delta_i = np.random.negative_binomial(kappa, 1 - phi)
|
130 |
+
|
131 |
+
# Update s and i according to the model.
|
132 |
+
s = np.append(s, s[-1] - delta_i)
|
133 |
+
i = np.append(i, i[-1] + delta_i - delta_r)
|
134 |
+
return s, i
|