|
import os |
|
import json |
|
import faiss |
|
import torch |
|
import numpy as np |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
DATA_DIR = "data" |
|
os.makedirs(DATA_DIR, exist_ok=True) |
|
|
|
def log(message): |
|
print(f"β
{message}") |
|
|
|
|
|
def load_local_dataset(dataset_name): |
|
file_path = os.path.join(DATA_DIR, f"{dataset_name}.json") |
|
|
|
if os.path.exists(file_path): |
|
with open(file_path, "r") as f: |
|
data = json.load(f) |
|
log(f"π Loaded {dataset_name} from {file_path}") |
|
return data |
|
else: |
|
log(f"β ERROR: {dataset_name} file not found!") |
|
return None |
|
|
|
|
|
datasets = { |
|
"sales": load_local_dataset("sales"), |
|
"blended": load_local_dataset("blended"), |
|
"dialog": load_local_dataset("dialog"), |
|
"multiwoz": load_local_dataset("multiwoz"), |
|
} |
|
|
|
|
|
model_name = "sentence-transformers/all-MiniLM-L6-v2" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModel.from_pretrained(model_name) |
|
|
|
def embed_text(texts): |
|
"""Generate embeddings for a batch of texts.""" |
|
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512) |
|
with torch.no_grad(): |
|
embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy() |
|
return embeddings |
|
|
|
|
|
def create_embeddings(dataset_name, dataset, batch_size=100): |
|
"""Extracts texts, embeds them in batches, and logs progress.""" |
|
log(f"π₯ Creating embeddings for {dataset_name}...") |
|
|
|
if dataset_name == "sales": |
|
texts = [" ".join(row.values()) for row in dataset] |
|
elif dataset_name == "blended": |
|
texts = [" ".join(row["free_messages"] + row["guided_messages"]) for row in dataset] |
|
elif dataset_name == "dialog": |
|
texts = [" ".join(row["dialog"]) for row in dataset] |
|
elif dataset_name == "multiwoz": |
|
texts = [" ".join(row["turns"]["utterance"]) for row in dataset] |
|
else: |
|
log(f"β οΈ Warning: Dataset {dataset_name} format unknown!") |
|
texts = [] |
|
|
|
log(f"β
Extracted {len(texts)} texts from {dataset_name}.") |
|
|
|
|
|
all_embeddings = [] |
|
for i in range(0, len(texts), batch_size): |
|
batch = texts[i : i + batch_size] |
|
batch_embeddings = embed_text(batch) |
|
all_embeddings.append(batch_embeddings) |
|
|
|
|
|
log(f"π Processed {i + len(batch)}/{len(texts)} embeddings for {dataset_name}...") |
|
|
|
|
|
all_embeddings = np.vstack(all_embeddings) |
|
return all_embeddings |
|
|
|
|
|
def save_embeddings_to_faiss(embeddings, index_name="my_embeddings"): |
|
index_file = f"{index_name}.faiss" |
|
|
|
|
|
if os.path.exists(index_file): |
|
log("π Loading existing FAISS index to append...") |
|
index = faiss.read_index(index_file) |
|
index.add(np.array(embeddings).astype(np.float32)) |
|
else: |
|
index = faiss.IndexFlatL2(embeddings.shape[1]) |
|
index.add(np.array(embeddings).astype(np.float32)) |
|
|
|
faiss.write_index(index, index_file) |
|
log(f"β
Saved FAISS index: {index_file}") |
|
|
|
|
|
for name, dataset in datasets.items(): |
|
if dataset: |
|
embeddings = create_embeddings(name, dataset, batch_size=100) |
|
save_embeddings_to_faiss(embeddings, index_name=name) |
|
log(f"β
Embeddings for {name} saved to FAISS.") |
|
|
|
|