Spaces:
Sleeping
Sleeping
File size: 5,964 Bytes
5581268 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
def load_model(model_path, head_path):
try:
model = SentenceTransformer(model_path)
classification_head = nn.Linear(model.get_sentence_embedding_dimension(), 5)
classification_head.load_state_dict(torch.load(head_path, map_location=torch.device('cpu')))
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
device = torch.device('cpu')
model.to(device)
classification_head.to(device)
return model, classification_head, tokenizer, device
except Exception as e:
print(f"Error loading model: {e}")
raise
def predict_spans(full_text, model, classification_head, tokenizer, device,
window_size=384, stride=256, min_span_length=3):
class_thresholds = {
0: 0.8,
1: 0.7,
2: 0.75,
3: 0.7,
4: 0.8
}
label_map = {
0: 'personal_information',
1: 'skills',
2: 'education',
3: 'experience',
4: 'certification'
}
results = []
full_text = full_text.strip()
for i in range(0, len(full_text), stride):
window_text = full_text[i:i+window_size]
encoding = tokenizer(
window_text,
max_length=window_size,
padding='max_length',
truncation=True,
return_offsets_mapping=True,
return_tensors='pt'
).to(device)
with torch.no_grad():
model_output = model({
'input_ids': encoding['input_ids'],
'attention_mask': encoding['attention_mask']
})
token_embeddings = model_output['token_embeddings']
token_logits = classification_head(token_embeddings)
token_probs = torch.softmax(token_logits, dim=2)
offset_mapping = encoding['offset_mapping'][0].cpu().numpy()
current_span = None
for token_idx, (start, end) in enumerate(offset_mapping):
if start == end == 0:
continue
probs = token_probs[0, token_idx]
max_prob, pred_label = torch.max(probs, dim=0)
max_prob = max_prob.item()
pred_label = pred_label.item()
if max_prob > class_thresholds[pred_label]:
token_text = window_text[start:end]
if token_text.startswith('##'):
if current_span and current_span['label'] == label_map[pred_label]:
current_span['text'] += token_text[2:]
current_span['position'] = (current_span['position'][0], i+end)
current_span['confidence'] = max(current_span['confidence'], max_prob)
continue
if (current_span and
current_span['label'] == label_map[pred_label] and
(i+start - current_span['position'][1]) <= 2):
current_span['text'] += ' ' + token_text
current_span['position'] = (current_span['position'][0], i+end)
current_span['confidence'] = max(current_span['confidence'], max_prob)
else:
if current_span:
results.append(current_span)
current_span = {
'text': token_text,
'label': label_map[pred_label],
'confidence': max_prob,
'position': (i+start, i+end)
}
else:
if current_span:
results.append(current_span)
current_span = None
if current_span:
results.append(current_span)
filtered_results = []
for span in results:
clean_text = span['text'].strip()
if len(clean_text.split()) >= min_span_length or span['confidence'] > 0.9:
span['text'] = clean_text
filtered_results.append(span)
merged_results = []
filtered_results.sort(key=lambda x: x['position'][0])
for span in filtered_results:
if not merged_results:
merged_results.append(span)
else:
last = merged_results[-1]
if (span['label'] == last['label'] and
span['position'][0] <= last['position'][1] + 5):
merged_text = last['text'] + ' ' + span['text']
merged_results[-1] = {
'text': merged_text,
'label': span['label'],
'confidence': max(last['confidence'], span['confidence']),
'position': (last['position'][0], span['position'][1])
}
else:
merged_results.append(span)
for span in merged_results:
tokens = span['text'].split()
if len(tokens) > 15:
span['text'] = ' '.join(tokens[:15])
return merged_results
def format_results(spans):
formatted = {}
for span in spans:
label = span['label']
if label not in formatted:
formatted[label] = []
formatted[label].append(span)
for label in formatted:
formatted[label].sort(key=lambda x: x['confidence'], reverse=True)
return formatted
def format_final_output(formatted_results):
final_output = []
for label, items in formatted_results.items():
top_n = 1 if label == 'personal_information' else 3
label_upper = label.upper()
for item in items[:top_n]:
final_output.append(f"{label_upper}: {item['text']}")
return " ".join(final_output)
|