abrah926 commited on
Commit
f76d804
Β·
verified Β·
1 Parent(s): dfbcc66

updating proper format to embed the datasets and print statements to get progress

Browse files
Files changed (1) hide show
  1. embeddings.py +44 -4
embeddings.py CHANGED
@@ -27,12 +27,52 @@ def embed_text(texts):
27
  embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
28
  return embeddings
29
 
 
30
  # βœ… Extract and embed the datasets
31
  def create_embeddings(dataset_name, dataset):
32
- print(f"Creating embeddings for {dataset_name}...")
33
- texts = [text for text in dataset["train"]['text']] # Adjust the field depending on dataset structure
34
- embeddings = embed_text(texts)
35
- return embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  # βœ… Save embeddings to a database
38
  def save_embeddings_to_faiss(embeddings, index_name="my_embeddings"):
 
27
  embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
28
  return embeddings
29
 
30
+
31
  # βœ… Extract and embed the datasets
32
  def create_embeddings(dataset_name, dataset):
33
+ print(f"πŸ“₯ Creating embeddings for {dataset_name}...")
34
+
35
+ if dataset_name == "goendalf666/sales-conversations":
36
+ texts = [" ".join(row.values()) for row in dataset["train"]]
37
+
38
+ elif dataset_name == "AlekseyKorshuk/persona-chat":
39
+ texts = [" ".join(utterance["candidates"]) for utterance in dataset["train"]["utterances"]]
40
+
41
+ elif dataset_name == "blended_skill_talk":
42
+ texts = [" ".join(row["free_messages"] + row["guided_messages"]) for row in dataset["train"]]
43
+
44
+ elif dataset_name == "daily_dialog":
45
+ texts = [" ".join(row["dialog"]) for row in dataset["train"]]
46
+
47
+ elif dataset_name == "multi_woz_v22":
48
+ texts = [" ".join(row["turns"]["utterance"]) for row in dataset["train"]]
49
+
50
+ else:
51
+ print(f"⚠️ Warning: Dataset {dataset_name} not handled properly!")
52
+ texts = []
53
+
54
+ # βœ… Verify dataset extraction
55
+ if len(texts) == 0:
56
+ print(f"❌ ERROR: No text extracted from {dataset_name}! Check dataset structure.")
57
+ else:
58
+ print(f"βœ… Extracted {len(texts)} texts from {dataset_name}. Sample:\n{texts[:3]}")
59
+
60
+ return texts
61
+
62
+ # βœ… Embed and store in FAISS
63
+ for name, dataset in datasets.items():
64
+ texts = create_embeddings(name, dataset)
65
+
66
+ if len(texts) > 0: # βœ… Only embed if texts exist
67
+ embeddings = embed_text(texts)
68
+ print(f"βœ… Generated embeddings shape: {embeddings.shape}")
69
+
70
+ index = save_embeddings_to_faiss(embeddings)
71
+ print(f"βœ… Embeddings for {name} saved to FAISS.")
72
+ else:
73
+ print(f"⚠️ Skipping embedding for {name} (No valid texts).")
74
+
75
+
76
 
77
  # βœ… Save embeddings to a database
78
  def save_embeddings_to_faiss(embeddings, index_name="my_embeddings"):