File size: 3,385 Bytes
3f5fd34 dfbcc66 3f5fd34 f76d804 3f5fd34 f76d804 3f5fd34 c007e39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
import faiss
import torch
import numpy as np
def log(message):
print(f"β
{message}")
# β
Load datasets
datasets = {
"sales": load_dataset("goendalf666/sales-conversations"),
"blended": load_dataset("blended_skill_talk"),
"dialog": load_dataset("daily_dialog"),
"multiwoz": load_dataset("multi_woz_v22"),
}
# β
Load MiniLM model and tokenizer
model_name = "sentence-transformers/all-MiniLM-L6-v2" # Model for embeddings
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
def embed_text(texts):
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
with torch.no_grad():
embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
return embeddings
# β
Extract and embed the datasets
def create_embeddings(dataset_name, dataset):
print(f"π₯ Creating embeddings for {dataset_name}...")
if dataset_name == "goendalf666/sales-conversations":
texts = [" ".join(row.values()) for row in dataset["train"]]
elif dataset_name == "AlekseyKorshuk/persona-chat":
texts = [" ".join(utterance["candidates"]) for utterance in dataset["train"]["utterances"]]
elif dataset_name == "blended_skill_talk":
texts = [" ".join(row["free_messages"] + row["guided_messages"]) for row in dataset["train"]]
elif dataset_name == "daily_dialog":
texts = [" ".join(row["dialog"]) for row in dataset["train"]]
elif dataset_name == "multi_woz_v22":
texts = [" ".join(row["turns"]["utterance"]) for row in dataset["train"]]
else:
print(f"β οΈ Warning: Dataset {dataset_name} not handled properly!")
texts = []
# β
Verify dataset extraction
if len(texts) == 0:
print(f"β ERROR: No text extracted from {dataset_name}! Check dataset structure.")
else:
print(f"β
Extracted {len(texts)} texts from {dataset_name}. Sample:\n{texts[:3]}")
return texts
# β
Embed and store in FAISS
for name, dataset in datasets.items():
texts = create_embeddings(name, dataset)
if len(texts) > 0: # β
Only embed if texts exist
embeddings = embed_text(texts)
print(f"β
Generated embeddings shape: {embeddings.shape}")
index = save_embeddings_to_faiss(embeddings)
print(f"β
Embeddings for {name} saved to FAISS.")
else:
print(f"β οΈ Skipping embedding for {name} (No valid texts).")
# β
Save embeddings to a database
def save_embeddings_to_faiss(embeddings, index_name="my_embeddings"):
print("Saving embeddings to FAISS...")
index = faiss.IndexFlatL2(embeddings.shape[1]) # Assuming 512-dimensional embeddings
index.add(np.array(embeddings).astype(np.float32))
faiss.write_index(index, index_name) # Save FAISS index to file
return index
# β
Create embeddings for all datasets
for name, dataset in datasets.items():
embeddings = create_embeddings(name, dataset)
index = save_embeddings_to_faiss(embeddings)
print(f"Embeddings for {name} saved to FAISS.")
# β
Check FAISS index after saving
index = faiss.read_index("my_embeddings") # Load the index
print(f"π FAISS index contains {index.ntotal} vectors.") # Check how many embeddings were stored
|