Luis Chaves
added answers and improved how context is fetched from the chunks
ca919d4
from typing import List
from pydantic import BaseModel
import pdfplumber
from fastapi import UploadFile
from gliner import GLiNER
import logging
import torch
import re
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Entity(BaseModel):
entity: str
context: str
start: int
end: int
# Curated medical labels
MEDICAL_LABELS = [
"gene", "protein", "protein_isoform", "cell", "disease",
"phenotypic_feature", "clinical_finding", "anatomical_entity",
"pathway", "biological_process", "drug", "small_molecule",
"food_additive", "chemical_mixture", "molecular_entity",
"clinical_intervention", "clinical_trial", "hospitalization",
"geographic_location", "environmental_feature", "environmental_process",
"publication", "journal_article", "book", "patent", "dataset",
"study_result", "human", "mammal", "plant", "virus", "bacterium",
"cell_line", "biological_sex", "clinical_attribute",
"socioeconomic_attribute", "environmental_exposure", "drug_exposure",
"procedure", "treatment", "device", "diagnostic_aid", "event"
]
# Check for GPU availability
if torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
logger.info(f"Using device: {device}")
# Initialize model
gliner_model = GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5")
gliner_model.to(device) # Move model to GPU if available
def chunk_text(text: str, max_tokens: int = 700) -> List[str]:
"""
Split text into chunks that respect sentence boundaries and token limit.
We use 700 tokens to leave some margin for the model's special tokens.
Args:
text (str): Input text to chunk
max_tokens (int): Maximum number of tokens per chunk
Returns:
List[str]: List of text chunks
"""
# Split into sentences (simple approach)
sentences = re.split(r'(?<=[.!?])\s+', text)
chunks = []
current_chunk = []
current_length = 0
for sentence in sentences:
# Rough estimation of tokens (words + punctuation)
sentence_tokens = len(re.findall(r'\w+|[^\w\s]', sentence))
if current_length + sentence_tokens > max_tokens:
if current_chunk: # Save current chunk if it exists
chunks.append(' '.join(current_chunk))
current_chunk = []
current_length = 0
current_chunk.append(sentence)
current_length += sentence_tokens
# Don't forget the last chunk
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks
def extract_entities_from_pdf(file: UploadFile) -> List[Entity]:
"""
Extract medical entities from a PDF file using GLiNER.
Args:
file (UploadFile): The uploaded PDF file
Returns:
List[Entity]: List of extracted entities with their context
"""
logger.debug(f"Starting extraction for file: {file.filename}")
try:
# Create a temporary file to handle the upload
with pdfplumber.open(file.file) as pdf:
logger.info(f"Successfully opened PDF with {len(pdf.pages)} pages")
# Join all pages into single string
pdf_text = " ".join(p.extract_text() for p in pdf.pages)
logger.info(f"Extracted text length: {len(pdf_text)} characters")
# Split text into chunks
text_chunks = chunk_text(pdf_text)
logger.info(f"Split text into {len(text_chunks)} chunks")
# Extract entities from each chunk
all_entities = []
base_offset = 0 # Keep track of the absolute position in the original text
for chunk in text_chunks:
# Extract entities using GLiNER
chunk_entities = gliner_model.predict_entities(chunk, MEDICAL_LABELS, threshold=0.7)
# Process entities from this chunk
for ent in chunk_entities:
if len(ent["text"]) <= 2: # Skip very short entities
continue
# Just store the entity and its position for now
start_idx = chunk.find(ent["text"])
if start_idx != -1:
all_entities.append(Entity(
entity=ent["text"],
context="", # Will be filled later
start=base_offset + start_idx,
end=base_offset + start_idx + len(ent["text"])
))
base_offset += len(chunk) + 1 # +1 for the space between chunks
# Now get context for all entities using the complete original text
final_entities = []
for ent in all_entities:
# Get surrounding context from the complete text
context_start = max(0, ent.start - 50)
context_end = min(len(pdf_text), ent.end + 50)
context = pdf_text[context_start:context_end]
final_entities.append(Entity(
entity=ent.entity,
context=context,
start=ent.start,
end=ent.end
))
logger.info(f"Returning {len(final_entities)} processed entities")
return final_entities
except Exception as e:
logger.error(f"Error during extraction: {str(e)}", exc_info=True)
raise