import os |
import torch |
from transformers import AutoTokenizer, PreTrainedTokenizerFast, AutoConfig |
from torch.nn import functional as F |
import numpy as np |
from sklearn.metrics.pairwise import cosine_similarity |
class SentenceEmbeddingModel(torch.nn.Module): |
""" |
Sentence Embedding model for inference |
""" |
def __init__(self, config): |
super(SentenceEmbeddingModel, self).__init__() |
from transformers import AutoModel |
self.transformer = AutoModel.from_config(config) |
self.pooling_mode = 'mean' |
def forward(self, input_ids, attention_mask): |
outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask) |
token_embeddings = outputs[0] |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) |
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
pooled_output = sum_embeddings / sum_mask |
pooled_output = F.normalize(pooled_output, p=2, dim=1) |
return pooled_output |
class SentenceEmbedder: |
def __init__(self, model_path): |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
print(f"Using device: {self.device}") |
tokenizer_loaded = False |
if not tokenizer_loaded: |
try: |
print(f"Trying AutoTokenizer from {model_path}") |
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
tokenizer_loaded = True |
print(f"Successfully loaded tokenizer with AutoTokenizer, vocab size: {self.tokenizer.vocab_size}") |
except Exception as e: |
print(f"AutoTokenizer failed: {e}") |
if not tokenizer_loaded: |
spm_model_path = os.path.join(model_path, "sentencepiece.bpe.model") |
if os.path.exists(spm_model_path): |
try: |
print(f"Trying to load SentencePiece model from {spm_model_path}") |
import sentencepiece as spm |
sp_model = spm.SentencePieceProcessor() |
sp_model.Load(spm_model_path) |
from transformers import PreTrainedTokenizer |
class SentencePieceTokenizer(PreTrainedTokenizer): |
def __init__(self, sp_model): |
super().__init__(bos_token="<s>", eos_token="</s>", |
unk_token="<unk>", pad_token="<pad>", |
mask_token="<mask>") |
self.sp_model = sp_model |
def _tokenize(self, text): |
return self.sp_model.EncodeAsPieces(text) |
def _convert_token_to_id(self, token): |
return self.sp_model.PieceToId(token) |
def _convert_id_to_token(self, index): |
return self.sp_model.IdToPiece(index) |
@property |
def vocab_size(self): |
return self.sp_model.GetPieceSize() |
self.tokenizer = SentencePieceTokenizer(sp_model) |
tokenizer_loaded = True |
print(f"Successfully loaded SentencePiece tokenizer, vocab size: {self.tokenizer.vocab_size}") |
except Exception as e: |
print(f"SentencePiece loading failed: {e}") |
if not tokenizer_loaded: |
tokenizer_json_path = os.path.join(model_path, "tokenizer.json") |
if os.path.exists(tokenizer_json_path): |
try: |
print(f"Trying to load tokenizer from {tokenizer_json_path}") |
self.tokenizer = PreTrainedTokenizerFast( |
tokenizer_file=tokenizer_json_path, |
bos_token="<s>", |
eos_token="</s>", |
unk_token="<unk>", |
pad_token="<pad>", |
mask_token="<mask>", |
model_max_length=512 |
) |
tokenizer_loaded = True |
print(f"Successfully loaded tokenizer with PreTrainedTokenizerFast, vocab size: {self.tokenizer.vocab_size}") |
except Exception as e: |
print(f"PreTrainedTokenizerFast failed: {e}") |
if not tokenizer_loaded: |
try: |
print("Searching for any tokenizer files in the directory...") |
candidate_files = [] |
for file in os.listdir(model_path): |
filepath = os.path.join(model_path, file) |
if os.path.isfile(filepath) and any(keyword in file.lower() for keyword in ['token', 'vocab', 'sentencepiece', 'bpe']): |
candidate_files.append(filepath) |
if candidate_files: |
print(f"Found potential tokenizer files: {candidate_files}") |
for file_path in candidate_files: |
try: |
if file_path.endswith('.json'): |
self.tokenizer = PreTrainedTokenizerFast( |
tokenizer_file=file_path, |
bos_token="<s>", |
eos_token="</s>", |
unk_token="<unk>", |
pad_token="<pad>", |
mask_token="<mask>", |
model_max_length=512 |
) |
tokenizer_loaded = True |
print(f"Successfully loaded tokenizer from {file_path}") |
break |
elif file_path.endswith('.model'): |
import sentencepiece as spm |
sp_model = spm.SentencePieceProcessor() |
sp_model.Load(file_path) |
tokenizer_loaded = True |
print(f"Successfully loaded SentencePiece from {file_path}") |
break |
except Exception as file_e: |
print(f"Failed to load {file_path}: {file_e}") |
except Exception as e: |
print(f"Error searching for tokenizer files: {e}") |
if not tokenizer_loaded: |
raise ValueError("Could not load tokenizer from any available source. Please check the model directory.") |
try: |
print(f"Loading config from {model_path}") |
config = AutoConfig.from_pretrained(model_path) |
print(f"Config loaded with hidden_size={config.hidden_size}") |
except Exception as e: |
print(f"Error loading config: {e}") |
raise RuntimeError("Could not load model configuration") |
try: |
model_path_pt = os.path.join(model_path, 'embedding_model.pt') |
try: |
model_info = torch.load( |
model_path_pt, |
map_location=self.device, |
weights_only=False |
) |
except TypeError: |
model_info = torch.load( |
model_path_pt, |
map_location=self.device |
) |
print(f"Model info keys: {list(model_info.keys())}") |
except Exception as e: |
print(f"Error loading model weights: {e}") |
raise RuntimeError(f"Could not load model weights: {e}") |
self.model = SentenceEmbeddingModel(config) |
if 'model_state_dict' in model_info: |
self.model.load_state_dict(model_info['model_state_dict']) |
else: |
self.model.load_state_dict(model_info) |
self.model.to(self.device) |
self.model.eval() |
self.embedding_dim = model_info.get('embedding_dim', config.hidden_size) |
print(f"Model loaded successfully with embedding dimension: {self.embedding_dim}") |
def encode(self, sentences, batch_size=32): |
""" |
Encode sentences to embeddings |
""" |
if isinstance(sentences, str): |
sentences = [sentences] |
all_embeddings = [] |
for i in range(0, len(sentences), batch_size): |
batch = sentences[i:i+batch_size] |
encoded_input = self.tokenizer( |
batch, |
padding=True, |
truncation=True, |
max_length=128, |
return_tensors='pt' |
).to(self.device) |
with torch.no_grad(): |
embeddings = self.model(encoded_input['input_ids'], encoded_input['attention_mask']) |
all_embeddings.append(embeddings.cpu().numpy()) |
all_embeddings = np.vstack(all_embeddings) |
return all_embeddings |
def compute_similarity(self, sentences1, sentences2=None): |
""" |
Compute similarity between sentences |
""" |
embeddings1 = self.encode(sentences1) |
if sentences2 is None: |
return cosine_similarity(embeddings1) |
else: |
embeddings2 = self.encode(sentences2) |
return np.array([cosine_similarity([e1], [e2])[0][0] for e1, e2 in zip(embeddings1, embeddings2)]) |
def search(self, query, documents, top_k=5): |
""" |
Search for the most similar documents to a query |
""" |
query_embedding = self.encode([query])[0] |
document_embeddings = self.encode(documents) |
similarities = cosine_similarity([query_embedding], document_embeddings)[0] |
top_indices = similarities.argsort()[-top_k:][::-1] |
results = [] |
for idx in top_indices: |
results.append({ |
'document': documents[idx], |
'score': similarities[idx] |
}) |
return results |
def main(): |
model_path = "output/hindi-sentence-embeddings-from-scratch/final" |
mode = "similarity" |
model = SentenceEmbedder(model_path) |
sentences = [ |
'मुझे हिंदी भाषा बहुत पसंद है।', |
'मैं हिंदी भाषा सीख रहा हूँ।', |
'भारत एक विशाल देश है।', |
'भारत में बहुत सारी भाषाएँ बोली जाती हैं।', |
'आज मौसम बहुत अच्छा है।', |
'कल बारिश होगी।', |
'दिल्ली भारत की राजधानी है।', |
'मुंबई भारत का आर्थिक केंद्र है।', |
'भारतीय खाना बहुत स्वादिष्ट होता है।', |
'मैं आज बाजार जाऊंगा।' |
] |
document_corpus = [ |
'हिंदी भारत की आधिकारिक भाषा है।', |
'भारत में अनेक भाषाएँ बोली जाती हैं।', |
'दिल्ली भारत की राजधानी है।', |
'मुंबई भारत का सबसे बड़ा शहर है।', |
'हिमालय पर्वत भारत के उत्तर में स्थित है।', |
'गंगा नदी भारत की सबसे पवित्र नदी है।', |
'भारतीय संस्कृति बहुत समृद्ध है।', |
'भारत में अनेक त्योहार मनाए जाते हैं।', |
'तमिल, तेलुगु, कन्नड़ और मलयालम दक्षिण भारत की प्रमुख भाषाएँ हैं।', |
'आम, अमरूद और केला भारत के लोकप्रिय फल हैं।', |
'भारत में विभिन्न धर्मों के लोग एक साथ रहते हैं।', |
'रामायण और महाभारत भारत के प्रसिद्ध महाकाव्य हैं।' |
] |
if mode == 'similarity': |
print("Computing similarity matrix...") |
sim_matrix = model.compute_similarity(sentences) |
print("\nSentences:") |
for i, sentence in enumerate(sentences): |
print(f"[{i}] {sentence}") |
print("\nSimilarity matrix:") |
np.set_printoptions(precision=2) |
print(sim_matrix) |
print("\nMost similar sentence pairs:") |
sim_matrix_no_diag = sim_matrix.copy() |
np.fill_diagonal(sim_matrix_no_diag, -1) |
for _ in range(5): |
max_idx = np.unravel_index(sim_matrix_no_diag.argmax(), sim_matrix_no_diag.shape) |
i, j = max_idx |
print(f"Similarity: {sim_matrix[i, j]:.4f}") |
print(f"Sentence 1: {sentences[i]}") |
print(f"Sentence 2: {sentences[j]}") |
print("---") |
sim_matrix_no_diag[i, j] = -1 |
if __name__ == "__main__": |
main() |