|
from datasets import load_dataset |
|
from transformers import AutoTokenizer, AutoModel |
|
import faiss |
|
import torch |
|
import numpy as np |
|
|
|
def log(message): |
|
print(f"β
{message}") |
|
|
|
|
|
|
|
datasets = { |
|
"sales": load_dataset("goendalf666/sales-conversations"), |
|
"blended": load_dataset("blended_skill_talk"), |
|
"dialog": load_dataset("daily_dialog"), |
|
"multiwoz": load_dataset("multi_woz_v22"), |
|
} |
|
|
|
|
|
model_name = "sentence-transformers/all-MiniLM-L6-v2" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModel.from_pretrained(model_name) |
|
|
|
def embed_text(texts): |
|
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512) |
|
with torch.no_grad(): |
|
embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy() |
|
return embeddings |
|
|
|
|
|
|
|
def create_embeddings(dataset_name, dataset): |
|
print(f"π₯ Creating embeddings for {dataset_name}...") |
|
|
|
if dataset_name == "goendalf666/sales-conversations": |
|
texts = [" ".join(row.values()) for row in dataset["train"]] |
|
|
|
elif dataset_name == "AlekseyKorshuk/persona-chat": |
|
texts = [" ".join(utterance["candidates"]) for utterance in dataset["train"]["utterances"]] |
|
|
|
elif dataset_name == "blended_skill_talk": |
|
texts = [" ".join(row["free_messages"] + row["guided_messages"]) for row in dataset["train"]] |
|
|
|
elif dataset_name == "daily_dialog": |
|
texts = [" ".join(row["dialog"]) for row in dataset["train"]] |
|
|
|
elif dataset_name == "multi_woz_v22": |
|
texts = [" ".join(row["turns"]["utterance"]) for row in dataset["train"]] |
|
|
|
else: |
|
print(f"β οΈ Warning: Dataset {dataset_name} not handled properly!") |
|
texts = [] |
|
|
|
|
|
if len(texts) == 0: |
|
print(f"β ERROR: No text extracted from {dataset_name}! Check dataset structure.") |
|
else: |
|
print(f"β
Extracted {len(texts)} texts from {dataset_name}. Sample:\n{texts[:3]}") |
|
|
|
return texts |
|
|
|
|
|
for name, dataset in datasets.items(): |
|
texts = create_embeddings(name, dataset) |
|
|
|
if len(texts) > 0: |
|
embeddings = embed_text(texts) |
|
print(f"β
Generated embeddings shape: {embeddings.shape}") |
|
|
|
index = save_embeddings_to_faiss(embeddings) |
|
print(f"β
Embeddings for {name} saved to FAISS.") |
|
else: |
|
print(f"β οΈ Skipping embedding for {name} (No valid texts).") |
|
|
|
|
|
|
|
|
|
def save_embeddings_to_faiss(embeddings, index_name="my_embeddings"): |
|
print("Saving embeddings to FAISS...") |
|
index = faiss.IndexFlatL2(embeddings.shape[1]) |
|
index.add(np.array(embeddings).astype(np.float32)) |
|
faiss.write_index(index, index_name) |
|
return index |
|
|
|
|
|
for name, dataset in datasets.items(): |
|
embeddings = create_embeddings(name, dataset) |
|
index = save_embeddings_to_faiss(embeddings) |
|
print(f"Embeddings for {name} saved to FAISS.") |
|
|
|
|
|
|
|
index = faiss.read_index("my_embeddings") |
|
print(f"π FAISS index contains {index.ntotal} vectors.") |
|
|
|
|