updating proper format to embed the datasets and print statements to get progress
Browse files- embeddings.py +44 -4
embeddings.py
CHANGED
@@ -27,12 +27,52 @@ def embed_text(texts):
|
|
27 |
embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
|
28 |
return embeddings
|
29 |
|
|
|
30 |
# β
Extract and embed the datasets
|
31 |
def create_embeddings(dataset_name, dataset):
|
32 |
-
print(f"Creating embeddings for {dataset_name}...")
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
# β
Save embeddings to a database
|
38 |
def save_embeddings_to_faiss(embeddings, index_name="my_embeddings"):
|
|
|
27 |
embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
|
28 |
return embeddings
|
29 |
|
30 |
+
|
31 |
# β
Extract and embed the datasets
|
32 |
def create_embeddings(dataset_name, dataset):
|
33 |
+
print(f"π₯ Creating embeddings for {dataset_name}...")
|
34 |
+
|
35 |
+
if dataset_name == "goendalf666/sales-conversations":
|
36 |
+
texts = [" ".join(row.values()) for row in dataset["train"]]
|
37 |
+
|
38 |
+
elif dataset_name == "AlekseyKorshuk/persona-chat":
|
39 |
+
texts = [" ".join(utterance["candidates"]) for utterance in dataset["train"]["utterances"]]
|
40 |
+
|
41 |
+
elif dataset_name == "blended_skill_talk":
|
42 |
+
texts = [" ".join(row["free_messages"] + row["guided_messages"]) for row in dataset["train"]]
|
43 |
+
|
44 |
+
elif dataset_name == "daily_dialog":
|
45 |
+
texts = [" ".join(row["dialog"]) for row in dataset["train"]]
|
46 |
+
|
47 |
+
elif dataset_name == "multi_woz_v22":
|
48 |
+
texts = [" ".join(row["turns"]["utterance"]) for row in dataset["train"]]
|
49 |
+
|
50 |
+
else:
|
51 |
+
print(f"β οΈ Warning: Dataset {dataset_name} not handled properly!")
|
52 |
+
texts = []
|
53 |
+
|
54 |
+
# β
Verify dataset extraction
|
55 |
+
if len(texts) == 0:
|
56 |
+
print(f"β ERROR: No text extracted from {dataset_name}! Check dataset structure.")
|
57 |
+
else:
|
58 |
+
print(f"β
Extracted {len(texts)} texts from {dataset_name}. Sample:\n{texts[:3]}")
|
59 |
+
|
60 |
+
return texts
|
61 |
+
|
62 |
+
# β
Embed and store in FAISS
|
63 |
+
for name, dataset in datasets.items():
|
64 |
+
texts = create_embeddings(name, dataset)
|
65 |
+
|
66 |
+
if len(texts) > 0: # β
Only embed if texts exist
|
67 |
+
embeddings = embed_text(texts)
|
68 |
+
print(f"β
Generated embeddings shape: {embeddings.shape}")
|
69 |
+
|
70 |
+
index = save_embeddings_to_faiss(embeddings)
|
71 |
+
print(f"β
Embeddings for {name} saved to FAISS.")
|
72 |
+
else:
|
73 |
+
print(f"β οΈ Skipping embedding for {name} (No valid texts).")
|
74 |
+
|
75 |
+
|
76 |
|
77 |
# β
Save embeddings to a database
|
78 |
def save_embeddings_to_faiss(embeddings, index_name="my_embeddings"):
|