from datasets import load_dataset from transformers import AutoTokenizer, AutoModel import faiss import torch import numpy as np def log(message): print(f"✅ {message}") # ✅ Load datasets datasets = { "sales": load_dataset("goendalf666/sales-conversations"), "blended": load_dataset("blended_skill_talk"), "dialog": load_dataset("daily_dialog"), "multiwoz": load_dataset("multi_woz_v22"), } # ✅ Load MiniLM model and tokenizer model_name = "sentence-transformers/all-MiniLM-L6-v2" # Model for embeddings tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name) def embed_text(texts): inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512) with torch.no_grad(): embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy() return embeddings # ✅ Extract and embed the datasets def create_embeddings(dataset_name, dataset): print(f"📥 Creating embeddings for {dataset_name}...") if dataset_name == "goendalf666/sales-conversations": texts = [" ".join(row.values()) for row in dataset["train"]] elif dataset_name == "AlekseyKorshuk/persona-chat": texts = [" ".join(utterance["candidates"]) for utterance in dataset["train"]["utterances"]] elif dataset_name == "blended_skill_talk": texts = [" ".join(row["free_messages"] + row["guided_messages"]) for row in dataset["train"]] elif dataset_name == "daily_dialog": texts = [" ".join(row["dialog"]) for row in dataset["train"]] elif dataset_name == "multi_woz_v22": texts = [" ".join(row["turns"]["utterance"]) for row in dataset["train"]] else: print(f"⚠️ Warning: Dataset {dataset_name} not handled properly!") texts = [] # ✅ Verify dataset extraction if len(texts) == 0: print(f"❌ ERROR: No text extracted from {dataset_name}! Check dataset structure.") else: print(f"✅ Extracted {len(texts)} texts from {dataset_name}. Sample:\n{texts[:3]}") return texts # ✅ Embed and store in FAISS for name, dataset in datasets.items(): texts = create_embeddings(name, dataset) if len(texts) > 0: # ✅ Only embed if texts exist embeddings = embed_text(texts) print(f"✅ Generated embeddings shape: {embeddings.shape}") index = save_embeddings_to_faiss(embeddings) print(f"✅ Embeddings for {name} saved to FAISS.") else: print(f"⚠️ Skipping embedding for {name} (No valid texts).") # ✅ Save embeddings to a database def save_embeddings_to_faiss(embeddings, index_name="my_embeddings"): print("Saving embeddings to FAISS...") index = faiss.IndexFlatL2(embeddings.shape[1]) # Assuming 512-dimensional embeddings index.add(np.array(embeddings).astype(np.float32)) faiss.write_index(index, index_name) # Save FAISS index to file return index # ✅ Create embeddings for all datasets for name, dataset in datasets.items(): embeddings = create_embeddings(name, dataset) index = save_embeddings_to_faiss(embeddings) print(f"Embeddings for {name} saved to FAISS.") # ✅ Check FAISS index after saving index = faiss.read_index("my_embeddings") # Load the index print(f"📊 FAISS index contains {index.ntotal} vectors.") # Check how many embeddings were stored