Andrew Green commited on
Commit
fb58829
·
1 Parent(s): 57b0310

Somewhat working prototype

Browse files
Files changed (1) hide show
  1. app.py +291 -0
app.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ import polars as pl
5
+ from datetime import datetime
6
+ from functools import lru_cache
7
+ from transformers import pipeline
8
+ from typing import Dict
9
+
10
+ label_lookup = {
11
+ "LABEL_0": "NOT_CURATEABLE",
12
+ "LABEL_1": "CURATEABLE"
13
+ }
14
+
15
+
16
+ @spaces.GPU
17
+ @lru_cache
18
+ def get_pipeline():
19
+ print("fetching model and building pipeline")
20
+ model_name = "afg1/pombe_curation_fold_0"
21
+
22
+
23
+ pipe = pipeline(model=model_name)
24
+ return pipe
25
+
26
+
27
+
28
+
29
+
30
+ @spaces.GPU
31
+ def classify_abstracts(abstracts:Dict[str, str]) -> None:
32
+ pipe = get_pipeline()
33
+ pmids = list(abstracts.keys())
34
+ classification = pipe(list(abstracts.values()))
35
+
36
+ for pmid, abs in zip(pmids, classification):
37
+ abs['label'] = label_lookup[abs['label']]
38
+ abs['pmid'] = pmid
39
+
40
+ return classification
41
+
42
+ import gradio as gr
43
+ import requests
44
+ import xml.etree.ElementTree as ET
45
+ import time
46
+ from typing import List, Tuple, Dict
47
+
48
+ @lru_cache
49
+ def fetch_latest_canto_dump() -> pl.DataFrame:
50
+ """
51
+ Read the latest pombase canto dump direct from the URL
52
+ """
53
+ url = "https://curation.pombase.org/kmr44/canto_pombe_pubs.tsv"
54
+ return pl.read_csv(url, separator='\t')
55
+
56
+
57
+ def filter_new_hits(canto_pmcids: pl.DataFrame, new_pmcids: List[str]) -> List[str]:
58
+ """
59
+ Convert the list of PMCIDs from the search to a dataframe and do an anti-join to
60
+ find new stuff
61
+
62
+ """
63
+ new_pmids = pl.DataFrame({"pmid": new_pmcids})
64
+
65
+ uncurated = new_pmids.join(canto_pmcids, on="pmid", how="anti")
66
+
67
+ return uncurated.get_column("pmid").to_list()
68
+
69
+
70
+ def fetch_abstracts_batch(pmids: List[str], batch_size: int = 200) -> Dict[str, str]:
71
+ """
72
+ Fetch abstracts for a list of PMIDs in batches
73
+
74
+ Args:
75
+ pmids (List[str]): List of PMIDs to fetch abstracts for
76
+ batch_size (int): Number of PMIDs to process per batch
77
+
78
+ Returns:
79
+ Dict[str, str]: Dictionary mapping PMIDs to their abstracts
80
+ """
81
+ base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
82
+ all_abstracts = {}
83
+
84
+ # Process PMIDs in batches
85
+ for i in range(0, len(pmids), batch_size):
86
+ batch_pmids = pmids[i:i + batch_size]
87
+ pmids_string = ",".join(batch_pmids)
88
+
89
+ print(f"Processing batch {i//batch_size + 1} of {(len(pmids) + batch_size - 1)//batch_size}")
90
+
91
+ params = {
92
+ "db": "pubmed",
93
+ "id": pmids_string,
94
+ "retmode": "xml",
95
+ "rettype": "abstract"
96
+ }
97
+
98
+ try:
99
+ response = requests.get(base_url, params=params)
100
+ response.raise_for_status()
101
+
102
+ # Parse XML response
103
+ root = ET.fromstring(response.content)
104
+
105
+ # Iterate through each article in the batch
106
+ for article in root.findall(".//PubmedArticle"):
107
+ # Get PMID
108
+ pmid = article.find(".//PMID").text
109
+
110
+ # Find abstract text
111
+ abstract_element = article.find(".//Abstract/AbstractText")
112
+
113
+ if abstract_element is not None:
114
+ # Handle structured abstracts
115
+ if 'Label' in abstract_element.attrib:
116
+ abstract_sections = article.findall(".//Abstract/AbstractText")
117
+ abstract_text = "\n".join(
118
+ f"{section.attrib.get('Label', 'Abstract')}: {section.text}"
119
+ for section in abstract_sections
120
+ if section.text is not None
121
+ )
122
+ else:
123
+ # Simple abstract
124
+ abstract_text = abstract_element.text
125
+ else:
126
+ abstract_text = "No abstract available"
127
+
128
+ all_abstracts[pmid] = abstract_text
129
+
130
+ # Respect NCBI's rate limits
131
+ time.sleep(0.34)
132
+
133
+ except requests.exceptions.RequestException as e:
134
+ print(f"Error accessing PubMed API for batch {i//batch_size + 1}: {str(e)}")
135
+ continue
136
+ except ET.ParseError as e:
137
+ print(f"Error parsing PubMed response for batch {i//batch_size + 1}: {str(e)}")
138
+ continue
139
+ except Exception as e:
140
+ print(f"Unexpected error in batch {i//batch_size + 1}: {str(e)}")
141
+ continue
142
+ print("All abstracts retrieved")
143
+ return all_abstracts
144
+
145
+ def chunk_search(query: str, year_start: int, year_end: int) -> List[str]:
146
+ """
147
+ Perform a PubMed search for a specific year range
148
+ """
149
+ base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
150
+ retmax = 9999 # Maximum allowed per query
151
+
152
+ date_query = f"{query} AND {year_start}:{year_end}[dp]"
153
+
154
+ params = {
155
+ "db": "pubmed",
156
+ "term": date_query,
157
+ "retmax": retmax,
158
+ "retmode": "xml"
159
+ }
160
+
161
+ response = requests.get(base_url, params=params)
162
+ response.raise_for_status()
163
+
164
+ root = ET.fromstring(response.content)
165
+ id_list = root.findall(".//Id")
166
+
167
+ return [id_elem.text for id_elem in id_list]
168
+
169
+ def search_pubmed(query: str, start_year:int, end_year: int) -> Tuple[str, List[str]]:
170
+ """
171
+ Search PubMed and return all matching PMIDs by breaking the search into year chunks
172
+ """
173
+ base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
174
+ all_pmids = []
175
+
176
+ yield "Loading current canto dump...", gr.DownloadButton(visible=True, interactive=False)
177
+ canto_pmids = fetch_latest_canto_dump().select("pmid").with_columns(pl.col("pmid").str.split(":").list.last())
178
+
179
+ try:
180
+ # First, get the total count
181
+ params = {
182
+ "db": "pubmed",
183
+ "term": query,
184
+ "retmax": 0,
185
+ "retmode": "xml"
186
+ }
187
+
188
+ response = requests.get(base_url, params=params)
189
+ response.raise_for_status()
190
+
191
+ root = ET.fromstring(response.content)
192
+ total_count = int(root.find(".//Count").text)
193
+ if total_count == 0:
194
+ return "No results found.", gr.DownloadButton(visible=True, interactive=False)
195
+ print(total_count)
196
+
197
+
198
+ # Break the search into year chunks
199
+ year_chunks = []
200
+ chunk_size = 5 # Number of years per chunk
201
+
202
+ for year in range(start_year, end_year + 1, chunk_size):
203
+ chunk_end = min(year + chunk_size - 1, end_year)
204
+ year_chunks.append((year, chunk_end))
205
+ # Search each year chunk
206
+ for start_year, end_year in year_chunks:
207
+ current_status = f"Searching years {start_year}-{end_year}..."
208
+
209
+ yield current_status, gr.DownloadButton(visible=True, interactive=False)
210
+
211
+ try:
212
+ chunk_pmids = chunk_search(query, start_year, end_year)
213
+ all_pmids.extend(chunk_pmids)
214
+
215
+ # Status update
216
+ yield f"Retrieved {len(all_pmids)} total results so far...", gr.DownloadButton(visible=True, interactive=False)
217
+
218
+ # Respect NCBI's rate limits
219
+ time.sleep(0.34)
220
+
221
+ except Exception as e:
222
+ print(f"Error processing years {start_year}-{end_year}: {str(e)}")
223
+ continue
224
+
225
+ uncurated_pmid = filter_new_hits(canto_pmids, all_pmids)
226
+ final_message = f"Retrieved {len(uncurated_pmid)} uncurated pmids!"
227
+ yield final_message, gr.DownloadButton(visible=True, interactive=False)
228
+ abstracts = fetch_abstracts_batch(uncurated_pmid)
229
+ yield f"Fetched {len(abstracts)} abstracts", gr.DownloadButton(visible=True, interactive=False)
230
+ classifications = pl.DataFrame(classify_abstracts(abstracts))
231
+ print(classifications)
232
+ yield f"Classified {len(abstracts)} abstracts", gr.DownloadButton(visible=True, interactive=False)
233
+
234
+ classification_date = datetime.today().strftime('%Y%m%d')
235
+ csv_filename = f"classified_pmids_{classification_date}.csv"
236
+ yield "Write csv file...", gr.DownloadButton(visible=True, value=csv_filename, interactive=True)
237
+ classifications.write_csv(csv_filename)
238
+
239
+ yield final_message, gr.DownloadButton(visible=True, value=csv_filename, interactive=True)
240
+
241
+ except requests.exceptions.RequestException as e:
242
+ return f"Error accessing PubMed API: {str(e)}", all_pmids
243
+ except ET.ParseError as e:
244
+ return f"Error parsing PubMed response: {str(e)}", all_pmids
245
+ except Exception as e:
246
+ return f"Unexpected error: {str(e)}", all_pmids
247
+
248
+ def download_file():
249
+ return gr.DownloadButton("Download results", visible=True, interactive=True)
250
+
251
+
252
+ # Create Gradio interface
253
+ def create_interface():
254
+ with gr.Blocks() as app:
255
+ gr.Markdown("## PomBase PubMed PMID Search")
256
+ gr.Markdown("Enter a search term to find ALL relevant PubMed articles. Large searches may take several minutes.")
257
+ gr.Markdown("We then filter for new pmids, then classify them with a transformer model.")
258
+
259
+ with gr.Row():
260
+ search_input = gr.Textbox(
261
+ label="Search Term",
262
+ placeholder="Enter search terms...",
263
+ lines=1
264
+ )
265
+ search_button = gr.Button("Search")
266
+ with gr.Row():
267
+ current_year = datetime.now().year + 1
268
+ start_year = gr.Slider(label="Start year", minimum=1900, maximum=current_year, value=1900)
269
+ end_year = gr.Slider(label="End year", minimum=1900, maximum=current_year, value=current_year)
270
+
271
+ with gr.Row():
272
+ status_output = gr.Textbox(
273
+ label="Status",
274
+ value="Ready to search..."
275
+ )
276
+ with gr.Row():
277
+ d = gr.DownloadButton("Download results", visible=True, interactive=False)
278
+
279
+ d.click(download_file, None, d)
280
+
281
+ search_button.click(
282
+ fn=search_pubmed,
283
+ inputs=[search_input, start_year, end_year],
284
+ outputs=[status_output, d]
285
+ )
286
+
287
+ return app
288
+
289
+ # fetch_latest_canto_dump()
290
+ app = create_interface()
291
+ app.launch()