abrah926 commited on
Commit
3f5fd34
Β·
verified Β·
1 Parent(s): 54072ad

new file to create embeddings

Browse files
Files changed (1) hide show
  1. embeddings.py +45 -0
embeddings.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import AutoTokenizer, AutoModel
3
+ import faiss
4
+ import torch
5
+ import numpy as np
6
+
7
+ # βœ… Load datasets
8
+ datasets = {
9
+ "sales": load_dataset("goendalf666/sales-conversations"),
10
+ "blended": load_dataset("blended_skill_talk"),
11
+ "dialog": load_dataset("daily_dialog"),
12
+ "multiwoz": load_dataset("multi_woz_v22"),
13
+ }
14
+
15
+ # βœ… Load MiniLM model and tokenizer
16
+ model_name = "sentence-transformers/all-MiniLM-L6-v2" # Model for embeddings
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ model = AutoModel.from_pretrained(model_name)
19
+
20
+ def embed_text(texts):
21
+ inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
22
+ with torch.no_grad():
23
+ embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
24
+ return embeddings
25
+
26
+ # βœ… Extract and embed the datasets
27
+ def create_embeddings(dataset_name, dataset):
28
+ print(f"Creating embeddings for {dataset_name}...")
29
+ texts = [text for text in dataset["train"]['text']] # Adjust the field depending on dataset structure
30
+ embeddings = embed_text(texts)
31
+ return embeddings
32
+
33
+ # βœ… Save embeddings to a database
34
+ def save_embeddings_to_faiss(embeddings, index_name="my_embeddings"):
35
+ print("Saving embeddings to FAISS...")
36
+ index = faiss.IndexFlatL2(embeddings.shape[1]) # Assuming 512-dimensional embeddings
37
+ index.add(np.array(embeddings).astype(np.float32))
38
+ faiss.write_index(index, index_name) # Save FAISS index to file
39
+ return index
40
+
41
+ # βœ… Create embeddings for all datasets
42
+ for name, dataset in datasets.items():
43
+ embeddings = create_embeddings(name, dataset)
44
+ index = save_embeddings_to_faiss(embeddings)
45
+ print(f"Embeddings for {name} saved to FAISS.")