cvsummarizationsbertandt5 / span_classifier.py
rfahlevih's picture
Initial Commit
5581268 verified
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
def load_model(model_path, head_path):
try:
model = SentenceTransformer(model_path)
classification_head = nn.Linear(model.get_sentence_embedding_dimension(), 5)
classification_head.load_state_dict(torch.load(head_path, map_location=torch.device('cpu')))
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
device = torch.device('cpu')
model.to(device)
classification_head.to(device)
return model, classification_head, tokenizer, device
except Exception as e:
print(f"Error loading model: {e}")
raise
def predict_spans(full_text, model, classification_head, tokenizer, device,
window_size=384, stride=256, min_span_length=3):
class_thresholds = {
0: 0.8,
1: 0.7,
2: 0.75,
3: 0.7,
4: 0.8
}
label_map = {
0: 'personal_information',
1: 'skills',
2: 'education',
3: 'experience',
4: 'certification'
}
results = []
full_text = full_text.strip()
for i in range(0, len(full_text), stride):
window_text = full_text[i:i+window_size]
encoding = tokenizer(
window_text,
max_length=window_size,
padding='max_length',
truncation=True,
return_offsets_mapping=True,
return_tensors='pt'
).to(device)
with torch.no_grad():
model_output = model({
'input_ids': encoding['input_ids'],
'attention_mask': encoding['attention_mask']
})
token_embeddings = model_output['token_embeddings']
token_logits = classification_head(token_embeddings)
token_probs = torch.softmax(token_logits, dim=2)
offset_mapping = encoding['offset_mapping'][0].cpu().numpy()
current_span = None
for token_idx, (start, end) in enumerate(offset_mapping):
if start == end == 0:
continue
probs = token_probs[0, token_idx]
max_prob, pred_label = torch.max(probs, dim=0)
max_prob = max_prob.item()
pred_label = pred_label.item()
if max_prob > class_thresholds[pred_label]:
token_text = window_text[start:end]
if token_text.startswith('##'):
if current_span and current_span['label'] == label_map[pred_label]:
current_span['text'] += token_text[2:]
current_span['position'] = (current_span['position'][0], i+end)
current_span['confidence'] = max(current_span['confidence'], max_prob)
continue
if (current_span and
current_span['label'] == label_map[pred_label] and
(i+start - current_span['position'][1]) <= 2):
current_span['text'] += ' ' + token_text
current_span['position'] = (current_span['position'][0], i+end)
current_span['confidence'] = max(current_span['confidence'], max_prob)
else:
if current_span:
results.append(current_span)
current_span = {
'text': token_text,
'label': label_map[pred_label],
'confidence': max_prob,
'position': (i+start, i+end)
}
else:
if current_span:
results.append(current_span)
current_span = None
if current_span:
results.append(current_span)
filtered_results = []
for span in results:
clean_text = span['text'].strip()
if len(clean_text.split()) >= min_span_length or span['confidence'] > 0.9:
span['text'] = clean_text
filtered_results.append(span)
merged_results = []
filtered_results.sort(key=lambda x: x['position'][0])
for span in filtered_results:
if not merged_results:
merged_results.append(span)
else:
last = merged_results[-1]
if (span['label'] == last['label'] and
span['position'][0] <= last['position'][1] + 5):
merged_text = last['text'] + ' ' + span['text']
merged_results[-1] = {
'text': merged_text,
'label': span['label'],
'confidence': max(last['confidence'], span['confidence']),
'position': (last['position'][0], span['position'][1])
}
else:
merged_results.append(span)
for span in merged_results:
tokens = span['text'].split()
if len(tokens) > 15:
span['text'] = ' '.join(tokens[:15])
return merged_results
def format_results(spans):
formatted = {}
for span in spans:
label = span['label']
if label not in formatted:
formatted[label] = []
formatted[label].append(span)
for label in formatted:
formatted[label].sort(key=lambda x: x['confidence'], reverse=True)
return formatted
def format_final_output(formatted_results):
final_output = []
for label, items in formatted_results.items():
top_n = 1 if label == 'personal_information' else 3
label_upper = label.upper()
for item in items[:top_n]:
final_output.append(f"{label_upper}: {item['text']}")
return " ".join(final_output)