Andrew Green
Use more recent default start date
20fccc0
import gradio as gr
import spaces
import polars as pl
from datetime import datetime
from functools import lru_cache
from transformers import pipeline
from typing import Dict
import requests
import xml.etree.ElementTree as ET
import time
from typing import List, Tuple, Dict
label_lookup = {
"LABEL_0": "NOT_CURATEABLE",
"LABEL_1": "CURATEABLE"
}
@spaces.GPU
@lru_cache
def get_pipeline():
print("fetching model and building pipeline")
model_name = "afg1/pombe_curation_fold_0"
pipe = pipeline(model=model_name, task="text-classification")
return pipe
@spaces.GPU
def classify_abstracts(abstracts:Dict[str, str],batch_size=64, progress=gr.Progress()) -> None:
pipe = get_pipeline()
# return classification
results = []
total = len(abstracts)
# Convert dictionary to lists of PMIDs and abstracts, preserving order
pmids = list(abstracts.keys())
abstract_texts = list(abstracts.values())
# Initialize progress bar
progress(0, desc="Starting classification...")
# Process in batches
for i in range(0, total, batch_size):
# Get current batch
batch_abstracts = abstract_texts[i:i + batch_size]
batch_pmids = pmids[i:i + batch_size]
try:
# Classify the batch
classifications = pipe(batch_abstracts)
# Process each result in the batch
for pmid, classification in zip(batch_pmids, classifications):
results.append({
'pmid': pmid,
'classification': label_lookup[classification['label']],
'score': classification['score']
})
# Update progress
progress(min((i + batch_size) / total, 1.0),
desc=f"Classified {min(i + batch_size, total)}/{total} abstracts...")
except Exception as e:
print(f"Error classifying batch starting at index {i}: {str(e)}")
continue
progress(1.0, desc="Classification complete!")
return results
@lru_cache
def fetch_latest_canto_dump() -> pl.DataFrame:
"""
Read the latest pombase canto dump direct from the URL
"""
url = "https://curation.pombase.org/kmr44/canto_pombe_pubs.tsv"
return pl.read_csv(url, separator='\t')
def filter_new_hits(canto_pmcids: pl.DataFrame, new_pmcids: List[str]) -> List[str]:
"""
Convert the list of PMCIDs from the search to a dataframe and do an anti-join to
find new stuff
"""
new_pmids = pl.DataFrame({"pmid": new_pmcids})
uncurated = new_pmids.join(canto_pmcids, on="pmid", how="anti")
return uncurated.get_column("pmid").to_list()
def fetch_abstracts_batch(pmids: List[str], batch_size: int = 200) -> Dict[str, str]:
"""
Fetch abstracts for a list of PMIDs in batches
Args:
pmids (List[str]): List of PMIDs to fetch abstracts for
batch_size (int): Number of PMIDs to process per batch
Returns:
Dict[str, str]: Dictionary mapping PMIDs to their abstracts
"""
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
all_abstracts = {}
# Process PMIDs in batches
for i in range(0, len(pmids), batch_size):
batch_pmids = pmids[i:i + batch_size]
pmids_string = ",".join(batch_pmids)
print(f"Processing batch {i//batch_size + 1} of {(len(pmids) + batch_size - 1)//batch_size}")
params = {
"db": "pubmed",
"id": pmids_string,
"retmode": "xml",
"rettype": "abstract"
}
try:
response = requests.get(base_url, params=params)
response.raise_for_status()
# Parse XML response
root = ET.fromstring(response.content)
# Iterate through each article in the batch
for article in root.findall(".//PubmedArticle"):
# Get PMID
pmid = article.find(".//PMID").text
# Find abstract text
abstract_element = article.find(".//Abstract/AbstractText")
if abstract_element is not None:
# Handle structured abstracts
if 'Label' in abstract_element.attrib:
abstract_sections = article.findall(".//Abstract/AbstractText")
abstract_text = "\n".join(
f"{section.attrib.get('Label', 'Abstract')}: {section.text}"
for section in abstract_sections
if section.text is not None
)
else:
# Simple abstract
abstract_text = abstract_element.text
else:
abstract_text = ""
if len(abstract_text) > 0:
all_abstracts[pmid] = abstract_text
# Respect NCBI's rate limits
time.sleep(0.34)
except requests.exceptions.RequestException as e:
print(f"Error accessing PubMed API for batch {i//batch_size + 1}: {str(e)}")
continue
except ET.ParseError as e:
print(f"Error parsing PubMed response for batch {i//batch_size + 1}: {str(e)}")
continue
except Exception as e:
print(f"Unexpected error in batch {i//batch_size + 1}: {str(e)}")
continue
print("All abstracts retrieved")
return all_abstracts
def chunk_search(query: str, year_start: int, year_end: int) -> List[str]:
"""
Perform a PubMed search for a specific year range
"""
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
retmax = 9999 # Maximum allowed per query
date_query = f"{query} AND {year_start}:{year_end}[dp]"
params = {
"db": "pubmed",
"term": date_query,
"retmax": retmax,
"retmode": "xml"
}
response = requests.get(base_url, params=params)
response.raise_for_status()
root = ET.fromstring(response.content)
id_list = root.findall(".//Id")
return [id_elem.text for id_elem in id_list]
def search_pubmed(query: str, start_year:int, end_year: int) -> Tuple[str, List[str]]:
"""
Search PubMed and return all matching PMIDs by breaking the search into year chunks
"""
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
all_pmids = []
yield "Loading current canto dump...", gr.DownloadButton(visible=True, interactive=False)
canto_pmids = fetch_latest_canto_dump().select("pmid").with_columns(pl.col("pmid").str.split(":").list.last())
try:
# First, get the total count
params = {
"db": "pubmed",
"term": query,
"retmax": 0,
"retmode": "xml"
}
response = requests.get(base_url, params=params)
response.raise_for_status()
root = ET.fromstring(response.content)
total_count = int(root.find(".//Count").text)
if total_count == 0:
return "No results found.", gr.DownloadButton(visible=True, interactive=False)
print(total_count)
# Break the search into year chunks
year_chunks = []
chunk_size = 5 # Number of years per chunk
for year in range(start_year, end_year + 1, chunk_size):
chunk_end = min(year + chunk_size - 1, end_year)
year_chunks.append((year, chunk_end))
# Search each year chunk
for start_year, end_year in year_chunks:
current_status = f"Searching years {start_year}-{end_year}..."
yield current_status, gr.DownloadButton(visible=True, interactive=False)
try:
chunk_pmids = chunk_search(query, start_year, end_year)
all_pmids.extend(chunk_pmids)
# Status update
yield f"Retrieved {len(all_pmids)} total results so far...", gr.DownloadButton(visible=True, interactive=False)
# Respect NCBI's rate limits
time.sleep(0.34)
except Exception as e:
print(f"Error processing years {start_year}-{end_year}: {str(e)}")
continue
uncurated_pmid = filter_new_hits(canto_pmids, all_pmids)
final_message = f"Retrieved {len(uncurated_pmid)} uncurated pmids!"
yield final_message, gr.DownloadButton(visible=True, interactive=False)
abstracts = fetch_abstracts_batch(uncurated_pmid)
yield f"Fetched {len(abstracts)} abstracts", gr.DownloadButton(visible=True, interactive=False)
classifications = pl.DataFrame(classify_abstracts(abstracts))
print(classifications)
yield f"Classified {len(abstracts)} abstracts", gr.DownloadButton(visible=True, interactive=False)
classification_date = datetime.today().strftime('%Y%m%d')
csv_filename = f"classified_pmids_{classification_date}.csv"
yield "Write csv file...", gr.DownloadButton(visible=True, value=csv_filename, interactive=True)
classifications.write_csv(csv_filename)
yield final_message, gr.DownloadButton(visible=True, value=csv_filename, interactive=True)
except requests.exceptions.RequestException as e:
return f"Error accessing PubMed API: {str(e)}", all_pmids
except ET.ParseError as e:
return f"Error parsing PubMed response: {str(e)}", all_pmids
except Exception as e:
return f"Unexpected error: {str(e)}", all_pmids
def download_file():
return gr.DownloadButton("Download results", visible=True, interactive=True)
# Create Gradio interface
def create_interface():
with gr.Blocks() as app:
gr.Markdown("## PomBase PubMed PMID Search")
gr.Markdown("Enter a search term to find ALL relevant PubMed articles. Large searches may take several minutes.")
gr.Markdown("We then filter for new pmids, then classify them with a transformer model.")
with gr.Row():
search_input = gr.Textbox(
label="Search Term",
placeholder="Enter search terms...",
lines=1,
value='pombe OR "fission yeast"'
)
search_button = gr.Button("Search")
with gr.Row():
current_year = datetime.now().year + 1
start_year = gr.Slider(label="Start year", minimum=1900, maximum=current_year, value=2020)
end_year = gr.Slider(label="End year", minimum=1900, maximum=current_year, value=current_year)
with gr.Row():
status_output = gr.Textbox(
label="Status",
value="Ready to search..."
)
with gr.Row():
d = gr.DownloadButton("Download results", visible=True, interactive=False)
with gr.Row():
progress=gr.Progress()
d.click(download_file, None, d)
search_button.click(
fn=search_pubmed,
inputs=[search_input, start_year, end_year],
outputs=[status_output, d]
)
return app
# fetch_latest_canto_dump()
app = create_interface()
app.launch()