Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,199 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
-
|
4 |
-
|
5 |
-
import
|
6 |
-
import
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
else:
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
# --- Main Atlas Generation Logic ---
|
86 |
-
def generate_atlas(
|
87 |
-
dataset_name: str,
|
88 |
-
split: str,
|
89 |
-
text_column: str,
|
90 |
-
sample_size: int,
|
91 |
-
model_name: str,
|
92 |
-
umap_neighbors: int,
|
93 |
-
umap_min_dist: float,
|
94 |
-
progress=gr.Progress(track_tqdm=True)
|
95 |
-
):
|
96 |
-
"""
|
97 |
-
Loads data, computes embeddings, and serves the Embedding Atlas UI.
|
98 |
-
"""
|
99 |
-
if not all([dataset_name, split, text_column]):
|
100 |
-
raise gr.Error("Please ensure a Dataset, Split, and Text Column are selected.")
|
101 |
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
if sample_size > 0 and sample_size < len(df):
|
110 |
-
progress(0.1, desc=f"Sampling {sample_size} rows...")
|
111 |
-
df = df.sample(n=sample_size, random_state=42).reset_index(drop=True)
|
112 |
-
|
113 |
-
if text_column not in df.columns:
|
114 |
-
raise gr.Error(f"Column '{text_column}' not found. Please select a valid column.")
|
115 |
-
|
116 |
-
progress(0.2, desc="Computing embeddings and UMAP...")
|
117 |
-
|
118 |
-
x_col = find_column_name(df.columns, "projection_x")
|
119 |
-
y_col = find_column_name(df.columns, "projection_y")
|
120 |
-
neighbors_col = find_column_name(df.columns, "__neighbors")
|
121 |
-
|
122 |
-
try:
|
123 |
-
compute_text_projection(
|
124 |
-
df, text_column, x=x_col, y=y_col, neighbors=neighbors_col, model=model_name,
|
125 |
-
umap_args={"n_neighbors": umap_neighbors, "min_dist": umap_min_dist, "metric": "cosine", "random_state": 42},
|
126 |
-
)
|
127 |
-
except Exception as e:
|
128 |
-
raise gr.Error(f"Failed to compute embeddings. Check model name or sample size. Error: {e}")
|
129 |
-
|
130 |
-
progress(0.8, desc="Preparing Atlas data source...")
|
131 |
-
id_col = find_column_name(df.columns, "_row_index")
|
132 |
-
df[id_col] = range(df.shape[0])
|
133 |
-
|
134 |
-
metadata = {"columns": {"id": id_col, "text": text_column, "embedding": {"x": x_col, "y": y_col}, "neighbors": neighbors_col}}
|
135 |
-
hasher = Hasher()
|
136 |
-
hasher.update(f"{dataset_name}-{split}-{text_column}-{sample_size}-{model_name}")
|
137 |
-
identifier = hasher.hexdigest()
|
138 |
-
atlas_dataset = DataSource(identifier, df, metadata)
|
139 |
-
|
140 |
-
progress(0.9, desc="Mounting visualization UI...")
|
141 |
-
static_path = str((pathlib.Path(__import__('embedding_atlas').__file__).parent / "static").resolve())
|
142 |
-
mount_path = f"/{uuid.uuid4().hex}"
|
143 |
-
atlas_app = make_server(atlas_dataset, static_path=static_path, duckdb_uri="wasm")
|
144 |
-
|
145 |
-
# --- THIS IS THE FINAL, CORRECT METHOD ---
|
146 |
-
# The 'app' object from gr.Blocks() is a FastAPI app. We use its standard 'mount' method.
|
147 |
-
app.mount(mount_path, atlas_app)
|
148 |
-
|
149 |
-
progress(1.0, desc="Done!")
|
150 |
-
iframe_html = f"<iframe src='{mount_path}' width='100%' height='800px' frameborder='0'></iframe>"
|
151 |
-
return gr.HTML(iframe_html)
|
152 |
-
|
153 |
-
# --- Gradio UI Definition ---
|
154 |
-
with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
|
155 |
-
gr.Markdown("# Embedding Atlas Explorer")
|
156 |
-
gr.Markdown("Interactively select and visualize any text-based dataset from the Hugging Face Hub.")
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
with gr.Row():
|
159 |
with gr.Column(scale=1):
|
160 |
-
gr.
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
gr.
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
with gr.Column(scale=3):
|
177 |
-
gr.Markdown("### 3. Explore Atlas")
|
178 |
-
output_html = gr.HTML("<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>")
|
179 |
-
|
180 |
-
# --- Chained Event Listeners for Dynamic UI (Corrected Logic) ---
|
181 |
-
hf_user_input.submit(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input)
|
182 |
-
|
183 |
-
dataset_input.change(fn=get_dataset_splits, inputs=dataset_input, outputs=split_input)
|
184 |
-
|
185 |
-
# The columns are populated only AFTER a split is chosen.
|
186 |
-
split_input.change(fn=get_split_columns, inputs=[dataset_input, split_input], outputs=text_column_input)
|
187 |
-
|
188 |
-
# --- Button Click Event ---
|
189 |
-
generate_button.click(
|
190 |
-
fn=generate_atlas,
|
191 |
-
inputs=[dataset_input, split_input, text_column_input, sample_size_input, model_input, umap_neighbors_input, umap_min_dist_input],
|
192 |
-
outputs=[output_html],
|
193 |
-
)
|
194 |
|
195 |
-
|
196 |
-
|
197 |
|
198 |
if __name__ == "__main__":
|
199 |
-
|
|
|
1 |
+
import torch
|
2 |
+
from torch.optim import AdamW
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from transformers import (
|
5 |
+
GPT2Config,
|
6 |
+
AutoModelForCausalLM,
|
7 |
+
DataCollatorForLanguageModeling,
|
8 |
+
PreTrainedTokenizerFast,
|
9 |
+
DataCollatorWithPadding,
|
10 |
+
GenerationConfig,
|
11 |
+
)
|
12 |
+
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead
|
13 |
+
from tokenizers import Tokenizer
|
14 |
+
from tokenizers.models import BPE
|
15 |
+
from tokenizers.trainers import BpeTrainer
|
16 |
+
from tokenizers.pre_tokenizers import Whitespace
|
17 |
+
from datasets import Dataset, load_dataset
|
18 |
+
import os
|
19 |
+
import torch.nn as nn
|
20 |
import gradio as gr
|
21 |
import pandas as pd
|
22 |
+
import time
|
23 |
+
import re
|
24 |
+
from sentence_transformers import SentenceTransformer
|
25 |
+
import faiss
|
26 |
+
|
27 |
+
# --- Configuration & Global State ---
|
28 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
+
SFT_MODEL_PATH = "./sft_model"
|
30 |
+
RAG_INDEX_PATH = "faiss_index.bin"
|
31 |
+
TOKENIZER_FILE_PATH = "savant_tokenizer.json"
|
32 |
+
CACHE_DIR = "./hf_cache"
|
33 |
+
|
34 |
+
SAVANT_MODEL = None
|
35 |
+
SAVANT_TOKENIZER = None
|
36 |
+
RAG_DATABASE = None
|
37 |
+
full_dataset_for_rag = None
|
38 |
+
|
39 |
+
# --- RAG Database Class ---
|
40 |
+
class VectorDatabase:
|
41 |
+
def __init__(self, embedder_model_name='all-MiniLM-L6-v2'):
|
42 |
+
self.embedder = SentenceTransformer(embedder_model_name, device=str(DEVICE), cache_folder=CACHE_DIR)
|
43 |
+
self.index = None
|
44 |
+
self.documents = []
|
45 |
+
|
46 |
+
def build_index(self, texts):
|
47 |
+
print("Building RAG vector index...")
|
48 |
+
self.documents = texts
|
49 |
+
embeddings = self.embedder.encode(texts, convert_to_tensor=True, show_progress_bar=True)
|
50 |
+
self.index = faiss.IndexFlatL2(embeddings.shape[1])
|
51 |
+
self.index.add(embeddings.cpu().numpy())
|
52 |
+
print(f"RAG Index built with {len(self.documents)} documents.")
|
53 |
+
|
54 |
+
def save_index(self, path):
|
55 |
+
if self.index:
|
56 |
+
faiss.write_index(self.index, path)
|
57 |
+
print(f"RAG Index saved to {path}")
|
58 |
+
|
59 |
+
def search(self, query, k=3):
|
60 |
+
if self.index is None:
|
61 |
+
return []
|
62 |
+
query_embedding = self.embedder.encode([query], convert_to_tensor=True)
|
63 |
+
distances, indices = self.index.search(query_embedding.cpu().numpy(), k)
|
64 |
+
return [self.documents[i] for i in indices[0]]
|
65 |
+
|
66 |
+
# --- Core Logic ---
|
67 |
+
def create_tokenizer_file_from_dataset(dataset, save_path=TOKENIZER_FILE_PATH):
|
68 |
+
corpus_path = "temp_corpus.txt"
|
69 |
+
with open(corpus_path, "w", encoding="utf-8") as f:
|
70 |
+
for item in dataset:
|
71 |
+
if item and item.get('question') and item.get('answer'):
|
72 |
+
f.write(str(item['question']) + " " + str(item['answer']) + "\n")
|
73 |
+
|
74 |
+
raw_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
|
75 |
+
raw_tokenizer.pre_tokenizer = Whitespace()
|
76 |
+
special_tokens = ["[UNK]", "[PAD]", "<|startoftext|>", "<|endoftext|>"]
|
77 |
+
trainer = BpeTrainer(vocab_size=8192, special_tokens=special_tokens)
|
78 |
+
raw_tokenizer.train(files=[corpus_path], trainer=trainer)
|
79 |
+
os.remove(corpus_path)
|
80 |
+
raw_tokenizer.save(save_path)
|
81 |
+
return save_path
|
82 |
+
|
83 |
+
def create_seed_model(config):
|
84 |
+
return AutoModelForCausalLM.from_config(config)
|
85 |
+
|
86 |
+
def get_sft_curriculum():
|
87 |
+
return [{"name": "Phase 1: Foundational Math (SFT)", "dataset_name": "openai/gsm8k", "dataset_config": "main", "data_slice": slice(0, 100), "epochs": 1, "learning_rate": 5e-5, "mastery_threshold": 16.0}]
|
88 |
+
|
89 |
+
def get_rl_curriculum():
|
90 |
+
return {"name": "Phase 2: Problem Solving (RL)", "dataset_name": "openai/gsm8k", "dataset_config": "main", "data_slice": slice(1000, 1100), "num_steps": 100}
|
91 |
+
|
92 |
+
def get_folder_files(folder_path):
|
93 |
+
if not os.path.isdir(folder_path): return []
|
94 |
+
return [os.path.join(folder_path, f) for f in os.listdir(folder_path)]
|
95 |
+
|
96 |
+
def extract_answer(text):
|
97 |
+
text = str(text)
|
98 |
+
match = re.search(r'\\boxed\{([^}]*)\}', text)
|
99 |
+
if match:
|
100 |
+
ans = match.group(1).strip().replace(",", "")
|
101 |
+
try: return float(ans)
|
102 |
+
except ValueError: return None
|
103 |
+
matches = re.findall(r'(\d+\.?\d*|\.\d+)', text)
|
104 |
+
if matches:
|
105 |
+
try: return float(matches[-1])
|
106 |
+
except ValueError: return None
|
107 |
+
return None
|
108 |
+
|
109 |
+
# --- Master Training Process ---
|
110 |
+
def run_sft_phase(artifact_files):
|
111 |
+
global full_dataset_for_rag
|
112 |
+
log_text = "--- Starting Phase 1: Supervised Fine-Tuning (SFT) ---\n"
|
113 |
+
yield log_text, None, gr.Group(visible=False), artifact_files
|
114 |
+
|
115 |
+
stage = get_sft_curriculum()[0]
|
116 |
+
full_dataset_for_rag = load_dataset(stage['dataset_name'], name=stage['dataset_config'], split='train', cache_dir=CACHE_DIR)
|
117 |
+
tokenizer_file = create_tokenizer_file_from_dataset(full_dataset_for_rag)
|
118 |
+
artifact_files.append(tokenizer_file)
|
119 |
+
|
120 |
+
wrapped_tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
|
121 |
+
wrapped_tokenizer.pad_token = "[PAD]"
|
122 |
+
|
123 |
+
log_text += f"✓ Tokenizer created.\n"
|
124 |
+
yield log_text, None, gr.Group(visible=False), artifact_files
|
125 |
+
|
126 |
+
config = GPT2Config(vocab_size=wrapped_tokenizer.vocab_size, n_positions=512, n_layer=4, n_head=4, n_embd=256, pad_token_id=wrapped_tokenizer.pad_token_id, eos_token_id=wrapped_tokenizer.eos_token_id)
|
127 |
+
|
128 |
+
model = create_seed_model(config).to(DEVICE)
|
129 |
+
log_text += f"✓ Seed Model (GPT-2 Style) Initialized.\n\n"
|
130 |
+
yield log_text, None, gr.Group(visible=False), artifact_files
|
131 |
+
|
132 |
+
data_collator = DataCollatorForLanguageModeling(tokenizer=wrapped_tokenizer, mlm=False)
|
133 |
+
stage_ds = full_dataset_for_rag.select(range(stage['data_slice'].start, stage['data_slice'].stop))
|
134 |
+
|
135 |
+
tokenized_dataset = stage_ds.map(lambda ex: wrapped_tokenizer([q + " " + a for q, a in zip(ex['question'], ex['answer'])], truncation=True, padding="max_length", max_length=512), batched=True, remove_columns=stage_ds.column_names)
|
136 |
+
|
137 |
+
dataloader = DataLoader(tokenized_dataset, batch_size=4, shuffle=True, collate_fn=data_collator)
|
138 |
+
optimizer = AdamW(model.parameters(), lr=stage['learning_rate'])
|
139 |
+
|
140 |
+
loss_history = []
|
141 |
+
avg_epoch_loss = float('inf')
|
142 |
+
|
143 |
+
for epoch in range(stage['epochs']):
|
144 |
+
base_log_for_epoch = log_text + f" Starting SFT Epoch {epoch+1}/{stage['epochs']}...\n"
|
145 |
+
yield base_log_for_epoch, None, gr.Group(visible=False), artifact_files
|
146 |
+
for batch_idx, batch in enumerate(dataloader):
|
147 |
+
batch = {k: v.to(DEVICE) for k, v in batch.items()}
|
148 |
+
optimizer.zero_grad()
|
149 |
+
outputs = model(**batch)
|
150 |
+
loss = outputs.loss
|
151 |
+
loss.backward()
|
152 |
+
optimizer.step()
|
153 |
+
if (batch_idx + 1) % 20 == 0:
|
154 |
+
yield base_log_for_epoch + f" - Batch {batch_idx+1}/{len(dataloader)}\n", None, gr.Group(visible=False), artifact_files
|
155 |
+
avg_epoch_loss = loss.item()
|
156 |
+
loss_history.append({"Phase": "SFT", "Epoch": epoch, "Loss": avg_epoch_loss})
|
157 |
+
loss_df = pd.DataFrame(loss_history)
|
158 |
+
log_text += f" Epoch {epoch+1}/{stage['epochs']} complete. Loss: {avg_epoch_loss:.4f}\n"
|
159 |
+
yield log_text, gr.LinePlot(loss_df, x="Epoch", y="Loss", color="Phase"), gr.Group(visible=False), artifact_files
|
160 |
+
|
161 |
+
log_text += f"✓ SFT Phase Complete. Final Loss: {avg_epoch_loss:.4f}\n"
|
162 |
+
if avg_epoch_loss < stage['mastery_threshold']:
|
163 |
+
log_text += f"✓ SFT Mastery Gate PASSED! Saving model...\n\n"
|
164 |
+
model.save_pretrained(SFT_MODEL_PATH)
|
165 |
+
wrapped_tokenizer.save_pretrained(SFT_MODEL_PATH)
|
166 |
+
artifact_files = [f for f in artifact_files if SFT_MODEL_PATH not in os.path.dirname(f)]
|
167 |
+
artifact_files.extend(get_folder_files(SFT_MODEL_PATH))
|
168 |
+
yield log_text, None, gr.Group(visible=False), artifact_files
|
169 |
+
return model, wrapped_tokenizer, log_text, artifact_files
|
170 |
else:
|
171 |
+
log_text += f"✗ SFT Mastery Gate FAILED. Stopping.\n"
|
172 |
+
yield log_text, None, gr.Group(visible=False), artifact_files
|
173 |
+
return None, None, log_text, artifact_files
|
174 |
+
|
175 |
+
def run_rl_phase(sft_model, tokenizer, initial_log_text, rl_dataset_slice, artifact_files):
|
176 |
+
global SAVANT_MODEL, SAVANT_TOKENIZER
|
177 |
+
log_text = initial_log_text
|
178 |
+
rl_curriculum = get_rl_curriculum()
|
179 |
+
|
180 |
+
log_text += f"--- Starting Phase 2: Reinforcement Learning (RL) ---\n"
|
181 |
+
yield log_text, None, gr.Group(visible=False), artifact_files
|
182 |
+
|
183 |
+
def tokenize_query(examples):
|
184 |
+
return tokenizer(examples["question"], truncation=True, max_length=512)
|
185 |
+
|
186 |
+
rl_dataset = rl_dataset_slice.map(tokenize_query, batched=True)
|
187 |
+
|
188 |
+
ppo_config = PPOConfig(
|
189 |
+
learning_rate=1.41e-5,
|
190 |
+
batch_size=4,
|
191 |
+
mini_batch_size=4,
|
192 |
+
ppo_epochs=4,
|
193 |
+
)
|
194 |
+
|
195 |
+
model_with_value_head = AutoModelForCausalLMWithValueHead.from_pretrained(SFT_MODEL_PATH).to(DEVICE)
|
196 |
+
|
197 |
+
data_collator = DataCollatorWithPadding(tokenizer)
|
198 |
+
ppo_trainer = PPOTrainer(
|
199 |
+
config=ppo_config,
|
200 |
+
model=model_with_value_head,
|
201 |
+
ref_model=None,
|
202 |
+
tokenizer=tokenizer,
|
203 |
+
dataset=rl_dataset,
|
204 |
+
data_collator=data_collator,
|
205 |
+
)
|
206 |
+
|
207 |
+
# FIXED: Create our own DataLoader to ensure the 'answer' column is preserved in the batch.
|
208 |
+
# The PPOTrainer's internal dataloader drops all columns except for model inputs.
|
209 |
+
dataloader = DataLoader(
|
210 |
+
rl_dataset,
|
211 |
+
batch_size=ppo_config.batch_size,
|
212 |
+
collate_fn=data_collator
|
213 |
+
)
|
214 |
+
|
215 |
+
generation_kwargs = {"max_new_tokens": 100, "pad_token_id": tokenizer.pad_token_id, "do_sample": True, "temperature": 0.7}
|
216 |
+
log_text += "✓ PPOTrainer instantiated. Starting PPO training loop...\n"
|
217 |
+
yield log_text, None, gr.Group(visible=False), artifact_files
|
218 |
+
|
219 |
+
# FIXED: Iterate over our custom dataloader, not ppo_trainer.dataloader
|
220 |
+
for step, batch in enumerate(dataloader):
|
221 |
+
if step >= rl_curriculum["num_steps"]:
|
222 |
+
break
|
223 |
+
|
224 |
+
query_tensors = batch["input_ids"].to(DEVICE)
|
225 |
+
# This will now work correctly because our dataloader preserves the 'answer' column.
|
226 |
+
ground_truth_answers = batch["answer"]
|
227 |
+
|
228 |
+
response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
|
229 |
+
response_texts = [tokenizer.decode(r, skip_special_tokens=True) for r in response_tensors]
|
230 |
|
231 |
+
rewards = []
|
232 |
+
for resp_text, gt_answer in zip(response_texts, ground_truth_answers):
|
233 |
+
pred = extract_answer(resp_text)
|
234 |
+
true = extract_answer(gt_answer)
|
235 |
+
reward_value = 1.0 if pred is not None and true is not None and abs(pred - true) < 1e-2 else 0.0
|
236 |
+
rewards.append(torch.tensor(reward_value, device=DEVICE))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
+
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
|
239 |
+
|
240 |
+
if step % 10 == 0:
|
241 |
+
mean_reward = stats.get("ppo/returns/mean", torch.tensor(0.0)).item()
|
242 |
+
log_text += f"Step {step}/{rl_curriculum['num_steps']}, Mean Reward: {mean_reward:.2f}\n"
|
243 |
+
yield log_text, None, gr.Group(visible=False), artifact_files
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
+
log_text += "\n✓ Reinforcement Learning Phase Complete! Saving final policy model...\n"
|
246 |
+
|
247 |
+
ppo_trainer.save_model(SFT_MODEL_PATH)
|
248 |
+
|
249 |
+
SAVANT_MODEL = ppo_trainer.model
|
250 |
+
SAVANT_TOKENIZER = tokenizer
|
251 |
+
|
252 |
+
artifact_files = [f for f in artifact_files if SFT_MODEL_PATH not in os.path.dirname(f)]
|
253 |
+
artifact_files.extend(get_folder_files(SFT_MODEL_PATH))
|
254 |
+
|
255 |
+
log_text += f"✓ Final RL-tuned model saved to {SFT_MODEL_PATH}\n"
|
256 |
+
yield log_text, None, gr.Group(visible=False), artifact_files
|
257 |
+
|
258 |
+
return log_text, artifact_files
|
259 |
+
|
260 |
+
def run_full_training():
|
261 |
+
global RAG_DATABASE, SAVANT_MODEL, SAVANT_TOKENIZER, full_dataset_for_rag
|
262 |
+
artifact_files = []
|
263 |
+
|
264 |
+
sft_model, tokenizer, log_text, artifact_files = yield from run_sft_phase(artifact_files)
|
265 |
+
|
266 |
+
if sft_model is None:
|
267 |
+
yield log_text, None, gr.Group(visible=False), artifact_files
|
268 |
+
return
|
269 |
+
|
270 |
+
log_text += "--- Building RAG Knowledge Base ---\n"
|
271 |
+
yield log_text, None, gr.Group(visible=False), artifact_files
|
272 |
+
|
273 |
+
rag_texts = [f"Question: {q}\nAnswer: {a}" for q, a in zip(full_dataset_for_rag['question'], full_dataset_for_rag['answer'])]
|
274 |
+
RAG_DATABASE = VectorDatabase()
|
275 |
+
RAG_DATABASE.build_index(rag_texts)
|
276 |
+
RAG_DATABASE.save_index(RAG_INDEX_PATH)
|
277 |
+
artifact_files.append(RAG_INDEX_PATH)
|
278 |
+
|
279 |
+
log_text += "✓ RAG Knowledge Base Built.\n\n"
|
280 |
+
yield log_text, None, gr.Group(visible=False), artifact_files
|
281 |
+
|
282 |
+
rl_curriculum = get_rl_curriculum()
|
283 |
+
rl_dataset_slice = full_dataset_for_rag.select(range(*rl_curriculum['data_slice'].indices(len(full_dataset_for_rag))))
|
284 |
+
|
285 |
+
final_log_text, final_artifact_files = yield from run_rl_phase(sft_model, tokenizer, log_text, rl_dataset_slice, artifact_files)
|
286 |
+
|
287 |
+
final_log_text += "\n--- Curriculum Completed: RAG-Powered Savant Ready! ---\n"
|
288 |
+
|
289 |
+
SAVANT_MODEL = AutoModelForCausalLM.from_pretrained(SFT_MODEL_PATH).to(DEVICE)
|
290 |
+
SAVANT_TOKENIZER = PreTrainedTokenizerFast.from_pretrained(SFT_MODEL_PATH)
|
291 |
+
|
292 |
+
yield final_log_text, None, gr.Group(visible=True), final_artifact_files
|
293 |
+
|
294 |
+
def run_savant_inference(user_question):
|
295 |
+
if SAVANT_MODEL is None or SAVANT_TOKENIZER is None or RAG_DATABASE is None:
|
296 |
+
return "Model or RAG Database not ready. Please run the training first."
|
297 |
+
|
298 |
+
retrieved_context = RAG_DATABASE.search(user_question, k=3)
|
299 |
+
context_str = "\n".join(retrieved_context)
|
300 |
+
augmented_prompt = f"Context:\n{context_str}\n\nQuestion:\n{user_question}\n\nAnswer:"
|
301 |
+
|
302 |
+
inputs = SAVANT_TOKENIZER(augmented_prompt, return_tensors="pt").to(DEVICE)
|
303 |
+
|
304 |
+
with torch.no_grad():
|
305 |
+
response_ids = SAVANT_MODEL.generate(**inputs, max_new_tokens=150, pad_token_id=SAVANT_TOKENIZER.pad_token_id, do_sample=True, temperature=0.7, top_p=0.9)
|
306 |
+
|
307 |
+
response_text = SAVANT_TOKENIZER.decode(response_ids[0], skip_special_tokens=True)
|
308 |
+
|
309 |
+
if "Answer:" in response_text:
|
310 |
+
final_answer = response_text.split("Answer:")[-1].strip()
|
311 |
+
else:
|
312 |
+
final_answer = response_text
|
313 |
+
|
314 |
+
return f"**Savant's Answer:**\n{final_answer}\n\n**Context Used by RAG:**\n- {retrieved_context[0]}\n- {retrieved_context[1]}\n- {retrieved_context[2]}"
|
315 |
+
|
316 |
+
# --- GRADIO UI DEFINITION ---
|
317 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="Savant-Garde Dashboard") as demo:
|
318 |
+
gr.Markdown("# The Savant-Garde: RAG + SFT + RL Savant Factory")
|
319 |
+
|
320 |
with gr.Row():
|
321 |
with gr.Column(scale=1):
|
322 |
+
start_button = gr.Button("🚀 Begin Full Training", variant="primary")
|
323 |
+
log_output = gr.Textbox(label="Live Training Log", interactive=False, lines=25, max_lines=25)
|
324 |
+
|
325 |
+
with gr.Column(scale=1):
|
326 |
+
gr.Markdown("### Dashboard Visualizations")
|
327 |
+
loss_plot = gr.LinePlot(show_label=False)
|
328 |
+
with gr.Accordion("Downloadable Artifacts", open=True):
|
329 |
+
file_output = gr.File(label="Generated Files", file_count="multiple", interactive=False)
|
330 |
+
|
331 |
+
with gr.Group(visible=False) as interaction_group:
|
332 |
+
gr.Markdown("---")
|
333 |
+
gr.Markdown("### 🧠 Interact with the RAG-Powered Savant")
|
334 |
+
with gr.Row():
|
335 |
+
question_input = gr.Textbox(label="Ask a Math Question", placeholder="e.g., Janet has 2 apples...")
|
336 |
+
ask_button = gr.Button("Ask Savant")
|
337 |
+
savant_answer = gr.Markdown()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
|
339 |
+
start_button.click(fn=run_full_training, inputs=None, outputs=[log_output, loss_plot, interaction_group, file_output])
|
340 |
+
ask_button.click(fn=run_savant_inference, inputs=question_input, outputs=savant_answer)
|
341 |
|
342 |
if __name__ == "__main__":
|
343 |
+
demo.launch(debug=True)
|