Spaces:
Running
Running
File size: 5,563 Bytes
73b49a2 d86a1f5 73b49a2 d86a1f5 73b49a2 d86a1f5 4c91d01 d86a1f5 73b49a2 d86a1f5 73b49a2 d86a1f5 73b49a2 d86a1f5 73b49a2 d86a1f5 73b49a2 d86a1f5 ca919d4 d86a1f5 ca919d4 d86a1f5 73b49a2 ca919d4 73b49a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
from typing import List
from pydantic import BaseModel
import pdfplumber
from fastapi import UploadFile
from gliner import GLiNER
import logging
import torch
import re
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Entity(BaseModel):
entity: str
context: str
start: int
end: int
# Curated medical labels
MEDICAL_LABELS = [
"gene", "protein", "protein_isoform", "cell", "disease",
"phenotypic_feature", "clinical_finding", "anatomical_entity",
"pathway", "biological_process", "drug", "small_molecule",
"food_additive", "chemical_mixture", "molecular_entity",
"clinical_intervention", "clinical_trial", "hospitalization",
"geographic_location", "environmental_feature", "environmental_process",
"publication", "journal_article", "book", "patent", "dataset",
"study_result", "human", "mammal", "plant", "virus", "bacterium",
"cell_line", "biological_sex", "clinical_attribute",
"socioeconomic_attribute", "environmental_exposure", "drug_exposure",
"procedure", "treatment", "device", "diagnostic_aid", "event"
]
# Check for GPU availability
if torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
logger.info(f"Using device: {device}")
# Initialize model
gliner_model = GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5")
gliner_model.to(device) # Move model to GPU if available
def chunk_text(text: str, max_tokens: int = 700) -> List[str]:
"""
Split text into chunks that respect sentence boundaries and token limit.
We use 700 tokens to leave some margin for the model's special tokens.
Args:
text (str): Input text to chunk
max_tokens (int): Maximum number of tokens per chunk
Returns:
List[str]: List of text chunks
"""
# Split into sentences (simple approach)
sentences = re.split(r'(?<=[.!?])\s+', text)
chunks = []
current_chunk = []
current_length = 0
for sentence in sentences:
# Rough estimation of tokens (words + punctuation)
sentence_tokens = len(re.findall(r'\w+|[^\w\s]', sentence))
if current_length + sentence_tokens > max_tokens:
if current_chunk: # Save current chunk if it exists
chunks.append(' '.join(current_chunk))
current_chunk = []
current_length = 0
current_chunk.append(sentence)
current_length += sentence_tokens
# Don't forget the last chunk
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks
def extract_entities_from_pdf(file: UploadFile) -> List[Entity]:
"""
Extract medical entities from a PDF file using GLiNER.
Args:
file (UploadFile): The uploaded PDF file
Returns:
List[Entity]: List of extracted entities with their context
"""
logger.debug(f"Starting extraction for file: {file.filename}")
try:
# Create a temporary file to handle the upload
with pdfplumber.open(file.file) as pdf:
logger.info(f"Successfully opened PDF with {len(pdf.pages)} pages")
# Join all pages into single string
pdf_text = " ".join(p.extract_text() for p in pdf.pages)
logger.info(f"Extracted text length: {len(pdf_text)} characters")
# Split text into chunks
text_chunks = chunk_text(pdf_text)
logger.info(f"Split text into {len(text_chunks)} chunks")
# Extract entities from each chunk
all_entities = []
base_offset = 0 # Keep track of the absolute position in the original text
for chunk in text_chunks:
# Extract entities using GLiNER
chunk_entities = gliner_model.predict_entities(chunk, MEDICAL_LABELS, threshold=0.7)
# Process entities from this chunk
for ent in chunk_entities:
if len(ent["text"]) <= 2: # Skip very short entities
continue
# Just store the entity and its position for now
start_idx = chunk.find(ent["text"])
if start_idx != -1:
all_entities.append(Entity(
entity=ent["text"],
context="", # Will be filled later
start=base_offset + start_idx,
end=base_offset + start_idx + len(ent["text"])
))
base_offset += len(chunk) + 1 # +1 for the space between chunks
# Now get context for all entities using the complete original text
final_entities = []
for ent in all_entities:
# Get surrounding context from the complete text
context_start = max(0, ent.start - 50)
context_end = min(len(pdf_text), ent.end + 50)
context = pdf_text[context_start:context_end]
final_entities.append(Entity(
entity=ent.entity,
context=context,
start=ent.start,
end=ent.end
))
logger.info(f"Returning {len(final_entities)} processed entities")
return final_entities
except Exception as e:
logger.error(f"Error during extraction: {str(e)}", exc_info=True)
raise
|