sms_agent / embeddings.py
abrah926's picture
update
663cc24 verified
import os
import json
import faiss
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
# βœ… Set up directories
DATA_DIR = "data"
os.makedirs(DATA_DIR, exist_ok=True) # Ensure data directory exists
def log(message):
print(f"βœ… {message}")
# βœ… Load datasets from stored JSON files
def load_local_dataset(dataset_name):
file_path = os.path.join(DATA_DIR, f"{dataset_name}.json")
if os.path.exists(file_path):
with open(file_path, "r") as f:
data = json.load(f)
log(f"πŸ“‚ Loaded {dataset_name} from {file_path}")
return data
else:
log(f"❌ ERROR: {dataset_name} file not found!")
return None
# βœ… Load all datasets from storage
datasets = {
"sales": load_local_dataset("sales"),
"blended": load_local_dataset("blended"),
"dialog": load_local_dataset("dialog"),
"multiwoz": load_local_dataset("multiwoz"),
}
# βœ… Load MiniLM model for embeddings
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
def embed_text(texts):
"""Generate embeddings for a batch of texts."""
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
with torch.no_grad():
embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
return embeddings
# βœ… Batch processing function
def create_embeddings(dataset_name, dataset, batch_size=100):
"""Extracts texts, embeds them in batches, and logs progress."""
log(f"πŸ“₯ Creating embeddings for {dataset_name}...")
if dataset_name == "sales":
texts = [" ".join(row.values()) for row in dataset]
elif dataset_name == "blended":
texts = [" ".join(row["free_messages"] + row["guided_messages"]) for row in dataset]
elif dataset_name == "dialog":
texts = [" ".join(row["dialog"]) for row in dataset]
elif dataset_name == "multiwoz":
texts = [" ".join(row["turns"]["utterance"]) for row in dataset]
else:
log(f"⚠️ Warning: Dataset {dataset_name} format unknown!")
texts = []
log(f"βœ… Extracted {len(texts)} texts from {dataset_name}.")
# Process in batches
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
batch_embeddings = embed_text(batch)
all_embeddings.append(batch_embeddings)
# βœ… Log progress
log(f"πŸš€ Processed {i + len(batch)}/{len(texts)} embeddings for {dataset_name}...")
# Convert list of numpy arrays to a single numpy array
all_embeddings = np.vstack(all_embeddings)
return all_embeddings
# βœ… Save embeddings to FAISS with unique filename
def save_embeddings_to_faiss(embeddings, index_name="my_embeddings"):
index_file = f"{index_name}.faiss"
# βœ… Check if previous FAISS index exists, append if needed
if os.path.exists(index_file):
log("πŸ”„ Loading existing FAISS index to append...")
index = faiss.read_index(index_file)
index.add(np.array(embeddings).astype(np.float32))
else:
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(np.array(embeddings).astype(np.float32))
faiss.write_index(index, index_file) # βœ… Save FAISS index
log(f"βœ… Saved FAISS index: {index_file}")
# βœ… Run embeddings process
for name, dataset in datasets.items():
if dataset: # Skip if dataset failed to load
embeddings = create_embeddings(name, dataset, batch_size=100)
save_embeddings_to_faiss(embeddings, index_name=name)
log(f"βœ… Embeddings for {name} saved to FAISS.")