abrah926 commited on
Commit
38fe90f
Β·
verified Β·
1 Parent(s): 2ab9d34

restoring datasets

Browse files
Files changed (1) hide show
  1. app.py +64 -59
app.py CHANGED
@@ -1,78 +1,83 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
  import faiss
 
4
  import numpy as np
5
- import os
6
  import time
7
- import threading # βœ… Run embeddings in parallel
8
-
9
- # βœ… Ensure FAISS is installed
10
- os.system("pip install faiss-cpu")
11
 
12
  def log(message):
13
  print(f"βœ… {message}")
14
 
15
- # βœ… Step 1: Run Embeddings in a Separate Thread
16
- def run_embeddings():
17
- log("πŸš€ Running embeddings script in background...")
18
- import embeddings # βœ… This will automatically run embeddings.py
19
- log("βœ… Embeddings process finished.")
 
 
20
 
21
- embedding_thread = threading.Thread(target=run_embeddings)
22
- embedding_thread.start() # βœ… Start embedding in background
 
 
23
 
24
- # βœ… Step 2: Check FAISS index
25
- def check_faiss():
26
- index_path = "my_embeddings.faiss" # Ensure file has .faiss extension
 
 
 
27
 
28
- if not os.path.exists(index_path):
29
- return "⚠️ No FAISS index found! Embeddings might still be processing."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- try:
32
- index = faiss.read_index(index_path)
33
- num_vectors = index.ntotal
34
- dim = index.d
35
- return f"πŸ“Š FAISS index contains {num_vectors} vectors.\nβœ… Embedding dimension: {dim}"
36
- except Exception as e:
37
- return f"❌ ERROR: Failed to load FAISS index - {e}"
38
 
39
- log("πŸ” Checking FAISS embeddings...")
40
- faiss_status = check_faiss()
41
- log(faiss_status)
 
 
 
42
 
43
- # βœ… Step 3: Initialize Chatbot
44
- client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
45
 
46
- def respond(message, history, system_message, max_tokens, temperature, top_p):
47
- messages = [{"role": "system", "content": system_message}]
48
 
49
- for val in history:
50
- if val[0]:
51
- messages.append({"role": "user", "content": val[0]})
52
- if val[1]:
53
- messages.append({"role": "assistant", "content": val[1]})
54
 
55
- messages.append({"role": "user", "content": message})
56
- response = ""
 
57
 
58
- for message in client.chat_completions(
59
- messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p
60
- ):
61
- token = message["choices"][0]["delta"]["content"]
62
- response += token
63
- yield response
64
 
65
- # βœ… Step 4: Start Chatbot Interface
66
- demo = gr.ChatInterface(
67
- respond,
68
- additional_inputs=[
69
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
70
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
71
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
72
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
73
- ],
74
- )
75
 
76
- log("βœ… All systems go! Launching chatbot...")
77
- if __name__ == "__main__":
78
- demo.launch()
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import AutoTokenizer, AutoModel
3
  import faiss
4
+ import torch
5
  import numpy as np
 
6
  import time
 
 
 
 
7
 
8
  def log(message):
9
  print(f"βœ… {message}")
10
 
11
+ # βœ… Load datasets dynamically
12
+ datasets = {
13
+ "sales": load_dataset("goendalf666/sales-conversations"),
14
+ "blended": load_dataset("blended_skill_talk"),
15
+ "dialog": load_dataset("daily_dialog"),
16
+ "multiwoz": load_dataset("multi_woz_v22"),
17
+ }
18
 
19
+ # βœ… Load MiniLM model for embeddings
20
+ model_name = "sentence-transformers/all-MiniLM-L6-v2"
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ model = AutoModel.from_pretrained(model_name)
23
 
24
+ def embed_text(texts):
25
+ """Generate embeddings for a batch of texts."""
26
+ inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
27
+ with torch.no_grad():
28
+ embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
29
+ return embeddings
30
 
31
+ # βœ… Batch processing function
32
+ def create_embeddings(dataset_name, dataset, batch_size=100):
33
+ log(f"πŸ“₯ Creating embeddings for {dataset_name}...")
34
+
35
+ # βœ… Extract text based on dataset structure
36
+ if dataset_name == "sales":
37
+ texts = [" ".join(row.values()) for row in dataset["train"]]
38
+ elif dataset_name == "blended":
39
+ texts = [" ".join(row["free_messages"] + row["guided_messages"]) for row in dataset["train"]]
40
+ elif dataset_name == "dialog":
41
+ texts = [" ".join(row["dialog"]) for row in dataset["train"]]
42
+ elif dataset_name == "multiwoz":
43
+ texts = [" ".join(row["turns"]["utterance"]) for row in dataset["train"]]
44
+ else:
45
+ log(f"⚠️ Unknown dataset structure for {dataset_name}!")
46
+ texts = []
47
 
48
+ log(f"βœ… Extracted {len(texts)} texts from {dataset_name}.")
 
 
 
 
 
 
49
 
50
+ # βœ… Process in batches
51
+ all_embeddings = []
52
+ for i in range(0, len(texts), batch_size):
53
+ batch = texts[i : i + batch_size]
54
+ batch_embeddings = embed_text(batch)
55
+ all_embeddings.append(batch_embeddings)
56
 
57
+ # βœ… Log progress
58
+ log(f"πŸš€ Processed {i + len(batch)}/{len(texts)} embeddings for {dataset_name}...")
59
 
60
+ # βœ… Simulate delay for monitoring
61
+ time.sleep(1)
62
 
63
+ # βœ… Convert list of numpy arrays to a single numpy array
64
+ all_embeddings = np.vstack(all_embeddings)
65
+ return all_embeddings
 
 
66
 
67
+ # βœ… Save embeddings to FAISS with unique filename
68
+ def save_embeddings_to_faiss(embeddings, index_name="my_embeddings"):
69
+ index_file = f"{index_name}.faiss"
70
 
71
+ # βœ… Create new FAISS index
72
+ index = faiss.IndexFlatL2(embeddings.shape[1])
73
+ index.add(np.array(embeddings).astype(np.float32))
 
 
 
74
 
75
+ # βœ… Save FAISS index
76
+ faiss.write_index(index, index_file)
77
+ log(f"βœ… Saved FAISS index: {index_file}")
 
 
 
 
 
 
 
78
 
79
+ # βœ… Run embedding process for all datasets
80
+ for name, dataset in datasets.items():
81
+ embeddings = create_embeddings(name, dataset, batch_size=100)
82
+ save_embeddings_to_faiss(embeddings, index_name=name)
83
+ log(f"βœ… Embeddings for {name} saved to FAISS.")