broadfield-dev commited on
Commit
58b8f23
·
verified ·
1 Parent(s): 79afc90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +332 -188
app.py CHANGED
@@ -1,199 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import pandas as pd
3
- from datasets import load_dataset, get_dataset_split_names
4
- from huggingface_hub import HfApi
5
- import os
6
- import pathlib
7
- import uuid
8
- import logging
9
-
10
- # --- Setup Logging ---
11
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
-
13
- # --- Embedding Atlas Imports ---
14
- from embedding_atlas.data_source import DataSource
15
- from embedding_atlas.server import make_server
16
- from embedding_atlas.projection import compute_text_projection
17
- from embedding_atlas.utils import Hasher
18
-
19
- # --- Helper function from embedding_atlas/cli.py ---
20
- def find_column_name(existing_names, candidate):
21
- """Finds a unique column name, appending '_1', '_2', etc. if the candidate name already exists."""
22
- if candidate not in existing_names:
23
- return candidate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  else:
25
- index = 1
26
- while True:
27
- s = f"{candidate}_{index}"
28
- if s not in existing_names:
29
- return s
30
- index += 1
31
-
32
- # --- Hugging Face API Helpers for Dynamic UI ---
33
- hf_api = HfApi()
34
-
35
- def get_user_datasets(username: str):
36
- """Fetches all public datasets for a given username or organization."""
37
- logging.info(f"Fetching datasets for user: {username}")
38
- if not username:
39
- return gr.update(choices=[], value=None, interactive=False)
40
- try:
41
- datasets = hf_api.list_datasets(author=username, full=True)
42
- dataset_ids = [d.id for d in datasets if not d.private]
43
- logging.info(f"Found {len(dataset_ids)} datasets for {username}.")
44
- return gr.update(choices=sorted(dataset_ids), value=None, interactive=True)
45
- except Exception as e:
46
- logging.error(f"Failed to fetch datasets for {username}: {e}")
47
- gr.Warning(f"Could not fetch datasets for user '{username}'.")
48
- return gr.update(choices=[], value=None, interactive=False)
49
-
50
- def get_dataset_splits(dataset_id: str):
51
- """Gets all available splits for a selected dataset."""
52
- logging.info(f"Fetching splits for dataset: {dataset_id}")
53
- if not dataset_id:
54
- return gr.update(choices=[], value=None, interactive=False)
55
- try:
56
- splits = get_dataset_split_names(dataset_id)
57
- logging.info(f"Found splits for {dataset_id}: {splits}")
58
- return gr.update(choices=splits, value=splits[0] if splits else None, interactive=True)
59
- except Exception as e:
60
- logging.error(f"Failed to fetch splits for {dataset_id}: {e}")
61
- gr.Warning(f"Could not fetch splits for dataset '{dataset_id}'.")
62
- return gr.update(choices=[], value=None, interactive=False)
63
-
64
- def get_split_columns(dataset_id: str, split: str):
65
- """Gets all columns for a selected split by loading one row."""
66
- logging.info(f"Fetching columns for: {dataset_id}, split: {split}")
67
- if not dataset_id or not split:
68
- return gr.update(choices=[], value=None, interactive=False)
69
- try:
70
- dataset_sample = load_dataset(dataset_id, split=split, streaming=True)
71
- first_row = next(iter(dataset_sample))
72
- columns = list(first_row.keys())
73
- logging.info(f"Found columns: {columns}")
74
-
75
- preferred_cols = ['text', 'content', 'instruction', 'question', 'document', 'prompt']
76
- best_col = next((col for col in preferred_cols if col in columns), columns[0] if columns else None)
77
- logging.info(f"Best default column chosen: {best_col}")
 
 
 
 
 
 
78
 
79
- return gr.update(choices=columns, value=best_col, interactive=True)
80
- except Exception as e:
81
- logging.error(f"Failed to get columns for {dataset_id}/{split}: {e}", exc_info=True)
82
- gr.Warning(f"Could not fetch columns for split '{split}'. Error: {e}")
83
- return gr.update(choices=[], value=None, interactive=False)
84
-
85
- # --- Main Atlas Generation Logic ---
86
- def generate_atlas(
87
- dataset_name: str,
88
- split: str,
89
- text_column: str,
90
- sample_size: int,
91
- model_name: str,
92
- umap_neighbors: int,
93
- umap_min_dist: float,
94
- progress=gr.Progress(track_tqdm=True)
95
- ):
96
- """
97
- Loads data, computes embeddings, and serves the Embedding Atlas UI.
98
- """
99
- if not all([dataset_name, split, text_column]):
100
- raise gr.Error("Please ensure a Dataset, Split, and Text Column are selected.")
101
 
102
- progress(0, desc=f"Loading dataset '{dataset_name}' [{split}]...")
103
- try:
104
- dataset = load_dataset(dataset_name, split=split)
105
- df = dataset.to_pandas()
106
- except Exception as e:
107
- raise gr.Error(f"Failed to load data. Error: {e}")
108
-
109
- if sample_size > 0 and sample_size < len(df):
110
- progress(0.1, desc=f"Sampling {sample_size} rows...")
111
- df = df.sample(n=sample_size, random_state=42).reset_index(drop=True)
112
-
113
- if text_column not in df.columns:
114
- raise gr.Error(f"Column '{text_column}' not found. Please select a valid column.")
115
-
116
- progress(0.2, desc="Computing embeddings and UMAP...")
117
-
118
- x_col = find_column_name(df.columns, "projection_x")
119
- y_col = find_column_name(df.columns, "projection_y")
120
- neighbors_col = find_column_name(df.columns, "__neighbors")
121
-
122
- try:
123
- compute_text_projection(
124
- df, text_column, x=x_col, y=y_col, neighbors=neighbors_col, model=model_name,
125
- umap_args={"n_neighbors": umap_neighbors, "min_dist": umap_min_dist, "metric": "cosine", "random_state": 42},
126
- )
127
- except Exception as e:
128
- raise gr.Error(f"Failed to compute embeddings. Check model name or sample size. Error: {e}")
129
-
130
- progress(0.8, desc="Preparing Atlas data source...")
131
- id_col = find_column_name(df.columns, "_row_index")
132
- df[id_col] = range(df.shape[0])
133
-
134
- metadata = {"columns": {"id": id_col, "text": text_column, "embedding": {"x": x_col, "y": y_col}, "neighbors": neighbors_col}}
135
- hasher = Hasher()
136
- hasher.update(f"{dataset_name}-{split}-{text_column}-{sample_size}-{model_name}")
137
- identifier = hasher.hexdigest()
138
- atlas_dataset = DataSource(identifier, df, metadata)
139
-
140
- progress(0.9, desc="Mounting visualization UI...")
141
- static_path = str((pathlib.Path(__import__('embedding_atlas').__file__).parent / "static").resolve())
142
- mount_path = f"/{uuid.uuid4().hex}"
143
- atlas_app = make_server(atlas_dataset, static_path=static_path, duckdb_uri="wasm")
144
-
145
- # --- THIS IS THE FINAL, CORRECT METHOD ---
146
- # The 'app' object from gr.Blocks() is a FastAPI app. We use its standard 'mount' method.
147
- app.mount(mount_path, atlas_app)
148
-
149
- progress(1.0, desc="Done!")
150
- iframe_html = f"<iframe src='{mount_path}' width='100%' height='800px' frameborder='0'></iframe>"
151
- return gr.HTML(iframe_html)
152
-
153
- # --- Gradio UI Definition ---
154
- with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
155
- gr.Markdown("# Embedding Atlas Explorer")
156
- gr.Markdown("Interactively select and visualize any text-based dataset from the Hugging Face Hub.")
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  with gr.Row():
159
  with gr.Column(scale=1):
160
- gr.Markdown("### 1. Select Data")
161
- hf_user_input = gr.Textbox(label="Hugging Face User or Org Name", value="Trendyol", placeholder="e.g., 'gradio' or 'google'")
162
- dataset_input = gr.Dropdown(label="Select a Dataset", interactive=False)
163
- split_input = gr.Dropdown(label="Select a Split", interactive=False)
164
- text_column_input = gr.Dropdown(label="Select a Text Column", interactive=False)
165
-
166
- gr.Markdown("### 2. Configure Visualization")
167
- sample_size_input = gr.Slider(label="Number of Samples", minimum=0, maximum=10000, value=2000, step=100)
168
-
169
- with gr.Accordion("Advanced Settings", open=False):
170
- model_input = gr.Dropdown(label="Embedding Model", choices=["all-MiniLM-L6-v2", "all-mpnet-base-v2", "multi-qa-MiniLM-L6-cos-v1"], value="all-MiniLM-L6-v2")
171
- umap_neighbors_input = gr.Slider(label="UMAP Neighbors", minimum=2, maximum=100, value=15, step=1, info="Controls local vs. global structure.")
172
- umap_min_dist_input = gr.Slider(label="UMAP Min Distance", minimum=0.0, maximum=0.99, value=0.1, step=0.01, info="Controls how tightly points are packed.")
173
-
174
- generate_button = gr.Button("Generate Atlas", variant="primary")
175
-
176
- with gr.Column(scale=3):
177
- gr.Markdown("### 3. Explore Atlas")
178
- output_html = gr.HTML("<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>")
179
-
180
- # --- Chained Event Listeners for Dynamic UI (Corrected Logic) ---
181
- hf_user_input.submit(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input)
182
-
183
- dataset_input.change(fn=get_dataset_splits, inputs=dataset_input, outputs=split_input)
184
-
185
- # The columns are populated only AFTER a split is chosen.
186
- split_input.change(fn=get_split_columns, inputs=[dataset_input, split_input], outputs=text_column_input)
187
-
188
- # --- Button Click Event ---
189
- generate_button.click(
190
- fn=generate_atlas,
191
- inputs=[dataset_input, split_input, text_column_input, sample_size_input, model_input, umap_neighbors_input, umap_min_dist_input],
192
- outputs=[output_html],
193
- )
194
 
195
- # Load initial example data on app load
196
- app.load(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input)
197
 
198
  if __name__ == "__main__":
199
- app.launch(debug=True)
 
1
+ import torch
2
+ from torch.optim import AdamW
3
+ from torch.utils.data import DataLoader
4
+ from transformers import (
5
+ GPT2Config,
6
+ AutoModelForCausalLM,
7
+ DataCollatorForLanguageModeling,
8
+ PreTrainedTokenizerFast,
9
+ DataCollatorWithPadding,
10
+ GenerationConfig,
11
+ )
12
+ from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead
13
+ from tokenizers import Tokenizer
14
+ from tokenizers.models import BPE
15
+ from tokenizers.trainers import BpeTrainer
16
+ from tokenizers.pre_tokenizers import Whitespace
17
+ from datasets import Dataset, load_dataset
18
+ import os
19
+ import torch.nn as nn
20
  import gradio as gr
21
  import pandas as pd
22
+ import time
23
+ import re
24
+ from sentence_transformers import SentenceTransformer
25
+ import faiss
26
+
27
+ # --- Configuration & Global State ---
28
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ SFT_MODEL_PATH = "./sft_model"
30
+ RAG_INDEX_PATH = "faiss_index.bin"
31
+ TOKENIZER_FILE_PATH = "savant_tokenizer.json"
32
+ CACHE_DIR = "./hf_cache"
33
+
34
+ SAVANT_MODEL = None
35
+ SAVANT_TOKENIZER = None
36
+ RAG_DATABASE = None
37
+ full_dataset_for_rag = None
38
+
39
+ # --- RAG Database Class ---
40
+ class VectorDatabase:
41
+ def __init__(self, embedder_model_name='all-MiniLM-L6-v2'):
42
+ self.embedder = SentenceTransformer(embedder_model_name, device=str(DEVICE), cache_folder=CACHE_DIR)
43
+ self.index = None
44
+ self.documents = []
45
+
46
+ def build_index(self, texts):
47
+ print("Building RAG vector index...")
48
+ self.documents = texts
49
+ embeddings = self.embedder.encode(texts, convert_to_tensor=True, show_progress_bar=True)
50
+ self.index = faiss.IndexFlatL2(embeddings.shape[1])
51
+ self.index.add(embeddings.cpu().numpy())
52
+ print(f"RAG Index built with {len(self.documents)} documents.")
53
+
54
+ def save_index(self, path):
55
+ if self.index:
56
+ faiss.write_index(self.index, path)
57
+ print(f"RAG Index saved to {path}")
58
+
59
+ def search(self, query, k=3):
60
+ if self.index is None:
61
+ return []
62
+ query_embedding = self.embedder.encode([query], convert_to_tensor=True)
63
+ distances, indices = self.index.search(query_embedding.cpu().numpy(), k)
64
+ return [self.documents[i] for i in indices[0]]
65
+
66
+ # --- Core Logic ---
67
+ def create_tokenizer_file_from_dataset(dataset, save_path=TOKENIZER_FILE_PATH):
68
+ corpus_path = "temp_corpus.txt"
69
+ with open(corpus_path, "w", encoding="utf-8") as f:
70
+ for item in dataset:
71
+ if item and item.get('question') and item.get('answer'):
72
+ f.write(str(item['question']) + " " + str(item['answer']) + "\n")
73
+
74
+ raw_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
75
+ raw_tokenizer.pre_tokenizer = Whitespace()
76
+ special_tokens = ["[UNK]", "[PAD]", "<|startoftext|>", "<|endoftext|>"]
77
+ trainer = BpeTrainer(vocab_size=8192, special_tokens=special_tokens)
78
+ raw_tokenizer.train(files=[corpus_path], trainer=trainer)
79
+ os.remove(corpus_path)
80
+ raw_tokenizer.save(save_path)
81
+ return save_path
82
+
83
+ def create_seed_model(config):
84
+ return AutoModelForCausalLM.from_config(config)
85
+
86
+ def get_sft_curriculum():
87
+ return [{"name": "Phase 1: Foundational Math (SFT)", "dataset_name": "openai/gsm8k", "dataset_config": "main", "data_slice": slice(0, 100), "epochs": 1, "learning_rate": 5e-5, "mastery_threshold": 16.0}]
88
+
89
+ def get_rl_curriculum():
90
+ return {"name": "Phase 2: Problem Solving (RL)", "dataset_name": "openai/gsm8k", "dataset_config": "main", "data_slice": slice(1000, 1100), "num_steps": 100}
91
+
92
+ def get_folder_files(folder_path):
93
+ if not os.path.isdir(folder_path): return []
94
+ return [os.path.join(folder_path, f) for f in os.listdir(folder_path)]
95
+
96
+ def extract_answer(text):
97
+ text = str(text)
98
+ match = re.search(r'\\boxed\{([^}]*)\}', text)
99
+ if match:
100
+ ans = match.group(1).strip().replace(",", "")
101
+ try: return float(ans)
102
+ except ValueError: return None
103
+ matches = re.findall(r'(\d+\.?\d*|\.\d+)', text)
104
+ if matches:
105
+ try: return float(matches[-1])
106
+ except ValueError: return None
107
+ return None
108
+
109
+ # --- Master Training Process ---
110
+ def run_sft_phase(artifact_files):
111
+ global full_dataset_for_rag
112
+ log_text = "--- Starting Phase 1: Supervised Fine-Tuning (SFT) ---\n"
113
+ yield log_text, None, gr.Group(visible=False), artifact_files
114
+
115
+ stage = get_sft_curriculum()[0]
116
+ full_dataset_for_rag = load_dataset(stage['dataset_name'], name=stage['dataset_config'], split='train', cache_dir=CACHE_DIR)
117
+ tokenizer_file = create_tokenizer_file_from_dataset(full_dataset_for_rag)
118
+ artifact_files.append(tokenizer_file)
119
+
120
+ wrapped_tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
121
+ wrapped_tokenizer.pad_token = "[PAD]"
122
+
123
+ log_text += f"✓ Tokenizer created.\n"
124
+ yield log_text, None, gr.Group(visible=False), artifact_files
125
+
126
+ config = GPT2Config(vocab_size=wrapped_tokenizer.vocab_size, n_positions=512, n_layer=4, n_head=4, n_embd=256, pad_token_id=wrapped_tokenizer.pad_token_id, eos_token_id=wrapped_tokenizer.eos_token_id)
127
+
128
+ model = create_seed_model(config).to(DEVICE)
129
+ log_text += f"✓ Seed Model (GPT-2 Style) Initialized.\n\n"
130
+ yield log_text, None, gr.Group(visible=False), artifact_files
131
+
132
+ data_collator = DataCollatorForLanguageModeling(tokenizer=wrapped_tokenizer, mlm=False)
133
+ stage_ds = full_dataset_for_rag.select(range(stage['data_slice'].start, stage['data_slice'].stop))
134
+
135
+ tokenized_dataset = stage_ds.map(lambda ex: wrapped_tokenizer([q + " " + a for q, a in zip(ex['question'], ex['answer'])], truncation=True, padding="max_length", max_length=512), batched=True, remove_columns=stage_ds.column_names)
136
+
137
+ dataloader = DataLoader(tokenized_dataset, batch_size=4, shuffle=True, collate_fn=data_collator)
138
+ optimizer = AdamW(model.parameters(), lr=stage['learning_rate'])
139
+
140
+ loss_history = []
141
+ avg_epoch_loss = float('inf')
142
+
143
+ for epoch in range(stage['epochs']):
144
+ base_log_for_epoch = log_text + f" Starting SFT Epoch {epoch+1}/{stage['epochs']}...\n"
145
+ yield base_log_for_epoch, None, gr.Group(visible=False), artifact_files
146
+ for batch_idx, batch in enumerate(dataloader):
147
+ batch = {k: v.to(DEVICE) for k, v in batch.items()}
148
+ optimizer.zero_grad()
149
+ outputs = model(**batch)
150
+ loss = outputs.loss
151
+ loss.backward()
152
+ optimizer.step()
153
+ if (batch_idx + 1) % 20 == 0:
154
+ yield base_log_for_epoch + f" - Batch {batch_idx+1}/{len(dataloader)}\n", None, gr.Group(visible=False), artifact_files
155
+ avg_epoch_loss = loss.item()
156
+ loss_history.append({"Phase": "SFT", "Epoch": epoch, "Loss": avg_epoch_loss})
157
+ loss_df = pd.DataFrame(loss_history)
158
+ log_text += f" Epoch {epoch+1}/{stage['epochs']} complete. Loss: {avg_epoch_loss:.4f}\n"
159
+ yield log_text, gr.LinePlot(loss_df, x="Epoch", y="Loss", color="Phase"), gr.Group(visible=False), artifact_files
160
+
161
+ log_text += f"✓ SFT Phase Complete. Final Loss: {avg_epoch_loss:.4f}\n"
162
+ if avg_epoch_loss < stage['mastery_threshold']:
163
+ log_text += f"✓ SFT Mastery Gate PASSED! Saving model...\n\n"
164
+ model.save_pretrained(SFT_MODEL_PATH)
165
+ wrapped_tokenizer.save_pretrained(SFT_MODEL_PATH)
166
+ artifact_files = [f for f in artifact_files if SFT_MODEL_PATH not in os.path.dirname(f)]
167
+ artifact_files.extend(get_folder_files(SFT_MODEL_PATH))
168
+ yield log_text, None, gr.Group(visible=False), artifact_files
169
+ return model, wrapped_tokenizer, log_text, artifact_files
170
  else:
171
+ log_text += f"✗ SFT Mastery Gate FAILED. Stopping.\n"
172
+ yield log_text, None, gr.Group(visible=False), artifact_files
173
+ return None, None, log_text, artifact_files
174
+
175
+ def run_rl_phase(sft_model, tokenizer, initial_log_text, rl_dataset_slice, artifact_files):
176
+ global SAVANT_MODEL, SAVANT_TOKENIZER
177
+ log_text = initial_log_text
178
+ rl_curriculum = get_rl_curriculum()
179
+
180
+ log_text += f"--- Starting Phase 2: Reinforcement Learning (RL) ---\n"
181
+ yield log_text, None, gr.Group(visible=False), artifact_files
182
+
183
+ def tokenize_query(examples):
184
+ return tokenizer(examples["question"], truncation=True, max_length=512)
185
+
186
+ rl_dataset = rl_dataset_slice.map(tokenize_query, batched=True)
187
+
188
+ ppo_config = PPOConfig(
189
+ learning_rate=1.41e-5,
190
+ batch_size=4,
191
+ mini_batch_size=4,
192
+ ppo_epochs=4,
193
+ )
194
+
195
+ model_with_value_head = AutoModelForCausalLMWithValueHead.from_pretrained(SFT_MODEL_PATH).to(DEVICE)
196
+
197
+ data_collator = DataCollatorWithPadding(tokenizer)
198
+ ppo_trainer = PPOTrainer(
199
+ config=ppo_config,
200
+ model=model_with_value_head,
201
+ ref_model=None,
202
+ tokenizer=tokenizer,
203
+ dataset=rl_dataset,
204
+ data_collator=data_collator,
205
+ )
206
+
207
+ # FIXED: Create our own DataLoader to ensure the 'answer' column is preserved in the batch.
208
+ # The PPOTrainer's internal dataloader drops all columns except for model inputs.
209
+ dataloader = DataLoader(
210
+ rl_dataset,
211
+ batch_size=ppo_config.batch_size,
212
+ collate_fn=data_collator
213
+ )
214
+
215
+ generation_kwargs = {"max_new_tokens": 100, "pad_token_id": tokenizer.pad_token_id, "do_sample": True, "temperature": 0.7}
216
+ log_text += "✓ PPOTrainer instantiated. Starting PPO training loop...\n"
217
+ yield log_text, None, gr.Group(visible=False), artifact_files
218
+
219
+ # FIXED: Iterate over our custom dataloader, not ppo_trainer.dataloader
220
+ for step, batch in enumerate(dataloader):
221
+ if step >= rl_curriculum["num_steps"]:
222
+ break
223
+
224
+ query_tensors = batch["input_ids"].to(DEVICE)
225
+ # This will now work correctly because our dataloader preserves the 'answer' column.
226
+ ground_truth_answers = batch["answer"]
227
+
228
+ response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
229
+ response_texts = [tokenizer.decode(r, skip_special_tokens=True) for r in response_tensors]
230
 
231
+ rewards = []
232
+ for resp_text, gt_answer in zip(response_texts, ground_truth_answers):
233
+ pred = extract_answer(resp_text)
234
+ true = extract_answer(gt_answer)
235
+ reward_value = 1.0 if pred is not None and true is not None and abs(pred - true) < 1e-2 else 0.0
236
+ rewards.append(torch.tensor(reward_value, device=DEVICE))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
+ stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
239
+
240
+ if step % 10 == 0:
241
+ mean_reward = stats.get("ppo/returns/mean", torch.tensor(0.0)).item()
242
+ log_text += f"Step {step}/{rl_curriculum['num_steps']}, Mean Reward: {mean_reward:.2f}\n"
243
+ yield log_text, None, gr.Group(visible=False), artifact_files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
+ log_text += "\n✓ Reinforcement Learning Phase Complete! Saving final policy model...\n"
246
+
247
+ ppo_trainer.save_model(SFT_MODEL_PATH)
248
+
249
+ SAVANT_MODEL = ppo_trainer.model
250
+ SAVANT_TOKENIZER = tokenizer
251
+
252
+ artifact_files = [f for f in artifact_files if SFT_MODEL_PATH not in os.path.dirname(f)]
253
+ artifact_files.extend(get_folder_files(SFT_MODEL_PATH))
254
+
255
+ log_text += f"✓ Final RL-tuned model saved to {SFT_MODEL_PATH}\n"
256
+ yield log_text, None, gr.Group(visible=False), artifact_files
257
+
258
+ return log_text, artifact_files
259
+
260
+ def run_full_training():
261
+ global RAG_DATABASE, SAVANT_MODEL, SAVANT_TOKENIZER, full_dataset_for_rag
262
+ artifact_files = []
263
+
264
+ sft_model, tokenizer, log_text, artifact_files = yield from run_sft_phase(artifact_files)
265
+
266
+ if sft_model is None:
267
+ yield log_text, None, gr.Group(visible=False), artifact_files
268
+ return
269
+
270
+ log_text += "--- Building RAG Knowledge Base ---\n"
271
+ yield log_text, None, gr.Group(visible=False), artifact_files
272
+
273
+ rag_texts = [f"Question: {q}\nAnswer: {a}" for q, a in zip(full_dataset_for_rag['question'], full_dataset_for_rag['answer'])]
274
+ RAG_DATABASE = VectorDatabase()
275
+ RAG_DATABASE.build_index(rag_texts)
276
+ RAG_DATABASE.save_index(RAG_INDEX_PATH)
277
+ artifact_files.append(RAG_INDEX_PATH)
278
+
279
+ log_text += "✓ RAG Knowledge Base Built.\n\n"
280
+ yield log_text, None, gr.Group(visible=False), artifact_files
281
+
282
+ rl_curriculum = get_rl_curriculum()
283
+ rl_dataset_slice = full_dataset_for_rag.select(range(*rl_curriculum['data_slice'].indices(len(full_dataset_for_rag))))
284
+
285
+ final_log_text, final_artifact_files = yield from run_rl_phase(sft_model, tokenizer, log_text, rl_dataset_slice, artifact_files)
286
+
287
+ final_log_text += "\n--- Curriculum Completed: RAG-Powered Savant Ready! ---\n"
288
+
289
+ SAVANT_MODEL = AutoModelForCausalLM.from_pretrained(SFT_MODEL_PATH).to(DEVICE)
290
+ SAVANT_TOKENIZER = PreTrainedTokenizerFast.from_pretrained(SFT_MODEL_PATH)
291
+
292
+ yield final_log_text, None, gr.Group(visible=True), final_artifact_files
293
+
294
+ def run_savant_inference(user_question):
295
+ if SAVANT_MODEL is None or SAVANT_TOKENIZER is None or RAG_DATABASE is None:
296
+ return "Model or RAG Database not ready. Please run the training first."
297
+
298
+ retrieved_context = RAG_DATABASE.search(user_question, k=3)
299
+ context_str = "\n".join(retrieved_context)
300
+ augmented_prompt = f"Context:\n{context_str}\n\nQuestion:\n{user_question}\n\nAnswer:"
301
+
302
+ inputs = SAVANT_TOKENIZER(augmented_prompt, return_tensors="pt").to(DEVICE)
303
+
304
+ with torch.no_grad():
305
+ response_ids = SAVANT_MODEL.generate(**inputs, max_new_tokens=150, pad_token_id=SAVANT_TOKENIZER.pad_token_id, do_sample=True, temperature=0.7, top_p=0.9)
306
+
307
+ response_text = SAVANT_TOKENIZER.decode(response_ids[0], skip_special_tokens=True)
308
+
309
+ if "Answer:" in response_text:
310
+ final_answer = response_text.split("Answer:")[-1].strip()
311
+ else:
312
+ final_answer = response_text
313
+
314
+ return f"**Savant's Answer:**\n{final_answer}\n\n**Context Used by RAG:**\n- {retrieved_context[0]}\n- {retrieved_context[1]}\n- {retrieved_context[2]}"
315
+
316
+ # --- GRADIO UI DEFINITION ---
317
+ with gr.Blocks(theme=gr.themes.Soft(), title="Savant-Garde Dashboard") as demo:
318
+ gr.Markdown("# The Savant-Garde: RAG + SFT + RL Savant Factory")
319
+
320
  with gr.Row():
321
  with gr.Column(scale=1):
322
+ start_button = gr.Button("🚀 Begin Full Training", variant="primary")
323
+ log_output = gr.Textbox(label="Live Training Log", interactive=False, lines=25, max_lines=25)
324
+
325
+ with gr.Column(scale=1):
326
+ gr.Markdown("### Dashboard Visualizations")
327
+ loss_plot = gr.LinePlot(show_label=False)
328
+ with gr.Accordion("Downloadable Artifacts", open=True):
329
+ file_output = gr.File(label="Generated Files", file_count="multiple", interactive=False)
330
+
331
+ with gr.Group(visible=False) as interaction_group:
332
+ gr.Markdown("---")
333
+ gr.Markdown("### 🧠 Interact with the RAG-Powered Savant")
334
+ with gr.Row():
335
+ question_input = gr.Textbox(label="Ask a Math Question", placeholder="e.g., Janet has 2 apples...")
336
+ ask_button = gr.Button("Ask Savant")
337
+ savant_answer = gr.Markdown()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
+ start_button.click(fn=run_full_training, inputs=None, outputs=[log_output, loss_plot, interaction_group, file_output])
340
+ ask_button.click(fn=run_savant_inference, inputs=question_input, outputs=savant_answer)
341
 
342
  if __name__ == "__main__":
343
+ demo.launch(debug=True)