Spaces:
Sleeping
Sleeping
from typing import List | |
from pydantic import BaseModel | |
import pdfplumber | |
from fastapi import UploadFile | |
from gliner import GLiNER | |
import logging | |
import torch | |
import re | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class Entity(BaseModel): | |
entity: str | |
context: str | |
start: int | |
end: int | |
# Curated medical labels | |
MEDICAL_LABELS = [ | |
"gene", "protein", "protein_isoform", "cell", "disease", | |
"phenotypic_feature", "clinical_finding", "anatomical_entity", | |
"pathway", "biological_process", "drug", "small_molecule", | |
"food_additive", "chemical_mixture", "molecular_entity", | |
"clinical_intervention", "clinical_trial", "hospitalization", | |
"geographic_location", "environmental_feature", "environmental_process", | |
"publication", "journal_article", "book", "patent", "dataset", | |
"study_result", "human", "mammal", "plant", "virus", "bacterium", | |
"cell_line", "biological_sex", "clinical_attribute", | |
"socioeconomic_attribute", "environmental_exposure", "drug_exposure", | |
"procedure", "treatment", "device", "diagnostic_aid", "event" | |
] | |
# Check for GPU availability | |
if torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
elif torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
logger.info(f"Using device: {device}") | |
# Initialize model | |
gliner_model = GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5") | |
gliner_model.to(device) # Move model to GPU if available | |
def chunk_text(text: str, max_tokens: int = 700) -> List[str]: | |
""" | |
Split text into chunks that respect sentence boundaries and token limit. | |
We use 700 tokens to leave some margin for the model's special tokens. | |
Args: | |
text (str): Input text to chunk | |
max_tokens (int): Maximum number of tokens per chunk | |
Returns: | |
List[str]: List of text chunks | |
""" | |
# Split into sentences (simple approach) | |
sentences = re.split(r'(?<=[.!?])\s+', text) | |
chunks = [] | |
current_chunk = [] | |
current_length = 0 | |
for sentence in sentences: | |
# Rough estimation of tokens (words + punctuation) | |
sentence_tokens = len(re.findall(r'\w+|[^\w\s]', sentence)) | |
if current_length + sentence_tokens > max_tokens: | |
if current_chunk: # Save current chunk if it exists | |
chunks.append(' '.join(current_chunk)) | |
current_chunk = [] | |
current_length = 0 | |
current_chunk.append(sentence) | |
current_length += sentence_tokens | |
# Don't forget the last chunk | |
if current_chunk: | |
chunks.append(' '.join(current_chunk)) | |
return chunks | |
def extract_entities_from_pdf(file: UploadFile) -> List[Entity]: | |
""" | |
Extract medical entities from a PDF file using GLiNER. | |
Args: | |
file (UploadFile): The uploaded PDF file | |
Returns: | |
List[Entity]: List of extracted entities with their context | |
""" | |
logger.debug(f"Starting extraction for file: {file.filename}") | |
try: | |
# Create a temporary file to handle the upload | |
with pdfplumber.open(file.file) as pdf: | |
logger.info(f"Successfully opened PDF with {len(pdf.pages)} pages") | |
# Join all pages into single string | |
pdf_text = " ".join(p.extract_text() for p in pdf.pages) | |
logger.info(f"Extracted text length: {len(pdf_text)} characters") | |
# Split text into chunks | |
text_chunks = chunk_text(pdf_text) | |
logger.info(f"Split text into {len(text_chunks)} chunks") | |
# Extract entities from each chunk | |
all_entities = [] | |
base_offset = 0 # Keep track of the absolute position in the original text | |
for chunk in text_chunks: | |
# Extract entities using GLiNER | |
chunk_entities = gliner_model.predict_entities(chunk, MEDICAL_LABELS, threshold=0.7) | |
# Process entities from this chunk | |
for ent in chunk_entities: | |
if len(ent["text"]) <= 2: # Skip very short entities | |
continue | |
# Just store the entity and its position for now | |
start_idx = chunk.find(ent["text"]) | |
if start_idx != -1: | |
all_entities.append(Entity( | |
entity=ent["text"], | |
context="", # Will be filled later | |
start=base_offset + start_idx, | |
end=base_offset + start_idx + len(ent["text"]) | |
)) | |
base_offset += len(chunk) + 1 # +1 for the space between chunks | |
# Now get context for all entities using the complete original text | |
final_entities = [] | |
for ent in all_entities: | |
# Get surrounding context from the complete text | |
context_start = max(0, ent.start - 50) | |
context_end = min(len(pdf_text), ent.end + 50) | |
context = pdf_text[context_start:context_end] | |
final_entities.append(Entity( | |
entity=ent.entity, | |
context=context, | |
start=ent.start, | |
end=ent.end | |
)) | |
logger.info(f"Returning {len(final_entities)} processed entities") | |
return final_entities | |
except Exception as e: | |
logger.error(f"Error during extraction: {str(e)}", exc_info=True) | |
raise | |