import torch import torch.nn as nn from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer def load_model(model_path, head_path): try: model = SentenceTransformer(model_path) classification_head = nn.Linear(model.get_sentence_embedding_dimension(), 5) classification_head.load_state_dict(torch.load(head_path, map_location=torch.device('cpu'))) tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2') device = torch.device('cpu') model.to(device) classification_head.to(device) return model, classification_head, tokenizer, device except Exception as e: print(f"Error loading model: {e}") raise def predict_spans(full_text, model, classification_head, tokenizer, device, window_size=384, stride=256, min_span_length=3): class_thresholds = { 0: 0.8, 1: 0.7, 2: 0.75, 3: 0.7, 4: 0.8 } label_map = { 0: 'personal_information', 1: 'skills', 2: 'education', 3: 'experience', 4: 'certification' } results = [] full_text = full_text.strip() for i in range(0, len(full_text), stride): window_text = full_text[i:i+window_size] encoding = tokenizer( window_text, max_length=window_size, padding='max_length', truncation=True, return_offsets_mapping=True, return_tensors='pt' ).to(device) with torch.no_grad(): model_output = model({ 'input_ids': encoding['input_ids'], 'attention_mask': encoding['attention_mask'] }) token_embeddings = model_output['token_embeddings'] token_logits = classification_head(token_embeddings) token_probs = torch.softmax(token_logits, dim=2) offset_mapping = encoding['offset_mapping'][0].cpu().numpy() current_span = None for token_idx, (start, end) in enumerate(offset_mapping): if start == end == 0: continue probs = token_probs[0, token_idx] max_prob, pred_label = torch.max(probs, dim=0) max_prob = max_prob.item() pred_label = pred_label.item() if max_prob > class_thresholds[pred_label]: token_text = window_text[start:end] if token_text.startswith('##'): if current_span and current_span['label'] == label_map[pred_label]: current_span['text'] += token_text[2:] current_span['position'] = (current_span['position'][0], i+end) current_span['confidence'] = max(current_span['confidence'], max_prob) continue if (current_span and current_span['label'] == label_map[pred_label] and (i+start - current_span['position'][1]) <= 2): current_span['text'] += ' ' + token_text current_span['position'] = (current_span['position'][0], i+end) current_span['confidence'] = max(current_span['confidence'], max_prob) else: if current_span: results.append(current_span) current_span = { 'text': token_text, 'label': label_map[pred_label], 'confidence': max_prob, 'position': (i+start, i+end) } else: if current_span: results.append(current_span) current_span = None if current_span: results.append(current_span) filtered_results = [] for span in results: clean_text = span['text'].strip() if len(clean_text.split()) >= min_span_length or span['confidence'] > 0.9: span['text'] = clean_text filtered_results.append(span) merged_results = [] filtered_results.sort(key=lambda x: x['position'][0]) for span in filtered_results: if not merged_results: merged_results.append(span) else: last = merged_results[-1] if (span['label'] == last['label'] and span['position'][0] <= last['position'][1] + 5): merged_text = last['text'] + ' ' + span['text'] merged_results[-1] = { 'text': merged_text, 'label': span['label'], 'confidence': max(last['confidence'], span['confidence']), 'position': (last['position'][0], span['position'][1]) } else: merged_results.append(span) for span in merged_results: tokens = span['text'].split() if len(tokens) > 15: span['text'] = ' '.join(tokens[:15]) return merged_results def format_results(spans): formatted = {} for span in spans: label = span['label'] if label not in formatted: formatted[label] = [] formatted[label].append(span) for label in formatted: formatted[label].sort(key=lambda x: x['confidence'], reverse=True) return formatted def format_final_output(formatted_results): final_output = [] for label, items in formatted_results.items(): top_n = 1 if label == 'personal_information' else 3 label_upper = label.upper() for item in items[:top_n]: final_output.append(f"{label_upper}: {item['text']}") return " ".join(final_output)