abrah926 commited on
Commit
663cc24
Β·
verified Β·
1 Parent(s): e925ddf
Files changed (1) hide show
  1. embeddings.py +40 -21
embeddings.py CHANGED
@@ -1,19 +1,36 @@
1
- from datasets import load_dataset
2
- from transformers import AutoTokenizer, AutoModel
3
  import faiss
4
  import torch
5
  import numpy as np
6
- import os
 
 
 
 
7
 
8
  def log(message):
9
  print(f"βœ… {message}")
10
 
11
- # βœ… Load datasets
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  datasets = {
13
- "sales": load_dataset("goendalf666/sales-conversations"),
14
- "blended": load_dataset("blended_skill_talk"),
15
- "dialog": load_dataset("daily_dialog"),
16
- "multiwoz": load_dataset("multi_woz_v22"),
17
  }
18
 
19
  # βœ… Load MiniLM model for embeddings
@@ -30,18 +47,19 @@ def embed_text(texts):
30
 
31
  # βœ… Batch processing function
32
  def create_embeddings(dataset_name, dataset, batch_size=100):
33
- print(f"πŸ“₯ Creating embeddings for {dataset_name}...")
 
34
 
35
- if dataset_name == "goendalf666/sales-conversations":
36
- texts = [" ".join(row.values()) for row in dataset["train"]]
37
- elif dataset_name == "blended_skill_talk":
38
- texts = [" ".join(row["free_messages"] + row["guided_messages"]) for row in dataset["train"]]
39
- elif dataset_name == "daily_dialog":
40
- texts = [" ".join(row["dialog"]) for row in dataset["train"]]
41
- elif dataset_name == "multi_woz_v22":
42
- texts = [" ".join(row["turns"]["utterance"]) for row in dataset["train"]]
43
  else:
44
- print(f"⚠️ Warning: Dataset {dataset_name} not handled properly!")
45
  texts = []
46
 
47
  log(f"βœ… Extracted {len(texts)} texts from {dataset_name}.")
@@ -78,7 +96,8 @@ def save_embeddings_to_faiss(embeddings, index_name="my_embeddings"):
78
 
79
  # βœ… Run embeddings process
80
  for name, dataset in datasets.items():
81
- embeddings = create_embeddings(name, dataset, batch_size=100)
82
- save_embeddings_to_faiss(embeddings, index_name=name)
83
- log(f"βœ… Embeddings for {name} saved to FAISS.")
 
84
 
 
1
+ import os
2
+ import json
3
  import faiss
4
  import torch
5
  import numpy as np
6
+ from transformers import AutoTokenizer, AutoModel
7
+
8
+ # βœ… Set up directories
9
+ DATA_DIR = "data"
10
+ os.makedirs(DATA_DIR, exist_ok=True) # Ensure data directory exists
11
 
12
  def log(message):
13
  print(f"βœ… {message}")
14
 
15
+ # βœ… Load datasets from stored JSON files
16
+ def load_local_dataset(dataset_name):
17
+ file_path = os.path.join(DATA_DIR, f"{dataset_name}.json")
18
+
19
+ if os.path.exists(file_path):
20
+ with open(file_path, "r") as f:
21
+ data = json.load(f)
22
+ log(f"πŸ“‚ Loaded {dataset_name} from {file_path}")
23
+ return data
24
+ else:
25
+ log(f"❌ ERROR: {dataset_name} file not found!")
26
+ return None
27
+
28
+ # βœ… Load all datasets from storage
29
  datasets = {
30
+ "sales": load_local_dataset("sales"),
31
+ "blended": load_local_dataset("blended"),
32
+ "dialog": load_local_dataset("dialog"),
33
+ "multiwoz": load_local_dataset("multiwoz"),
34
  }
35
 
36
  # βœ… Load MiniLM model for embeddings
 
47
 
48
  # βœ… Batch processing function
49
  def create_embeddings(dataset_name, dataset, batch_size=100):
50
+ """Extracts texts, embeds them in batches, and logs progress."""
51
+ log(f"πŸ“₯ Creating embeddings for {dataset_name}...")
52
 
53
+ if dataset_name == "sales":
54
+ texts = [" ".join(row.values()) for row in dataset]
55
+ elif dataset_name == "blended":
56
+ texts = [" ".join(row["free_messages"] + row["guided_messages"]) for row in dataset]
57
+ elif dataset_name == "dialog":
58
+ texts = [" ".join(row["dialog"]) for row in dataset]
59
+ elif dataset_name == "multiwoz":
60
+ texts = [" ".join(row["turns"]["utterance"]) for row in dataset]
61
  else:
62
+ log(f"⚠️ Warning: Dataset {dataset_name} format unknown!")
63
  texts = []
64
 
65
  log(f"βœ… Extracted {len(texts)} texts from {dataset_name}.")
 
96
 
97
  # βœ… Run embeddings process
98
  for name, dataset in datasets.items():
99
+ if dataset: # Skip if dataset failed to load
100
+ embeddings = create_embeddings(name, dataset, batch_size=100)
101
+ save_embeddings_to_faiss(embeddings, index_name=name)
102
+ log(f"βœ… Embeddings for {name} saved to FAISS.")
103