File size: 15,324 Bytes
58b8f23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb95205
fe5ff1b
58b8f23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8c307f
58b8f23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98fe021
58b8f23
 
 
 
 
 
701c41c
58b8f23
 
 
 
 
 
fe5ff1b
58b8f23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb95205
 
58b8f23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701c41c
58b8f23
 
c5a0831
bb95205
58b8f23
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
    GPT2Config,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
    PreTrainedTokenizerFast,
    DataCollatorWithPadding,
    GenerationConfig,
)
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from datasets import Dataset, load_dataset
import os
import torch.nn as nn
import gradio as gr
import pandas as pd
import time
import re
from sentence_transformers import SentenceTransformer
import faiss

# --- Configuration & Global State ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SFT_MODEL_PATH = "./sft_model"
RAG_INDEX_PATH = "faiss_index.bin"
TOKENIZER_FILE_PATH = "savant_tokenizer.json"
CACHE_DIR = "./hf_cache"

SAVANT_MODEL = None
SAVANT_TOKENIZER = None
RAG_DATABASE = None
full_dataset_for_rag = None

# --- RAG Database Class ---
class VectorDatabase:
    def __init__(self, embedder_model_name='all-MiniLM-L6-v2'):
        self.embedder = SentenceTransformer(embedder_model_name, device=str(DEVICE), cache_folder=CACHE_DIR)
        self.index = None
        self.documents = []
    
    def build_index(self, texts):
        print("Building RAG vector index...")
        self.documents = texts
        embeddings = self.embedder.encode(texts, convert_to_tensor=True, show_progress_bar=True)
        self.index = faiss.IndexFlatL2(embeddings.shape[1])
        self.index.add(embeddings.cpu().numpy())
        print(f"RAG Index built with {len(self.documents)} documents.")
    
    def save_index(self, path):
        if self.index:
            faiss.write_index(self.index, path)
            print(f"RAG Index saved to {path}")
    
    def search(self, query, k=3):
        if self.index is None:
            return []
        query_embedding = self.embedder.encode([query], convert_to_tensor=True)
        distances, indices = self.index.search(query_embedding.cpu().numpy(), k)
        return [self.documents[i] for i in indices[0]]

# --- Core Logic ---
def create_tokenizer_file_from_dataset(dataset, save_path=TOKENIZER_FILE_PATH):
    corpus_path = "temp_corpus.txt"
    with open(corpus_path, "w", encoding="utf-8") as f:
        for item in dataset:
            if item and item.get('question') and item.get('answer'):
                f.write(str(item['question']) + " " + str(item['answer']) + "\n")
    
    raw_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
    raw_tokenizer.pre_tokenizer = Whitespace()
    special_tokens = ["[UNK]", "[PAD]", "<|startoftext|>", "<|endoftext|>"]
    trainer = BpeTrainer(vocab_size=8192, special_tokens=special_tokens)
    raw_tokenizer.train(files=[corpus_path], trainer=trainer)
    os.remove(corpus_path)
    raw_tokenizer.save(save_path)
    return save_path

def create_seed_model(config):
    return AutoModelForCausalLM.from_config(config)

def get_sft_curriculum():
    return [{"name": "Phase 1: Foundational Math (SFT)", "dataset_name": "openai/gsm8k", "dataset_config": "main", "data_slice": slice(0, 100), "epochs": 1, "learning_rate": 5e-5, "mastery_threshold": 16.0}]

def get_rl_curriculum():
    return {"name": "Phase 2: Problem Solving (RL)", "dataset_name": "openai/gsm8k", "dataset_config": "main", "data_slice": slice(1000, 1100), "num_steps": 100}

def get_folder_files(folder_path):
    if not os.path.isdir(folder_path): return []
    return [os.path.join(folder_path, f) for f in os.listdir(folder_path)]

def extract_answer(text):
    text = str(text)
    match = re.search(r'\\boxed\{([^}]*)\}', text)
    if match:
        ans = match.group(1).strip().replace(",", "")
        try: return float(ans)
        except ValueError: return None
    matches = re.findall(r'(\d+\.?\d*|\.\d+)', text)
    if matches:
        try: return float(matches[-1])
        except ValueError: return None
    return None

# --- Master Training Process ---
def run_sft_phase(artifact_files):
    global full_dataset_for_rag
    log_text = "--- Starting Phase 1: Supervised Fine-Tuning (SFT) ---\n"
    yield log_text, None, gr.Group(visible=False), artifact_files
    
    stage = get_sft_curriculum()[0]
    full_dataset_for_rag = load_dataset(stage['dataset_name'], name=stage['dataset_config'], split='train', cache_dir=CACHE_DIR)
    tokenizer_file = create_tokenizer_file_from_dataset(full_dataset_for_rag)
    artifact_files.append(tokenizer_file)
    
    wrapped_tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
    wrapped_tokenizer.pad_token = "[PAD]"
    
    log_text += f"βœ“ Tokenizer created.\n"
    yield log_text, None, gr.Group(visible=False), artifact_files
    
    config = GPT2Config(vocab_size=wrapped_tokenizer.vocab_size, n_positions=512, n_layer=4, n_head=4, n_embd=256, pad_token_id=wrapped_tokenizer.pad_token_id, eos_token_id=wrapped_tokenizer.eos_token_id)
    
    model = create_seed_model(config).to(DEVICE)
    log_text += f"βœ“ Seed Model (GPT-2 Style) Initialized.\n\n"
    yield log_text, None, gr.Group(visible=False), artifact_files
    
    data_collator = DataCollatorForLanguageModeling(tokenizer=wrapped_tokenizer, mlm=False)
    stage_ds = full_dataset_for_rag.select(range(stage['data_slice'].start, stage['data_slice'].stop))
    
    tokenized_dataset = stage_ds.map(lambda ex: wrapped_tokenizer([q + " " + a for q, a in zip(ex['question'], ex['answer'])], truncation=True, padding="max_length", max_length=512), batched=True, remove_columns=stage_ds.column_names)
    
    dataloader = DataLoader(tokenized_dataset, batch_size=4, shuffle=True, collate_fn=data_collator)
    optimizer = AdamW(model.parameters(), lr=stage['learning_rate'])
    
    loss_history = []
    avg_epoch_loss = float('inf')
    
    for epoch in range(stage['epochs']):
        base_log_for_epoch = log_text + f"  Starting SFT Epoch {epoch+1}/{stage['epochs']}...\n"
        yield base_log_for_epoch, None, gr.Group(visible=False), artifact_files
        for batch_idx, batch in enumerate(dataloader):
            batch = {k: v.to(DEVICE) for k, v in batch.items()}
            optimizer.zero_grad()
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            if (batch_idx + 1) % 20 == 0:
                yield base_log_for_epoch + f"    - Batch {batch_idx+1}/{len(dataloader)}\n", None, gr.Group(visible=False), artifact_files
        avg_epoch_loss = loss.item()
        loss_history.append({"Phase": "SFT", "Epoch": epoch, "Loss": avg_epoch_loss})
        loss_df = pd.DataFrame(loss_history)
        log_text += f"  Epoch {epoch+1}/{stage['epochs']} complete. Loss: {avg_epoch_loss:.4f}\n"
        yield log_text, gr.LinePlot(loss_df, x="Epoch", y="Loss", color="Phase"), gr.Group(visible=False), artifact_files
    
    log_text += f"βœ“ SFT Phase Complete. Final Loss: {avg_epoch_loss:.4f}\n"
    if avg_epoch_loss < stage['mastery_threshold']:
        log_text += f"βœ“ SFT Mastery Gate PASSED! Saving model...\n\n"
        model.save_pretrained(SFT_MODEL_PATH)
        wrapped_tokenizer.save_pretrained(SFT_MODEL_PATH)
        artifact_files = [f for f in artifact_files if SFT_MODEL_PATH not in os.path.dirname(f)]
        artifact_files.extend(get_folder_files(SFT_MODEL_PATH))
        yield log_text, None, gr.Group(visible=False), artifact_files
        return model, wrapped_tokenizer, log_text, artifact_files
    else:
        log_text += f"βœ— SFT Mastery Gate FAILED. Stopping.\n"
        yield log_text, None, gr.Group(visible=False), artifact_files
        return None, None, log_text, artifact_files

def run_rl_phase(sft_model, tokenizer, initial_log_text, rl_dataset_slice, artifact_files):
    global SAVANT_MODEL, SAVANT_TOKENIZER
    log_text = initial_log_text
    rl_curriculum = get_rl_curriculum()
    
    log_text += f"--- Starting Phase 2: Reinforcement Learning (RL) ---\n"
    yield log_text, None, gr.Group(visible=False), artifact_files
    
    def tokenize_query(examples):
        return tokenizer(examples["question"], truncation=True, max_length=512)

    rl_dataset = rl_dataset_slice.map(tokenize_query, batched=True)

    ppo_config = PPOConfig(
        learning_rate=1.41e-5,
        batch_size=4,
        mini_batch_size=4,
        ppo_epochs=4,
    )

    model_with_value_head = AutoModelForCausalLMWithValueHead.from_pretrained(SFT_MODEL_PATH).to(DEVICE)
    
    data_collator = DataCollatorWithPadding(tokenizer)
    ppo_trainer = PPOTrainer(
        config=ppo_config,
        model=model_with_value_head,
        ref_model=None,
        tokenizer=tokenizer,
        dataset=rl_dataset,
        data_collator=data_collator,
    )

    # FIXED: Create our own DataLoader to ensure the 'answer' column is preserved in the batch.
    # The PPOTrainer's internal dataloader drops all columns except for model inputs.
    dataloader = DataLoader(
        rl_dataset,
        batch_size=ppo_config.batch_size,
        collate_fn=data_collator
    )

    generation_kwargs = {"max_new_tokens": 100, "pad_token_id": tokenizer.pad_token_id, "do_sample": True, "temperature": 0.7}
    log_text += "βœ“ PPOTrainer instantiated. Starting PPO training loop...\n"
    yield log_text, None, gr.Group(visible=False), artifact_files
    
    # FIXED: Iterate over our custom dataloader, not ppo_trainer.dataloader
    for step, batch in enumerate(dataloader):
        if step >= rl_curriculum["num_steps"]:
            break

        query_tensors = batch["input_ids"].to(DEVICE)
        # This will now work correctly because our dataloader preserves the 'answer' column.
        ground_truth_answers = batch["answer"]

        response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
        response_texts = [tokenizer.decode(r, skip_special_tokens=True) for r in response_tensors]
        
        rewards = []
        for resp_text, gt_answer in zip(response_texts, ground_truth_answers):
            pred = extract_answer(resp_text)
            true = extract_answer(gt_answer)
            reward_value = 1.0 if pred is not None and true is not None and abs(pred - true) < 1e-2 else 0.0
            rewards.append(torch.tensor(reward_value, device=DEVICE))
        
        stats = ppo_trainer.step(query_tensors, response_tensors, rewards)

        if step % 10 == 0:
            mean_reward = stats.get("ppo/returns/mean", torch.tensor(0.0)).item()
            log_text += f"Step {step}/{rl_curriculum['num_steps']}, Mean Reward: {mean_reward:.2f}\n"
            yield log_text, None, gr.Group(visible=False), artifact_files

    log_text += "\nβœ“ Reinforcement Learning Phase Complete! Saving final policy model...\n"
    
    ppo_trainer.save_model(SFT_MODEL_PATH)
    
    SAVANT_MODEL = ppo_trainer.model
    SAVANT_TOKENIZER = tokenizer
    
    artifact_files = [f for f in artifact_files if SFT_MODEL_PATH not in os.path.dirname(f)]
    artifact_files.extend(get_folder_files(SFT_MODEL_PATH))
    
    log_text += f"βœ“ Final RL-tuned model saved to {SFT_MODEL_PATH}\n"
    yield log_text, None, gr.Group(visible=False), artifact_files
    
    return log_text, artifact_files

def run_full_training():
    global RAG_DATABASE, SAVANT_MODEL, SAVANT_TOKENIZER, full_dataset_for_rag
    artifact_files = []
    
    sft_model, tokenizer, log_text, artifact_files = yield from run_sft_phase(artifact_files)
    
    if sft_model is None:
        yield log_text, None, gr.Group(visible=False), artifact_files
        return
    
    log_text += "--- Building RAG Knowledge Base ---\n"
    yield log_text, None, gr.Group(visible=False), artifact_files
    
    rag_texts = [f"Question: {q}\nAnswer: {a}" for q, a in zip(full_dataset_for_rag['question'], full_dataset_for_rag['answer'])]
    RAG_DATABASE = VectorDatabase()
    RAG_DATABASE.build_index(rag_texts)
    RAG_DATABASE.save_index(RAG_INDEX_PATH)
    artifact_files.append(RAG_INDEX_PATH)
    
    log_text += "βœ“ RAG Knowledge Base Built.\n\n"
    yield log_text, None, gr.Group(visible=False), artifact_files
    
    rl_curriculum = get_rl_curriculum()
    rl_dataset_slice = full_dataset_for_rag.select(range(*rl_curriculum['data_slice'].indices(len(full_dataset_for_rag))))
    
    final_log_text, final_artifact_files = yield from run_rl_phase(sft_model, tokenizer, log_text, rl_dataset_slice, artifact_files)
    
    final_log_text += "\n--- Curriculum Completed: RAG-Powered Savant Ready! ---\n"
    
    SAVANT_MODEL = AutoModelForCausalLM.from_pretrained(SFT_MODEL_PATH).to(DEVICE)
    SAVANT_TOKENIZER = PreTrainedTokenizerFast.from_pretrained(SFT_MODEL_PATH)

    yield final_log_text, None, gr.Group(visible=True), final_artifact_files

def run_savant_inference(user_question):
    if SAVANT_MODEL is None or SAVANT_TOKENIZER is None or RAG_DATABASE is None:
        return "Model or RAG Database not ready. Please run the training first."
    
    retrieved_context = RAG_DATABASE.search(user_question, k=3)
    context_str = "\n".join(retrieved_context)
    augmented_prompt = f"Context:\n{context_str}\n\nQuestion:\n{user_question}\n\nAnswer:"
    
    inputs = SAVANT_TOKENIZER(augmented_prompt, return_tensors="pt").to(DEVICE)
    
    with torch.no_grad():
        response_ids = SAVANT_MODEL.generate(**inputs, max_new_tokens=150, pad_token_id=SAVANT_TOKENIZER.pad_token_id, do_sample=True, temperature=0.7, top_p=0.9)
    
    response_text = SAVANT_TOKENIZER.decode(response_ids[0], skip_special_tokens=True)
    
    if "Answer:" in response_text:
        final_answer = response_text.split("Answer:")[-1].strip()
    else:
        final_answer = response_text
    
    return f"**Savant's Answer:**\n{final_answer}\n\n**Context Used by RAG:**\n- {retrieved_context[0]}\n- {retrieved_context[1]}\n- {retrieved_context[2]}"

# --- GRADIO UI DEFINITION ---
with gr.Blocks(theme=gr.themes.Soft(), title="Savant-Garde Dashboard") as demo:
    gr.Markdown("# The Savant-Garde: RAG + SFT + RL Savant Factory")
    
    with gr.Row():
        with gr.Column(scale=1):
            start_button = gr.Button("πŸš€ Begin Full Training", variant="primary")
            log_output = gr.Textbox(label="Live Training Log", interactive=False, lines=25, max_lines=25)
        
        with gr.Column(scale=1):
            gr.Markdown("### Dashboard Visualizations")
            loss_plot = gr.LinePlot(show_label=False)
            with gr.Accordion("Downloadable Artifacts", open=True):
                file_output = gr.File(label="Generated Files", file_count="multiple", interactive=False)
    
    with gr.Group(visible=False) as interaction_group:
        gr.Markdown("---")
        gr.Markdown("### 🧠 Interact with the RAG-Powered Savant")
        with gr.Row():
            question_input = gr.Textbox(label="Ask a Math Question", placeholder="e.g., Janet has 2 apples...")
            ask_button = gr.Button("Ask Savant")
        savant_answer = gr.Markdown()
    
    start_button.click(fn=run_full_training, inputs=None, outputs=[log_output, loss_plot, interaction_group, file_output])
    ask_button.click(fn=run_savant_inference, inputs=question_input, outputs=savant_answer)

if __name__ == "__main__":
    demo.launch(debug=True)