File size: 3,729 Bytes
663cc24
 
3f5fd34
 
 
663cc24
 
 
 
 
3f5fd34
dfbcc66
 
 
663cc24
 
 
 
 
 
 
 
 
 
 
 
 
 
3f5fd34
663cc24
 
 
 
3f5fd34
 
a124d51
 
3f5fd34
 
 
 
a124d51
3f5fd34
 
 
 
 
a124d51
 
663cc24
 
f76d804
663cc24
 
 
 
 
 
 
 
f76d804
663cc24
f76d804
 
a124d51
f76d804
a124d51
 
 
 
 
 
 
 
 
f76d804
a124d51
 
 
f76d804
a124d51
 
 
 
 
 
 
 
 
f76d804
a124d51
 
f76d804
a124d51
 
f76d804
a124d51
3f5fd34
663cc24
 
 
 
c007e39
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import json
import faiss
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel

# βœ… Set up directories
DATA_DIR = "data"
os.makedirs(DATA_DIR, exist_ok=True)  # Ensure data directory exists

def log(message):
    print(f"βœ… {message}")

# βœ… Load datasets from stored JSON files
def load_local_dataset(dataset_name):
    file_path = os.path.join(DATA_DIR, f"{dataset_name}.json")
    
    if os.path.exists(file_path):
        with open(file_path, "r") as f:
            data = json.load(f)
        log(f"πŸ“‚ Loaded {dataset_name} from {file_path}")
        return data
    else:
        log(f"❌ ERROR: {dataset_name} file not found!")
        return None

# βœ… Load all datasets from storage
datasets = {
    "sales": load_local_dataset("sales"),
    "blended": load_local_dataset("blended"),
    "dialog": load_local_dataset("dialog"),
    "multiwoz": load_local_dataset("multiwoz"),
}

# βœ… Load MiniLM model for embeddings
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

def embed_text(texts):
    """Generate embeddings for a batch of texts."""
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
    with torch.no_grad():
        embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
    return embeddings

# βœ… Batch processing function
def create_embeddings(dataset_name, dataset, batch_size=100):
    """Extracts texts, embeds them in batches, and logs progress."""
    log(f"πŸ“₯ Creating embeddings for {dataset_name}...")
    
    if dataset_name == "sales":
        texts = [" ".join(row.values()) for row in dataset]
    elif dataset_name == "blended":
        texts = [" ".join(row["free_messages"] + row["guided_messages"]) for row in dataset]
    elif dataset_name == "dialog":
        texts = [" ".join(row["dialog"]) for row in dataset]
    elif dataset_name == "multiwoz":
        texts = [" ".join(row["turns"]["utterance"]) for row in dataset]
    else:
        log(f"⚠️ Warning: Dataset {dataset_name} format unknown!")
        texts = []

    log(f"βœ… Extracted {len(texts)} texts from {dataset_name}.")

    # Process in batches
    all_embeddings = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i : i + batch_size]
        batch_embeddings = embed_text(batch)
        all_embeddings.append(batch_embeddings)
        
        # βœ… Log progress
        log(f"πŸš€ Processed {i + len(batch)}/{len(texts)} embeddings for {dataset_name}...")

    # Convert list of numpy arrays to a single numpy array
    all_embeddings = np.vstack(all_embeddings)
    return all_embeddings

# βœ… Save embeddings to FAISS with unique filename
def save_embeddings_to_faiss(embeddings, index_name="my_embeddings"):
    index_file = f"{index_name}.faiss"
    
    # βœ… Check if previous FAISS index exists, append if needed
    if os.path.exists(index_file):
        log("πŸ”„ Loading existing FAISS index to append...")
        index = faiss.read_index(index_file)
        index.add(np.array(embeddings).astype(np.float32))
    else:
        index = faiss.IndexFlatL2(embeddings.shape[1])
        index.add(np.array(embeddings).astype(np.float32))

    faiss.write_index(index, index_file)  # βœ… Save FAISS index
    log(f"βœ… Saved FAISS index: {index_file}")

# βœ… Run embeddings process
for name, dataset in datasets.items():
    if dataset:  # Skip if dataset failed to load
        embeddings = create_embeddings(name, dataset, batch_size=100)
        save_embeddings_to_faiss(embeddings, index_name=name)
        log(f"βœ… Embeddings for {name} saved to FAISS.")