new file to create embeddings
Browse files- embeddings.py +45 -0
embeddings.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
from transformers import AutoTokenizer, AutoModel
|
3 |
+
import faiss
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
# β
Load datasets
|
8 |
+
datasets = {
|
9 |
+
"sales": load_dataset("goendalf666/sales-conversations"),
|
10 |
+
"blended": load_dataset("blended_skill_talk"),
|
11 |
+
"dialog": load_dataset("daily_dialog"),
|
12 |
+
"multiwoz": load_dataset("multi_woz_v22"),
|
13 |
+
}
|
14 |
+
|
15 |
+
# β
Load MiniLM model and tokenizer
|
16 |
+
model_name = "sentence-transformers/all-MiniLM-L6-v2" # Model for embeddings
|
17 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
18 |
+
model = AutoModel.from_pretrained(model_name)
|
19 |
+
|
20 |
+
def embed_text(texts):
|
21 |
+
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
|
22 |
+
with torch.no_grad():
|
23 |
+
embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
|
24 |
+
return embeddings
|
25 |
+
|
26 |
+
# β
Extract and embed the datasets
|
27 |
+
def create_embeddings(dataset_name, dataset):
|
28 |
+
print(f"Creating embeddings for {dataset_name}...")
|
29 |
+
texts = [text for text in dataset["train"]['text']] # Adjust the field depending on dataset structure
|
30 |
+
embeddings = embed_text(texts)
|
31 |
+
return embeddings
|
32 |
+
|
33 |
+
# β
Save embeddings to a database
|
34 |
+
def save_embeddings_to_faiss(embeddings, index_name="my_embeddings"):
|
35 |
+
print("Saving embeddings to FAISS...")
|
36 |
+
index = faiss.IndexFlatL2(embeddings.shape[1]) # Assuming 512-dimensional embeddings
|
37 |
+
index.add(np.array(embeddings).astype(np.float32))
|
38 |
+
faiss.write_index(index, index_name) # Save FAISS index to file
|
39 |
+
return index
|
40 |
+
|
41 |
+
# β
Create embeddings for all datasets
|
42 |
+
for name, dataset in datasets.items():
|
43 |
+
embeddings = create_embeddings(name, dataset)
|
44 |
+
index = save_embeddings_to_faiss(embeddings)
|
45 |
+
print(f"Embeddings for {name} saved to FAISS.")
|