import gradio as gr import spaces import polars as pl from datetime import datetime from functools import lru_cache from transformers import pipeline from typing import Dict import requests import xml.etree.ElementTree as ET import time from typing import List, Tuple, Dict label_lookup = { "LABEL_0": "NOT_CURATEABLE", "LABEL_1": "CURATEABLE" } @spaces.GPU @lru_cache def get_pipeline(): print("fetching model and building pipeline") model_name = "afg1/pombe_curation_fold_0" pipe = pipeline(model=model_name, task="text-classification") return pipe @spaces.GPU def classify_abstracts(abstracts:Dict[str, str],batch_size=64, progress=gr.Progress()) -> None: pipe = get_pipeline() # return classification results = [] total = len(abstracts) # Convert dictionary to lists of PMIDs and abstracts, preserving order pmids = list(abstracts.keys()) abstract_texts = list(abstracts.values()) # Initialize progress bar progress(0, desc="Starting classification...") # Process in batches for i in range(0, total, batch_size): # Get current batch batch_abstracts = abstract_texts[i:i + batch_size] batch_pmids = pmids[i:i + batch_size] try: # Classify the batch classifications = pipe(batch_abstracts) # Process each result in the batch for pmid, classification in zip(batch_pmids, classifications): results.append({ 'pmid': pmid, 'classification': label_lookup[classification['label']], 'score': classification['score'] }) # Update progress progress(min((i + batch_size) / total, 1.0), desc=f"Classified {min(i + batch_size, total)}/{total} abstracts...") except Exception as e: print(f"Error classifying batch starting at index {i}: {str(e)}") continue progress(1.0, desc="Classification complete!") return results @lru_cache def fetch_latest_canto_dump() -> pl.DataFrame: """ Read the latest pombase canto dump direct from the URL """ url = "https://curation.pombase.org/kmr44/canto_pombe_pubs.tsv" return pl.read_csv(url, separator='\t') def filter_new_hits(canto_pmcids: pl.DataFrame, new_pmcids: List[str]) -> List[str]: """ Convert the list of PMCIDs from the search to a dataframe and do an anti-join to find new stuff """ new_pmids = pl.DataFrame({"pmid": new_pmcids}) uncurated = new_pmids.join(canto_pmcids, on="pmid", how="anti") return uncurated.get_column("pmid").to_list() def fetch_abstracts_batch(pmids: List[str], batch_size: int = 200) -> Dict[str, str]: """ Fetch abstracts for a list of PMIDs in batches Args: pmids (List[str]): List of PMIDs to fetch abstracts for batch_size (int): Number of PMIDs to process per batch Returns: Dict[str, str]: Dictionary mapping PMIDs to their abstracts """ base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi" all_abstracts = {} # Process PMIDs in batches for i in range(0, len(pmids), batch_size): batch_pmids = pmids[i:i + batch_size] pmids_string = ",".join(batch_pmids) print(f"Processing batch {i//batch_size + 1} of {(len(pmids) + batch_size - 1)//batch_size}") params = { "db": "pubmed", "id": pmids_string, "retmode": "xml", "rettype": "abstract" } try: response = requests.get(base_url, params=params) response.raise_for_status() # Parse XML response root = ET.fromstring(response.content) # Iterate through each article in the batch for article in root.findall(".//PubmedArticle"): # Get PMID pmid = article.find(".//PMID").text # Find abstract text abstract_element = article.find(".//Abstract/AbstractText") if abstract_element is not None: # Handle structured abstracts if 'Label' in abstract_element.attrib: abstract_sections = article.findall(".//Abstract/AbstractText") abstract_text = "\n".join( f"{section.attrib.get('Label', 'Abstract')}: {section.text}" for section in abstract_sections if section.text is not None ) else: # Simple abstract abstract_text = abstract_element.text else: abstract_text = "" if len(abstract_text) > 0: all_abstracts[pmid] = abstract_text # Respect NCBI's rate limits time.sleep(0.34) except requests.exceptions.RequestException as e: print(f"Error accessing PubMed API for batch {i//batch_size + 1}: {str(e)}") continue except ET.ParseError as e: print(f"Error parsing PubMed response for batch {i//batch_size + 1}: {str(e)}") continue except Exception as e: print(f"Unexpected error in batch {i//batch_size + 1}: {str(e)}") continue print("All abstracts retrieved") return all_abstracts def chunk_search(query: str, year_start: int, year_end: int) -> List[str]: """ Perform a PubMed search for a specific year range """ base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" retmax = 9999 # Maximum allowed per query date_query = f"{query} AND {year_start}:{year_end}[dp]" params = { "db": "pubmed", "term": date_query, "retmax": retmax, "retmode": "xml" } response = requests.get(base_url, params=params) response.raise_for_status() root = ET.fromstring(response.content) id_list = root.findall(".//Id") return [id_elem.text for id_elem in id_list] def search_pubmed(query: str, start_year:int, end_year: int) -> Tuple[str, List[str]]: """ Search PubMed and return all matching PMIDs by breaking the search into year chunks """ base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" all_pmids = [] yield "Loading current canto dump...", gr.DownloadButton(visible=True, interactive=False) canto_pmids = fetch_latest_canto_dump().select("pmid").with_columns(pl.col("pmid").str.split(":").list.last()) try: # First, get the total count params = { "db": "pubmed", "term": query, "retmax": 0, "retmode": "xml" } response = requests.get(base_url, params=params) response.raise_for_status() root = ET.fromstring(response.content) total_count = int(root.find(".//Count").text) if total_count == 0: return "No results found.", gr.DownloadButton(visible=True, interactive=False) print(total_count) # Break the search into year chunks year_chunks = [] chunk_size = 5 # Number of years per chunk for year in range(start_year, end_year + 1, chunk_size): chunk_end = min(year + chunk_size - 1, end_year) year_chunks.append((year, chunk_end)) # Search each year chunk for start_year, end_year in year_chunks: current_status = f"Searching years {start_year}-{end_year}..." yield current_status, gr.DownloadButton(visible=True, interactive=False) try: chunk_pmids = chunk_search(query, start_year, end_year) all_pmids.extend(chunk_pmids) # Status update yield f"Retrieved {len(all_pmids)} total results so far...", gr.DownloadButton(visible=True, interactive=False) # Respect NCBI's rate limits time.sleep(0.34) except Exception as e: print(f"Error processing years {start_year}-{end_year}: {str(e)}") continue uncurated_pmid = filter_new_hits(canto_pmids, all_pmids) final_message = f"Retrieved {len(uncurated_pmid)} uncurated pmids!" yield final_message, gr.DownloadButton(visible=True, interactive=False) abstracts = fetch_abstracts_batch(uncurated_pmid) yield f"Fetched {len(abstracts)} abstracts", gr.DownloadButton(visible=True, interactive=False) classifications = pl.DataFrame(classify_abstracts(abstracts)) print(classifications) yield f"Classified {len(abstracts)} abstracts", gr.DownloadButton(visible=True, interactive=False) classification_date = datetime.today().strftime('%Y%m%d') csv_filename = f"classified_pmids_{classification_date}.csv" yield "Write csv file...", gr.DownloadButton(visible=True, value=csv_filename, interactive=True) classifications.write_csv(csv_filename) yield final_message, gr.DownloadButton(visible=True, value=csv_filename, interactive=True) except requests.exceptions.RequestException as e: return f"Error accessing PubMed API: {str(e)}", all_pmids except ET.ParseError as e: return f"Error parsing PubMed response: {str(e)}", all_pmids except Exception as e: return f"Unexpected error: {str(e)}", all_pmids def download_file(): return gr.DownloadButton("Download results", visible=True, interactive=True) # Create Gradio interface def create_interface(): with gr.Blocks() as app: gr.Markdown("## PomBase PubMed PMID Search") gr.Markdown("Enter a search term to find ALL relevant PubMed articles. Large searches may take several minutes.") gr.Markdown("We then filter for new pmids, then classify them with a transformer model.") with gr.Row(): search_input = gr.Textbox( label="Search Term", placeholder="Enter search terms...", lines=1, value='pombe OR "fission yeast"' ) search_button = gr.Button("Search") with gr.Row(): current_year = datetime.now().year + 1 start_year = gr.Slider(label="Start year", minimum=1900, maximum=current_year, value=2020) end_year = gr.Slider(label="End year", minimum=1900, maximum=current_year, value=current_year) with gr.Row(): status_output = gr.Textbox( label="Status", value="Ready to search..." ) with gr.Row(): d = gr.DownloadButton("Download results", visible=True, interactive=False) with gr.Row(): progress=gr.Progress() d.click(download_file, None, d) search_button.click( fn=search_pubmed, inputs=[search_input, start_year, end_year], outputs=[status_output, d] ) return app # fetch_latest_canto_dump() app = create_interface() app.launch()