SnoopKilla commited on
Commit
406ac25
·
1 Parent(s): f55aa25

Gradio APP

Browse files
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import numpy as np
4
+ import matplotlib
5
+ import matplotlib.pyplot as plt
6
+ from src.sampler import mcmc_sampler
7
+ from src.data_parser import Parser
8
+ matplotlib.use('Agg')
9
+ font = {'size': 30}
10
+ matplotlib.rc('font', **font)
11
+
12
+
13
+ def sample(country, d, n_iterations, burnin):
14
+ P = parser.parse_population(country)
15
+ start_date = "2020-03-01"
16
+ end_date = "2020-06-15"
17
+ i, r = parser.parse_data(start_date, end_date, country)
18
+ i, r = i.values, r.values
19
+ s = np.repeat(P, i.shape[0]) - i - r
20
+
21
+ p, lam, t, lam_ar, t_ar = mcmc_sampler(s, i, d, P, n_iterations, burnin,
22
+ M=3, sigma=0.01,
23
+ alpha=np.repeat(2, d),
24
+ beta=np.repeat(0.1, d),
25
+ a=1, b=1, phi=0.995)
26
+
27
+ lam_estimated = np.average(lam, axis=1)
28
+ t_estimated = np.average(t, axis=1)
29
+ p_estimated = np.average(p)
30
+
31
+ # Plot the series.
32
+ fig, axs = plt.subplots(nrows=2)
33
+ fig.set_figheight(30)
34
+ fig.set_figwidth(30)
35
+ ax1_left = axs[0]
36
+ ax2_left = axs[1]
37
+ ax1_right = ax1_left.twinx()
38
+ ax2_right = ax2_left.twinx()
39
+ ax1_left.plot(s, color='red', label="Susceptible")
40
+ ax1_right.plot(i, color='blue', label="Infected")
41
+ ax1_left.legend(loc=2)
42
+ ax1_right.legend(loc=1)
43
+ delta_i = -np.diff(s)
44
+ ax2_left.plot(delta_i, color="blue", label="Newly Infected Individuals")
45
+ ax2_right.plot(i, color='blue', linestyle='dashed', label="Infected")
46
+ ax2_left.legend(loc=2)
47
+ ax2_right.legend(loc=1)
48
+ # Display obtained breakpoints on plot.
49
+ for breakpoint in np.average(t, axis=1):
50
+ ax1_right.axvline(breakpoint, color="green")
51
+ ax2_right.axvline(breakpoint, color="green")
52
+
53
+ # Get output strings
54
+ lam_string = ""
55
+ for j, lam_component in enumerate(lam_estimated):
56
+ lam_string += f"Component {j+1}: {round(lam_component, 4)}\n"
57
+ lam_string = lam_string.rstrip()
58
+ t_string = ""
59
+ for j, t_component in enumerate(t_estimated):
60
+ t_string += f"Breakpoint {j+1}: {int(round(t_component, 0))}\n"
61
+ t_string = t_string.rstrip()
62
+ p_string = f"{round(p_estimated, 4)}"
63
+
64
+ return fig, lam_string, t_string, p_string
65
+
66
+
67
+ if __name__ == "__main__":
68
+ confirmed_path = "confirmed.csv"
69
+ deaths_path = "deaths.csv"
70
+ recovered_path = "recovered.csv"
71
+ population_path = "population.csv"
72
+ data_path = os.path.join(os.getcwd(), "data")
73
+ parser = Parser(os.path.join(data_path, confirmed_path),
74
+ os.path.join(data_path, deaths_path),
75
+ os.path.join(data_path, recovered_path),
76
+ os.path.join(data_path, population_path))
77
+ countries = parser.countries
78
+
79
+ # Inputs
80
+ dropdown = gr.Dropdown(choices=countries, value="Germany",
81
+ label="Select the Country")
82
+ slider = gr.Slider(minimum=1, maximum=5, value=3, step=1,
83
+ label="Select the Number of Breakpoints")
84
+ n_iterations = gr.Number(value=10000, precision=0,
85
+ label="Select the Number of iterations")
86
+ burnin = gr.Number(value=1000, precision=0,
87
+ label="Select the Number of Burn-In Iterations",
88
+ info="Such iterations will be discarded.")
89
+
90
+ # Outputs
91
+ plot = gr.Plot(label="Results")
92
+ lam = gr.Text(label="Estimated Lambda")
93
+ t = gr.Text(label="Estimated Breakpoints")
94
+ p = gr.Text(label="Estimated Recovery Probability")
95
+ interface = gr.Interface(sample,
96
+ inputs=[dropdown, slider, n_iterations, burnin],
97
+ outputs=[plot, lam, t, p])
98
+ interface.launch()
data/confirmed.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/deaths.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/population.csv ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Country,Population
2
+ Aruba,106585
3
+ Africa Eastern and Southern,685112705
4
+ Afghanistan,38972230
5
+ Africa Western and Central,466189102
6
+ Angola,33428486
7
+ Albania,2837849
8
+ Andorra,77700
9
+ Arab World,449228296
10
+ United Arab Emirates,9287289
11
+ Argentina,45376763
12
+ Armenia,2805608
13
+ American Samoa,46189
14
+ Antigua and Barbuda,92664
15
+ Australia,25655289
16
+ Austria,8916864
17
+ Azerbaijan,10093121
18
+ Burundi,12220227
19
+ Belgium,11538604
20
+ Benin,12643123
21
+ Burkina Faso,21522626
22
+ Bangladesh,167420951
23
+ Bulgaria,6934015
24
+ Bahrain,1477469
25
+ Bahamas,406471
26
+ Bosnia and Herzegovina,3318407
27
+ Belarus,9379952
28
+ Belize,394921
29
+ Bermuda,63893
30
+ Bolivia,11936162
31
+ Brazil,213196304
32
+ Barbados,280693
33
+ Brunei,441725
34
+ Bhutan,772506
35
+ Botswana,2546402
36
+ Central African Republic,5343020
37
+ Canada,38037204
38
+ Central Europe and the Baltics,102180124
39
+ Switzerland,8638167
40
+ Channel Islands,171113
41
+ Chile,19300315
42
+ China,1411100000
43
+ Cote d'Ivoire,26811790
44
+ Cameroon,26491087
45
+ "Congo, Dem. Rep.",92853164
46
+ "Congo, Rep.",5702174
47
+ Colombia,50930662
48
+ Comoros,806166
49
+ Cabo Verde,582640
50
+ Costa Rica,5123105
51
+ Caribbean small states,7444768
52
+ Cuba,11300698
53
+ Curacao,154947
54
+ Cayman Islands,67311
55
+ Cyprus,1237537
56
+ Czechia,10697858
57
+ Germany,83160871
58
+ Djibouti,1090156
59
+ Dominica,71995
60
+ Denmark,5831404
61
+ Dominican Republic,10999664
62
+ Algeria,43451666
63
+ East Asia & Pacific (excluding high income),2116424876
64
+ Early-demographic dividend,-2147483648
65
+ East Asia & Pacific,-2147483648
66
+ Europe & Central Asia (excluding high income),400811771
67
+ Europe & Central Asia,923103879
68
+ Ecuador,17588595
69
+ Egypt,107465134
70
+ Euro area,342913447
71
+ Eritrea,3555868
72
+ Spain,47365655
73
+ Estonia,1329522
74
+ Ethiopia,117190911
75
+ European Union,447692315
76
+ Fragile and conflict affected situations,979418527
77
+ Finland,5529543
78
+ Fiji,920422
79
+ France,67571107
80
+ Faroe Islands,52415
81
+ Micronesia,112106
82
+ Gabon,2292573
83
+ United Kingdom,67081000
84
+ Georgia,3722716
85
+ Ghana,32180401
86
+ Gibraltar,32709
87
+ Guinea,13205153
88
+ Gambia,2573995
89
+ Guinea-Bissau,2015828
90
+ Equatorial Guinea,1596049
91
+ Greece,10698599
92
+ Grenada,123663
93
+ Greenland,56367
94
+ Guatemala,16858333
95
+ Guam,169231
96
+ Guyana,797202
97
+ High income,1240900955
98
+ "Hong Kong SAR, China",7481000
99
+ Honduras,10121763
100
+ Heavily indebted poor countries (HIPC),838066650
101
+ Croatia,4047680
102
+ Haiti,11306801
103
+ Hungary,9750149
104
+ IBRD only,-2147483648
105
+ IDA & IBRD total,-2147483648
106
+ IDA total,1738306807
107
+ IDA blend,582637127
108
+ Indonesia,271857970
109
+ IDA only,1155669680
110
+ Isle of Man,84046
111
+ India,1396387127
112
+ Ireland,4985382
113
+ Iran,87290193
114
+ Iraq,42556984
115
+ Iceland,366463
116
+ Israel,9215100
117
+ Italy,59438851
118
+ Jamaica,2820436
119
+ Jordan,10928721
120
+ Japan,126261000
121
+ Kazakhstan,18755666
122
+ Kenya,51985780
123
+ Kyrgyzstan,6579900
124
+ Cambodia,16396860
125
+ Kiribati,126463
126
+ Saint Kitts and Nevis,47642
127
+ "Korea, South",51836239
128
+ Kuwait,4360444
129
+ Latin America & Caribbean (excluding high income),588808380
130
+ Laos,7319399
131
+ Lebanon,5662923
132
+ Liberia,5087584
133
+ Libya,6653942
134
+ Saint Lucia,179237
135
+ Latin America & Caribbean,650534967
136
+ Least developed countries: UN classification,1073743450
137
+ Low income,699186538
138
+ Liechtenstein,38756
139
+ Sri Lanka,21919000
140
+ Lower middle income,-2147483648
141
+ Low & middle income,-2147483648
142
+ Lesotho,2254100
143
+ Late-demographic dividend,-2147483648
144
+ Lithuania,2794885
145
+ Luxembourg,630419
146
+ Latvia,1900449
147
+ "Macao SAR, China",676283
148
+ St. Martin (French part),32553
149
+ Morocco,36688772
150
+ Monaco,36922
151
+ Moldova,2635130
152
+ Madagascar,28225177
153
+ Maldives,514438
154
+ Middle East & North Africa,479966649
155
+ Mexico,125998302
156
+ Marshall Islands,43413
157
+ Middle income,-2147483648
158
+ North Macedonia,2072531
159
+ Mali,21224040
160
+ Malta,515332
161
+ Burma,53423198
162
+ Middle East & North Africa (excluding high income),411810124
163
+ Montenegro,621306
164
+ Mongolia,3294335
165
+ Northern Mariana Islands,49587
166
+ Mozambique,31178239
167
+ Mauritania,4498604
168
+ Mauritius,1265740
169
+ Malawi,19377061
170
+ Malaysia,33199993
171
+ North America,369602177
172
+ Namibia,2489098
173
+ New Caledonia,271130
174
+ Niger,24333639
175
+ Nigeria,208327405
176
+ Nicaragua,6755895
177
+ Netherlands,17441500
178
+ Norway,5379475
179
+ Nepal,29348627
180
+ Nauru,12315
181
+ New Zealand,5090200
182
+ OECD members,1370241530
183
+ Oman,4543399
184
+ Other small states,32381190
185
+ Pakistan,227196741
186
+ Panama,4294396
187
+ Peru,33304756
188
+ Philippines,112190977
189
+ Palau,17972
190
+ Papua New Guinea,9749640
191
+ Poland,37899070
192
+ Pre-demographic dividend,984213438
193
+ Puerto Rico,3281538
194
+ "Korea, Dem. People's Rep.",25867467
195
+ Portugal,10297081
196
+ Paraguay,6618695
197
+ West Bank and Gaza,4803269
198
+ Pacific island small states,2566819
199
+ Post-demographic dividend,1117443485
200
+ French Polynesia,301920
201
+ Qatar,2760385
202
+ Romania,19265250
203
+ Russia,144073139
204
+ Rwanda,13146362
205
+ South Asia,1882531620
206
+ Saudi Arabia,35997107
207
+ Sudan,44440486
208
+ Senegal,16436120
209
+ Singapore,5685807
210
+ Solomon Islands,691191
211
+ Sierra Leone,8233970
212
+ El Salvador,6292731
213
+ San Marino,34007
214
+ Somalia,16537016
215
+ Serbia,6899126
216
+ Sub-Saharan Africa (excluding high income),1151203345
217
+ South Sudan,10606227
218
+ Sub-Saharan Africa,1151301807
219
+ Small states,42392777
220
+ Sao Tome and Principe,218641
221
+ Suriname,607065
222
+ Slovakia,5458827
223
+ Slovenia,2102419
224
+ Sweden,10353442
225
+ Eswatini,1180655
226
+ Sint Maarten (Dutch part),42310
227
+ Seychelles,98462
228
+ Syria,20772595
229
+ Turks and Caicos Islands,44276
230
+ Chad,16644701
231
+ East Asia & Pacific (IDA & IBRD countries),2090523535
232
+ Europe & Central Asia (IDA & IBRD countries),462023771
233
+ Togo,8442580
234
+ Thailand,71475664
235
+ Tajikistan,9543207
236
+ Turkmenistan,6250438
237
+ Latin America & the Caribbean (IDA & IBRD countries),634680385
238
+ Timor-Leste,1299995
239
+ Middle East & North Africa (IDA & IBRD countries),407006855
240
+ Tonga,105254
241
+ South Asia (IDA & IBRD),1882531620
242
+ Sub-Saharan Africa (IDA & IBRD countries),1151301807
243
+ Trinidad and Tobago,1518147
244
+ Tunisia,12161723
245
+ Turkey,84135428
246
+ Tuvalu,11069
247
+ Tanzania,61704518
248
+ Uganda,44404611
249
+ Ukraine,44132049
250
+ Upper middle income,-2147483648
251
+ Uruguay,3429086
252
+ US,331501080
253
+ Uzbekistan,34232050
254
+ Saint Vincent and the Grenadines,104632
255
+ Venezuela,28490453
256
+ British Virgin Islands,30910
257
+ Virgin Islands (U.S.),106290
258
+ Vietnam,96648685
259
+ Vanuatu,311685
260
+ World,-2147483648
261
+ Samoa,214929
262
+ Kosovo,1790133
263
+ Yemen,32284046
264
+ South Africa,58801927
265
+ Zambia,18927715
266
+ Zimbabwe,15669666
data/recovered.csv ADDED
The diff for this file is too large to render. See raw diff
 
src/data_parser.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from datetime import datetime
4
+
5
+
6
+ class Parser:
7
+ def __init__(self, filename_confirmed,
8
+ filename_deaths,
9
+ filename_recovered,
10
+ filename_population):
11
+ self.confirmed = self.read_csv(filename_confirmed)
12
+ self.deaths = self.read_csv(filename_deaths)
13
+ self.recovered = self.read_csv(filename_recovered)
14
+ self.population = self.read_population(filename_population)
15
+ self.countries = list(np.intersect1d(self.confirmed.columns.values,
16
+ self.population.index.values))
17
+
18
+ def read_csv(self, filename):
19
+ # Create pandas dataframe from .csv
20
+ data = pd.read_csv(filename)
21
+
22
+ # Manipulate the dataframe to have dates as row indices and country
23
+ # names as column names
24
+ data = data.set_index("Country/Region")
25
+ data = data.T
26
+ data.index = pd.to_datetime(data.index)
27
+
28
+ return data
29
+
30
+ def parse_data(self, start_date, end_date, country):
31
+ self.validate_date(start_date)
32
+ self.validate_date(end_date)
33
+ self.validate_country(country)
34
+
35
+ delta_i = self.confirmed.loc[:end_date, country].diff().dropna()
36
+ delta_i = delta_i.astype(int)
37
+ r = (self.deaths.loc[:end_date, country]
38
+ + self.recovered.loc[:end_date, country])
39
+ delta_r = r.diff().dropna().astype(int)
40
+ i = (delta_i - delta_r).cumsum()
41
+ return i[start_date:], r[start_date:]
42
+
43
+ def read_population(self, filename):
44
+ # Create pandas dataframe from .csv
45
+ data = pd.read_csv(filename)
46
+ data = data.set_index("Country")
47
+
48
+ return data
49
+
50
+ def parse_population(self, country):
51
+ population = self.population.loc[country, "Population"]
52
+
53
+ return population
54
+
55
+ def validate_date(self, date_text):
56
+ try:
57
+ datetime.strptime(date_text, '%Y-%m-%d')
58
+ except ValueError:
59
+ raise ValueError("Incorrect data format, should be YYYY-MM-DD!")
60
+
61
+ def validate_country(self, country):
62
+ if country not in self.countries:
63
+ raise ValueError("Country not in list!")
src/sampler.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utility_functions import log_pi_lambda, log_pi_t
2
+ import numpy as np
3
+
4
+
5
+ def update_lambda(lam, t, s, i, P, sigma, alpha, beta, phi):
6
+ # This function updates the parameter vector lambda.
7
+ # INPUT:
8
+ # - lam: array of the values of lambda;
9
+ # - t: array of the breakpoints;
10
+ # - s: array of susceptible individuals during time;
11
+ # - i: array of infected individuals during time;
12
+ # - P: total number of individuals;
13
+ # - sigma: algorithm parameter for the proposal of a new candidate lambda;
14
+ # - alpha, beta: hyperparameters of the prior of lambda;
15
+ # - phi parameter phi of the model.
16
+ # OUTPUT:
17
+ # - lam: update array;
18
+ # - accept: number of acceted candidates.
19
+ # NOTES: The update is done component-wise in a sequential manner.
20
+
21
+ current = np.copy(lam) # Get the current state of the chain.
22
+ candidate = np.copy(current) # Initialize the new candidate.
23
+
24
+ # For every component of the parameter vector, we tweak such component
25
+ # according to the chosen proposal and then update the chain according
26
+ # to the computed acceptance rate.
27
+ accepted = 0 # Initialize the count of accepted candidates.
28
+ for j in range(current.shape[0]):
29
+ # Tweak the j-th component.
30
+ candidate[j] = candidate[j] + sigma * np.random.normal()
31
+ # Compute the acceptance rate.
32
+ log_alpha = (log_pi_lambda(candidate, t, s, i, P, alpha, beta, phi)
33
+ - log_pi_lambda(current, t, s, i, P, alpha, beta, phi))
34
+
35
+ # If the candidate is accepted, we move the chain (current = candidate)
36
+ # and increase the count of accepted candidates. Otherwise, we reject
37
+ # the candidate and the chain does not move from the current state
38
+ # (candidate = current).
39
+ if log_alpha > np.log(np.random.uniform()):
40
+ current = np.copy(candidate)
41
+ accepted = accepted + 1
42
+ else:
43
+ candidate = np.copy(current)
44
+
45
+ return current.reshape(-1, 1), accepted
46
+
47
+
48
+ def update_t(lam, t, s, i, P, M, phi):
49
+ # This function updates the parameter vector lambda.
50
+ # INPUT:
51
+ # - lam: array of the values of lambda;
52
+ # - t: array of the breakpoints;
53
+ # - s: array of susceptible individuals during time;
54
+ # - i: array of infected individuals during time;
55
+ # - P: total number of individuals;
56
+ # - M: algorithm parameter for the proposal of a new candidate t;
57
+ # - phi parameter phi of the model.
58
+ # OUTPUT:
59
+ # - lam: update array;
60
+ # - accept: number of acceted candidates.
61
+ # NOTES: The update is done component-wise in a sequential manner.
62
+
63
+ current = np.copy(t) # Get the current state of the chain.
64
+ candidate = np.copy(current) # Initialize the new candidate.
65
+
66
+ # For every component of the parameter vector, we tweak such component
67
+ # according to the chosen proposal and then update the chain according
68
+ # to the computed acceptance rate.
69
+ accepted = 0 # Initialize the count of accepted candidates.
70
+ for j in range(current.shape[0]):
71
+ # Tweak the j-th component.
72
+ candidate[j] = candidate[j] + np.random.choice(np.arange(-M, M + 1))
73
+ # Compute the acceptance rate.
74
+ log_alpha = (log_pi_t(lam, candidate, s, i, P, phi)
75
+ - log_pi_t(lam, current, s, i, P, phi))
76
+
77
+ # If the candidate is accepted, we move the chain (current = candidate)
78
+ # and increase the count of accepted candidates. Otherwise, we reject
79
+ # the candidate and the chain does not move from the current state
80
+ # (candidate = current).
81
+ if log_alpha > np.log(np.random.uniform()):
82
+ current = np.copy(candidate)
83
+ accepted = accepted + 1
84
+ else:
85
+ candidate = np.copy(current)
86
+
87
+ return current.reshape(-1, 1), accepted
88
+
89
+
90
+ def mcmc_sampler(s, i, d, P, n_iterations, burnin, M, sigma,
91
+ alpha, beta, a, b, phi):
92
+ # This function implement the hybrid MCMC sampler.
93
+ # INPUT:
94
+ # - s: array of susceptible individuals during time;
95
+ # - i: array of infected individuals during time;
96
+ # - d: number of breakpoints;
97
+ # - P: total number of individuals;
98
+ # - n_iterations: number of iterations for the algorithm;
99
+ # - burnin: number of burnin iterations to discard;
100
+ # - M: algorithm parameter for the proposal of a new candidate t;
101
+ # - sigma: algorithm parameter for the proposal of a new candidate lambda;
102
+ # - alpha, beta: hyperparameters of the prior of lambda;
103
+ # - a, b: hyperparameters of the prior of p;
104
+ # - phi: parameter phi of the model.
105
+ # OUTPUT:
106
+ # - p: simulated chain for the probability of removal from
107
+ # infected population;
108
+ # - lam: simulated chain for lambda;
109
+ # - t: simulated chain for the breakpoints.
110
+
111
+ T = s.shape[0] - 1 # Index of the final time instant.
112
+
113
+ # Initialize the parameters.
114
+
115
+ # The initial value of p is drawn from the prior distribution.
116
+ p = np.random.beta(a, b, size=(1, 1))
117
+ # Each of the d breakpoints (t_i) is drawn randomly (without replacement)
118
+ # between 1 and T-1. The obtained vector is then sorted to make sure
119
+ # that t_1 < t_2 < ... < t_d.
120
+ t = np.sort(np.random.choice(np.arange(1, T), size=d-1, replace=False))
121
+ t = t.reshape(-1, 1)
122
+ # Each of the lambda_i's is drawn independently from
123
+ # its prior distribution.
124
+ lam = np.random.gamma(alpha, beta)
125
+ lam = lam.reshape(-1, 1)
126
+
127
+ # Compute the hyperparameters of the posterior of p.
128
+ a_new = a + i[0] - i[-1] + s[0] - s[-1]
129
+ b_new = b + np.sum(i[1:]) + s[-1] - s[0]
130
+
131
+ # Initialize the count of accepted candidates for lambda and t.
132
+ a_lam = 0
133
+ a_t = 0
134
+
135
+ # Run the chain
136
+ for _ in range(n_iterations):
137
+ # Update p by sampling from its posterior.
138
+ p = np.hstack((p, np.random.beta(a_new, b_new, size=(1, 1))))
139
+ # Update lam via Metropolis-Hastings step.
140
+ new_lam, accepted_lam = update_lambda(lam[:, -1], t[:, -1], s, i, P,
141
+ sigma, alpha, beta, phi)
142
+ # Update t via Metropolis-Hastings step.
143
+ new_t, accepted_t = update_t(lam[:, -1], t[:, -1], s, i, P, M, phi)
144
+ lam = np.hstack((lam, new_lam))
145
+ t = np.hstack((t, new_t))
146
+
147
+ # Update the counts of accepted candidates for lambda and t.
148
+ a_lam = a_lam + accepted_lam
149
+ a_t = a_t + accepted_t
150
+
151
+ # Compute the acceptance rates for lambda and t.
152
+ lam_ar = a_lam / n_iterations / d
153
+ t_ar = a_t / n_iterations / (d-1)
154
+
155
+ # Discard burn-in iterations.
156
+ p = p[:, burnin:]
157
+ lam = lam[:, burnin:]
158
+ t = t[:, burnin:]
159
+
160
+ return p, lam, t, lam_ar, t_ar
src/utility_functions.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.special import gammaln
3
+
4
+
5
+ def lambda_time(lam, t, time):
6
+ # This function computes the value of lambda.
7
+ # INPUT:
8
+ # - lam, t: arrays defining the (piecewise constant) function lambda(t);
9
+ # - time: time instants at which we want to evaluate lambda.
10
+ # OUTPUT:
11
+ # - lambda_time: value of lambda.
12
+ # NOTES: The function is vectorized in the array time. Indeed,
13
+ # it allows to compute kappa for all the time instants in the array time.
14
+
15
+ lambda_time = lam[np.searchsorted(t, time, side="right")]
16
+ return lambda_time
17
+
18
+
19
+ def compute_kappa_time(s_t, i_t, lam, t, time, phi, P):
20
+ # This function computes the value of kappa.
21
+ # INPUT:
22
+ # - s_t: number of susceptible individuals;
23
+ # - i_t: number of infected individuals;
24
+ # - lam, t: arrays defining the (piecewise constant) function lambda(t);
25
+ # - time: time instants at which we want to evaluate kappa;
26
+ # - phi: parameter phi of the model;
27
+ # - P: total number of individuals.
28
+ # OUTPUT:
29
+ # - kappa_time: value of kappa.
30
+ # NOTES: The function is vectorized in the arrays time, s_t and i_t.
31
+ # Indeed, it allows to compute kappa for all the time instants in
32
+ # the array time. It is required that time, s_t and i_t
33
+ # have the same dimension.
34
+
35
+ p_si_t = 1 - np.exp(-np.multiply(lambda_time(lam, t, time), i_t) / P)
36
+ kappa_time = (1/phi - 1) * np.multiply(s_t, p_si_t)
37
+ return kappa_time
38
+
39
+
40
+ def log_pi_lambda(lam, t, s, i, P, alpha, beta, phi):
41
+ # This function computes the (log-) full-conditional
42
+ # of the parameter vector lambda.
43
+ # INPUT:
44
+ # - lam: array of the values of lambda;
45
+ # - t: array of the breakpoints;
46
+ # - s: array of susceptible individuals during time;
47
+ # - i: array of infected individuals during time;
48
+ # - P: total number of individuals;
49
+ # - alpha, beta: hyperparameters of the prior of lambda;
50
+ # - phi: parameter phi of the model.
51
+ # OUTPUT:
52
+ # - result: (log-) full-conditional of lambda evaluated at lam.
53
+
54
+ T = s.shape[0] - 1 # Index of the final time instant.
55
+ time = np.arange(T + 1) # Array of all time instants.
56
+
57
+ # First, we initialize the result to -inf. If all the components of the
58
+ # vector lam are positive (i.e., the vector is admissible) we compute the
59
+ # (log-) full-conditional of lambda evaluated at lam and return the result.
60
+ result = np.NINF
61
+ if all(lam > 0):
62
+ kappa_vec = compute_kappa_time(s[:-1], i[:-1], lam,
63
+ t, time[:-1], phi, P)
64
+ result = (np.sum(gammaln(-np.diff(s) + kappa_vec)
65
+ + kappa_vec * np.log(1 - phi)
66
+ - gammaln(kappa_vec))
67
+ + np.sum(np.log(np.power(lam, alpha-1))
68
+ - np.multiply(beta, lam)))
69
+ return result
70
+
71
+
72
+ def log_pi_t(lam, t, s, i, P, phi):
73
+ # This function computes the (log-) full-conditional
74
+ # of the parameter vector t.
75
+ # INPUT:
76
+ # - lam: array of the values of lambda;
77
+ # - t: array of the breakpoints;
78
+ # - s: array of susceptible individuals during time;
79
+ # - i: array of infected individuals during time;
80
+ # - P: total number of individuals;
81
+ # - phi parameter phi of the model.
82
+ # OUTPUT:
83
+ # - result: (log-) full-conditional of t evaluated at t.
84
+
85
+ T = s.shape[0] - 1 # Index of the final time instant.
86
+ time = np.arange(T + 1) # Array of all time instants.
87
+
88
+ # First, we initialize the result to -inf. If we have that
89
+ # 0 < t_1 < t_2 < ... < t_(d-1) < T (i.e., the vector is admissible)
90
+ # we compute the (log-) full-conditional of t evaluated at t
91
+ # and return the result.
92
+ result = np.NINF
93
+ if np.all(np.diff(t) > 0) and t[0] > 0 and t[-1] < T:
94
+ kappa_vec = compute_kappa_time(s[:-1], i[:-1], lam,
95
+ t, time[:-1], phi, P)
96
+ result = np.sum(gammaln(-np.diff(s) + kappa_vec)
97
+ + kappa_vec * np.log(1 - phi)
98
+ - gammaln(kappa_vec))
99
+ return result
100
+
101
+
102
+ def simulate_data(T, lam, t, s_0, i_0, p_r, phi):
103
+ # This function simulates data according to the process described above.
104
+ # INPUT:
105
+ # - T: index of the final time instant;
106
+ # - lam, t: (true) arrays defining the function lambda(t);
107
+ # - s_0: initial number of susceptible individuals;
108
+ # - i_0: initial number of infected individuals;
109
+ # - p_r: (true) probability of removal from infected population;
110
+ # - phi: parameter phi of the model.
111
+ # OUTPUT:
112
+ # - s: array of susceptible individuals during time;
113
+ # - i: array of infected individuals during time.
114
+
115
+ # Compute the total number of individuals.
116
+ P = s_0 + i_0
117
+
118
+ # Initialize the arrays s and t.
119
+ s = np.array([s_0])
120
+ i = np.array([i_0])
121
+
122
+ time = np.arange(T + 1)
123
+ for t_i in time:
124
+ # Draw a realization of delta_r.
125
+ delta_r = np.random.binomial(i[-1], p_r)
126
+ # Compute the kappa parameter at time t_i.
127
+ kappa = compute_kappa_time(s[-1], i[-1], lam, t, t_i, phi, P)
128
+ # Draw a realization of delta_i.
129
+ delta_i = np.random.negative_binomial(kappa, 1 - phi)
130
+
131
+ # Update s and i according to the model.
132
+ s = np.append(s, s[-1] - delta_i)
133
+ i = np.append(i, i[-1] + delta_i - delta_r)
134
+ return s, i