File size: 5,563 Bytes
73b49a2
 
 
 
 
 
d86a1f5
 
73b49a2
 
d86a1f5
73b49a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d86a1f5
4c91d01
 
 
 
 
 
d86a1f5
 
73b49a2
 
d86a1f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73b49a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d86a1f5
73b49a2
 
d86a1f5
 
 
 
 
73b49a2
d86a1f5
 
 
73b49a2
d86a1f5
 
 
 
 
 
 
 
 
ca919d4
d86a1f5
 
 
 
ca919d4
 
d86a1f5
 
 
 
73b49a2
ca919d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73b49a2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from typing import List
from pydantic import BaseModel
import pdfplumber
from fastapi import UploadFile
from gliner import GLiNER
import logging
import torch
import re

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class Entity(BaseModel):
    entity: str
    context: str
    start: int
    end: int

# Curated medical labels
MEDICAL_LABELS = [
    "gene", "protein", "protein_isoform", "cell", "disease",
    "phenotypic_feature", "clinical_finding", "anatomical_entity",
    "pathway", "biological_process", "drug", "small_molecule",
    "food_additive", "chemical_mixture", "molecular_entity",
    "clinical_intervention", "clinical_trial", "hospitalization",
    "geographic_location", "environmental_feature", "environmental_process",
    "publication", "journal_article", "book", "patent", "dataset",
    "study_result", "human", "mammal", "plant", "virus", "bacterium",
    "cell_line", "biological_sex", "clinical_attribute",
    "socioeconomic_attribute", "environmental_exposure", "drug_exposure",
    "procedure", "treatment", "device", "diagnostic_aid", "event"
]

# Check for GPU availability
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
logger.info(f"Using device: {device}")

# Initialize model
gliner_model = GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5")
gliner_model.to(device)  # Move model to GPU if available

def chunk_text(text: str, max_tokens: int = 700) -> List[str]:
    """
    Split text into chunks that respect sentence boundaries and token limit.
    We use 700 tokens to leave some margin for the model's special tokens.
    
    Args:
        text (str): Input text to chunk
        max_tokens (int): Maximum number of tokens per chunk
        
    Returns:
        List[str]: List of text chunks
    """
    # Split into sentences (simple approach)
    sentences = re.split(r'(?<=[.!?])\s+', text)
    chunks = []
    current_chunk = []
    current_length = 0
    
    for sentence in sentences:
        # Rough estimation of tokens (words + punctuation)
        sentence_tokens = len(re.findall(r'\w+|[^\w\s]', sentence))
        
        if current_length + sentence_tokens > max_tokens:
            if current_chunk:  # Save current chunk if it exists
                chunks.append(' '.join(current_chunk))
                current_chunk = []
                current_length = 0
        
        current_chunk.append(sentence)
        current_length += sentence_tokens
    
    # Don't forget the last chunk
    if current_chunk:
        chunks.append(' '.join(current_chunk))
    
    return chunks

def extract_entities_from_pdf(file: UploadFile) -> List[Entity]:
    """
    Extract medical entities from a PDF file using GLiNER.
    
    Args:
        file (UploadFile): The uploaded PDF file
        
    Returns:
        List[Entity]: List of extracted entities with their context
    """
    logger.debug(f"Starting extraction for file: {file.filename}")
    
    try:
        # Create a temporary file to handle the upload
        with pdfplumber.open(file.file) as pdf:
            logger.info(f"Successfully opened PDF with {len(pdf.pages)} pages")
            # Join all pages into single string
            pdf_text = " ".join(p.extract_text() for p in pdf.pages)
            logger.info(f"Extracted text length: {len(pdf_text)} characters")
        
        # Split text into chunks
        text_chunks = chunk_text(pdf_text)
        logger.info(f"Split text into {len(text_chunks)} chunks")
        
        # Extract entities from each chunk
        all_entities = []
        base_offset = 0  # Keep track of the absolute position in the original text
        
        for chunk in text_chunks:
            # Extract entities using GLiNER
            chunk_entities = gliner_model.predict_entities(chunk, MEDICAL_LABELS, threshold=0.7)
            
            # Process entities from this chunk
            for ent in chunk_entities:
                if len(ent["text"]) <= 2:  # Skip very short entities
                    continue
                    
                # Just store the entity and its position for now
                start_idx = chunk.find(ent["text"])
                if start_idx != -1:
                    all_entities.append(Entity(
                        entity=ent["text"],
                        context="",  # Will be filled later
                        start=base_offset + start_idx,
                        end=base_offset + start_idx + len(ent["text"])
                    ))
            
            base_offset += len(chunk) + 1  # +1 for the space between chunks
        
        # Now get context for all entities using the complete original text
        final_entities = []
        for ent in all_entities:
            # Get surrounding context from the complete text
            context_start = max(0, ent.start - 50)
            context_end = min(len(pdf_text), ent.end + 50)
            context = pdf_text[context_start:context_end]
            
            final_entities.append(Entity(
                entity=ent.entity,
                context=context,
                start=ent.start,
                end=ent.end
            ))
        
        logger.info(f"Returning {len(final_entities)} processed entities")
        return final_entities
        
    except Exception as e:
        logger.error(f"Error during extraction: {str(e)}", exc_info=True)
        raise