Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spaces | |
import polars as pl | |
from datetime import datetime | |
from functools import lru_cache | |
from transformers import pipeline | |
from typing import Dict | |
import requests | |
import xml.etree.ElementTree as ET | |
import time | |
from typing import List, Tuple, Dict | |
label_lookup = { | |
"LABEL_0": "NOT_CURATEABLE", | |
"LABEL_1": "CURATEABLE" | |
} | |
def get_pipeline(): | |
print("fetching model and building pipeline") | |
model_name = "afg1/pombe_curation_fold_0" | |
pipe = pipeline(model=model_name, task="text-classification") | |
return pipe | |
def classify_abstracts(abstracts:Dict[str, str],batch_size=64, progress=gr.Progress()) -> None: | |
pipe = get_pipeline() | |
# return classification | |
results = [] | |
total = len(abstracts) | |
# Convert dictionary to lists of PMIDs and abstracts, preserving order | |
pmids = list(abstracts.keys()) | |
abstract_texts = list(abstracts.values()) | |
# Initialize progress bar | |
progress(0, desc="Starting classification...") | |
# Process in batches | |
for i in range(0, total, batch_size): | |
# Get current batch | |
batch_abstracts = abstract_texts[i:i + batch_size] | |
batch_pmids = pmids[i:i + batch_size] | |
try: | |
# Classify the batch | |
classifications = pipe(batch_abstracts) | |
# Process each result in the batch | |
for pmid, classification in zip(batch_pmids, classifications): | |
results.append({ | |
'pmid': pmid, | |
'classification': label_lookup[classification['label']], | |
'score': classification['score'] | |
}) | |
# Update progress | |
progress(min((i + batch_size) / total, 1.0), | |
desc=f"Classified {min(i + batch_size, total)}/{total} abstracts...") | |
except Exception as e: | |
print(f"Error classifying batch starting at index {i}: {str(e)}") | |
continue | |
progress(1.0, desc="Classification complete!") | |
return results | |
def fetch_latest_canto_dump() -> pl.DataFrame: | |
""" | |
Read the latest pombase canto dump direct from the URL | |
""" | |
url = "https://curation.pombase.org/kmr44/canto_pombe_pubs.tsv" | |
return pl.read_csv(url, separator='\t') | |
def filter_new_hits(canto_pmcids: pl.DataFrame, new_pmcids: List[str]) -> List[str]: | |
""" | |
Convert the list of PMCIDs from the search to a dataframe and do an anti-join to | |
find new stuff | |
""" | |
new_pmids = pl.DataFrame({"pmid": new_pmcids}) | |
uncurated = new_pmids.join(canto_pmcids, on="pmid", how="anti") | |
return uncurated.get_column("pmid").to_list() | |
def fetch_abstracts_batch(pmids: List[str], batch_size: int = 200) -> Dict[str, str]: | |
""" | |
Fetch abstracts for a list of PMIDs in batches | |
Args: | |
pmids (List[str]): List of PMIDs to fetch abstracts for | |
batch_size (int): Number of PMIDs to process per batch | |
Returns: | |
Dict[str, str]: Dictionary mapping PMIDs to their abstracts | |
""" | |
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi" | |
all_abstracts = {} | |
# Process PMIDs in batches | |
for i in range(0, len(pmids), batch_size): | |
batch_pmids = pmids[i:i + batch_size] | |
pmids_string = ",".join(batch_pmids) | |
print(f"Processing batch {i//batch_size + 1} of {(len(pmids) + batch_size - 1)//batch_size}") | |
params = { | |
"db": "pubmed", | |
"id": pmids_string, | |
"retmode": "xml", | |
"rettype": "abstract" | |
} | |
try: | |
response = requests.get(base_url, params=params) | |
response.raise_for_status() | |
# Parse XML response | |
root = ET.fromstring(response.content) | |
# Iterate through each article in the batch | |
for article in root.findall(".//PubmedArticle"): | |
# Get PMID | |
pmid = article.find(".//PMID").text | |
# Find abstract text | |
abstract_element = article.find(".//Abstract/AbstractText") | |
if abstract_element is not None: | |
# Handle structured abstracts | |
if 'Label' in abstract_element.attrib: | |
abstract_sections = article.findall(".//Abstract/AbstractText") | |
abstract_text = "\n".join( | |
f"{section.attrib.get('Label', 'Abstract')}: {section.text}" | |
for section in abstract_sections | |
if section.text is not None | |
) | |
else: | |
# Simple abstract | |
abstract_text = abstract_element.text | |
else: | |
abstract_text = "" | |
if len(abstract_text) > 0: | |
all_abstracts[pmid] = abstract_text | |
# Respect NCBI's rate limits | |
time.sleep(0.34) | |
except requests.exceptions.RequestException as e: | |
print(f"Error accessing PubMed API for batch {i//batch_size + 1}: {str(e)}") | |
continue | |
except ET.ParseError as e: | |
print(f"Error parsing PubMed response for batch {i//batch_size + 1}: {str(e)}") | |
continue | |
except Exception as e: | |
print(f"Unexpected error in batch {i//batch_size + 1}: {str(e)}") | |
continue | |
print("All abstracts retrieved") | |
return all_abstracts | |
def chunk_search(query: str, year_start: int, year_end: int) -> List[str]: | |
""" | |
Perform a PubMed search for a specific year range | |
""" | |
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" | |
retmax = 9999 # Maximum allowed per query | |
date_query = f"{query} AND {year_start}:{year_end}[dp]" | |
params = { | |
"db": "pubmed", | |
"term": date_query, | |
"retmax": retmax, | |
"retmode": "xml" | |
} | |
response = requests.get(base_url, params=params) | |
response.raise_for_status() | |
root = ET.fromstring(response.content) | |
id_list = root.findall(".//Id") | |
return [id_elem.text for id_elem in id_list] | |
def search_pubmed(query: str, start_year:int, end_year: int) -> Tuple[str, List[str]]: | |
""" | |
Search PubMed and return all matching PMIDs by breaking the search into year chunks | |
""" | |
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" | |
all_pmids = [] | |
yield "Loading current canto dump...", gr.DownloadButton(visible=True, interactive=False) | |
canto_pmids = fetch_latest_canto_dump().select("pmid").with_columns(pl.col("pmid").str.split(":").list.last()) | |
try: | |
# First, get the total count | |
params = { | |
"db": "pubmed", | |
"term": query, | |
"retmax": 0, | |
"retmode": "xml" | |
} | |
response = requests.get(base_url, params=params) | |
response.raise_for_status() | |
root = ET.fromstring(response.content) | |
total_count = int(root.find(".//Count").text) | |
if total_count == 0: | |
return "No results found.", gr.DownloadButton(visible=True, interactive=False) | |
print(total_count) | |
# Break the search into year chunks | |
year_chunks = [] | |
chunk_size = 5 # Number of years per chunk | |
for year in range(start_year, end_year + 1, chunk_size): | |
chunk_end = min(year + chunk_size - 1, end_year) | |
year_chunks.append((year, chunk_end)) | |
# Search each year chunk | |
for start_year, end_year in year_chunks: | |
current_status = f"Searching years {start_year}-{end_year}..." | |
yield current_status, gr.DownloadButton(visible=True, interactive=False) | |
try: | |
chunk_pmids = chunk_search(query, start_year, end_year) | |
all_pmids.extend(chunk_pmids) | |
# Status update | |
yield f"Retrieved {len(all_pmids)} total results so far...", gr.DownloadButton(visible=True, interactive=False) | |
# Respect NCBI's rate limits | |
time.sleep(0.34) | |
except Exception as e: | |
print(f"Error processing years {start_year}-{end_year}: {str(e)}") | |
continue | |
uncurated_pmid = filter_new_hits(canto_pmids, all_pmids) | |
final_message = f"Retrieved {len(uncurated_pmid)} uncurated pmids!" | |
yield final_message, gr.DownloadButton(visible=True, interactive=False) | |
abstracts = fetch_abstracts_batch(uncurated_pmid) | |
yield f"Fetched {len(abstracts)} abstracts", gr.DownloadButton(visible=True, interactive=False) | |
classifications = pl.DataFrame(classify_abstracts(abstracts)) | |
print(classifications) | |
yield f"Classified {len(abstracts)} abstracts", gr.DownloadButton(visible=True, interactive=False) | |
classification_date = datetime.today().strftime('%Y%m%d') | |
csv_filename = f"classified_pmids_{classification_date}.csv" | |
yield "Write csv file...", gr.DownloadButton(visible=True, value=csv_filename, interactive=True) | |
classifications.write_csv(csv_filename) | |
yield final_message, gr.DownloadButton(visible=True, value=csv_filename, interactive=True) | |
except requests.exceptions.RequestException as e: | |
return f"Error accessing PubMed API: {str(e)}", all_pmids | |
except ET.ParseError as e: | |
return f"Error parsing PubMed response: {str(e)}", all_pmids | |
except Exception as e: | |
return f"Unexpected error: {str(e)}", all_pmids | |
def download_file(): | |
return gr.DownloadButton("Download results", visible=True, interactive=True) | |
# Create Gradio interface | |
def create_interface(): | |
with gr.Blocks() as app: | |
gr.Markdown("## PomBase PubMed PMID Search") | |
gr.Markdown("Enter a search term to find ALL relevant PubMed articles. Large searches may take several minutes.") | |
gr.Markdown("We then filter for new pmids, then classify them with a transformer model.") | |
with gr.Row(): | |
search_input = gr.Textbox( | |
label="Search Term", | |
placeholder="Enter search terms...", | |
lines=1, | |
value='pombe OR "fission yeast"' | |
) | |
search_button = gr.Button("Search") | |
with gr.Row(): | |
current_year = datetime.now().year + 1 | |
start_year = gr.Slider(label="Start year", minimum=1900, maximum=current_year, value=2020) | |
end_year = gr.Slider(label="End year", minimum=1900, maximum=current_year, value=current_year) | |
with gr.Row(): | |
status_output = gr.Textbox( | |
label="Status", | |
value="Ready to search..." | |
) | |
with gr.Row(): | |
d = gr.DownloadButton("Download results", visible=True, interactive=False) | |
with gr.Row(): | |
progress=gr.Progress() | |
d.click(download_file, None, d) | |
search_button.click( | |
fn=search_pubmed, | |
inputs=[search_input, start_year, end_year], | |
outputs=[status_output, d] | |
) | |
return app | |
# fetch_latest_canto_dump() | |
app = create_interface() | |
app.launch() | |