Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoTokenizer | |
def load_model(model_path, head_path): | |
try: | |
model = SentenceTransformer(model_path) | |
classification_head = nn.Linear(model.get_sentence_embedding_dimension(), 5) | |
classification_head.load_state_dict(torch.load(head_path, map_location=torch.device('cpu'))) | |
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2') | |
device = torch.device('cpu') | |
model.to(device) | |
classification_head.to(device) | |
return model, classification_head, tokenizer, device | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
raise | |
def predict_spans(full_text, model, classification_head, tokenizer, device, | |
window_size=384, stride=256, min_span_length=3): | |
class_thresholds = { | |
0: 0.8, | |
1: 0.7, | |
2: 0.75, | |
3: 0.7, | |
4: 0.8 | |
} | |
label_map = { | |
0: 'personal_information', | |
1: 'skills', | |
2: 'education', | |
3: 'experience', | |
4: 'certification' | |
} | |
results = [] | |
full_text = full_text.strip() | |
for i in range(0, len(full_text), stride): | |
window_text = full_text[i:i+window_size] | |
encoding = tokenizer( | |
window_text, | |
max_length=window_size, | |
padding='max_length', | |
truncation=True, | |
return_offsets_mapping=True, | |
return_tensors='pt' | |
).to(device) | |
with torch.no_grad(): | |
model_output = model({ | |
'input_ids': encoding['input_ids'], | |
'attention_mask': encoding['attention_mask'] | |
}) | |
token_embeddings = model_output['token_embeddings'] | |
token_logits = classification_head(token_embeddings) | |
token_probs = torch.softmax(token_logits, dim=2) | |
offset_mapping = encoding['offset_mapping'][0].cpu().numpy() | |
current_span = None | |
for token_idx, (start, end) in enumerate(offset_mapping): | |
if start == end == 0: | |
continue | |
probs = token_probs[0, token_idx] | |
max_prob, pred_label = torch.max(probs, dim=0) | |
max_prob = max_prob.item() | |
pred_label = pred_label.item() | |
if max_prob > class_thresholds[pred_label]: | |
token_text = window_text[start:end] | |
if token_text.startswith('##'): | |
if current_span and current_span['label'] == label_map[pred_label]: | |
current_span['text'] += token_text[2:] | |
current_span['position'] = (current_span['position'][0], i+end) | |
current_span['confidence'] = max(current_span['confidence'], max_prob) | |
continue | |
if (current_span and | |
current_span['label'] == label_map[pred_label] and | |
(i+start - current_span['position'][1]) <= 2): | |
current_span['text'] += ' ' + token_text | |
current_span['position'] = (current_span['position'][0], i+end) | |
current_span['confidence'] = max(current_span['confidence'], max_prob) | |
else: | |
if current_span: | |
results.append(current_span) | |
current_span = { | |
'text': token_text, | |
'label': label_map[pred_label], | |
'confidence': max_prob, | |
'position': (i+start, i+end) | |
} | |
else: | |
if current_span: | |
results.append(current_span) | |
current_span = None | |
if current_span: | |
results.append(current_span) | |
filtered_results = [] | |
for span in results: | |
clean_text = span['text'].strip() | |
if len(clean_text.split()) >= min_span_length or span['confidence'] > 0.9: | |
span['text'] = clean_text | |
filtered_results.append(span) | |
merged_results = [] | |
filtered_results.sort(key=lambda x: x['position'][0]) | |
for span in filtered_results: | |
if not merged_results: | |
merged_results.append(span) | |
else: | |
last = merged_results[-1] | |
if (span['label'] == last['label'] and | |
span['position'][0] <= last['position'][1] + 5): | |
merged_text = last['text'] + ' ' + span['text'] | |
merged_results[-1] = { | |
'text': merged_text, | |
'label': span['label'], | |
'confidence': max(last['confidence'], span['confidence']), | |
'position': (last['position'][0], span['position'][1]) | |
} | |
else: | |
merged_results.append(span) | |
for span in merged_results: | |
tokens = span['text'].split() | |
if len(tokens) > 15: | |
span['text'] = ' '.join(tokens[:15]) | |
return merged_results | |
def format_results(spans): | |
formatted = {} | |
for span in spans: | |
label = span['label'] | |
if label not in formatted: | |
formatted[label] = [] | |
formatted[label].append(span) | |
for label in formatted: | |
formatted[label].sort(key=lambda x: x['confidence'], reverse=True) | |
return formatted | |
def format_final_output(formatted_results): | |
final_output = [] | |
for label, items in formatted_results.items(): | |
top_n = 1 if label == 'personal_information' else 3 | |
label_upper = label.upper() | |
for item in items[:top_n]: | |
final_output.append(f"{label_upper}: {item['text']}") | |
return " ".join(final_output) | |