Spaces:
Sleeping
Sleeping
import torch | |
from torch.optim import AdamW | |
from torch.utils.data import DataLoader | |
from transformers import ( | |
GPT2Config, | |
AutoModelForCausalLM, | |
DataCollatorForLanguageModeling, | |
PreTrainedTokenizerFast, | |
DataCollatorWithPadding, | |
GenerationConfig, | |
) | |
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead | |
from tokenizers import Tokenizer | |
from tokenizers.models import BPE | |
from tokenizers.trainers import BpeTrainer | |
from tokenizers.pre_tokenizers import Whitespace | |
from datasets import Dataset, load_dataset | |
import os | |
import torch.nn as nn | |
import gradio as gr | |
import pandas as pd | |
import time | |
import re | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
# --- Configuration & Global State --- | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
SFT_MODEL_PATH = "./sft_model" | |
RAG_INDEX_PATH = "faiss_index.bin" | |
TOKENIZER_FILE_PATH = "savant_tokenizer.json" | |
CACHE_DIR = "./hf_cache" | |
SAVANT_MODEL = None | |
SAVANT_TOKENIZER = None | |
RAG_DATABASE = None | |
full_dataset_for_rag = None | |
# --- RAG Database Class --- | |
class VectorDatabase: | |
def __init__(self, embedder_model_name='all-MiniLM-L6-v2'): | |
self.embedder = SentenceTransformer(embedder_model_name, device=str(DEVICE), cache_folder=CACHE_DIR) | |
self.index = None | |
self.documents = [] | |
def build_index(self, texts): | |
print("Building RAG vector index...") | |
self.documents = texts | |
embeddings = self.embedder.encode(texts, convert_to_tensor=True, show_progress_bar=True) | |
self.index = faiss.IndexFlatL2(embeddings.shape[1]) | |
self.index.add(embeddings.cpu().numpy()) | |
print(f"RAG Index built with {len(self.documents)} documents.") | |
def save_index(self, path): | |
if self.index: | |
faiss.write_index(self.index, path) | |
print(f"RAG Index saved to {path}") | |
def search(self, query, k=3): | |
if self.index is None: | |
return [] | |
query_embedding = self.embedder.encode([query], convert_to_tensor=True) | |
distances, indices = self.index.search(query_embedding.cpu().numpy(), k) | |
return [self.documents[i] for i in indices[0]] | |
# --- Core Logic --- | |
def create_tokenizer_file_from_dataset(dataset, save_path=TOKENIZER_FILE_PATH): | |
corpus_path = "temp_corpus.txt" | |
with open(corpus_path, "w", encoding="utf-8") as f: | |
for item in dataset: | |
if item and item.get('question') and item.get('answer'): | |
f.write(str(item['question']) + " " + str(item['answer']) + "\n") | |
raw_tokenizer = Tokenizer(BPE(unk_token="[UNK]")) | |
raw_tokenizer.pre_tokenizer = Whitespace() | |
special_tokens = ["[UNK]", "[PAD]", "<|startoftext|>", "<|endoftext|>"] | |
trainer = BpeTrainer(vocab_size=8192, special_tokens=special_tokens) | |
raw_tokenizer.train(files=[corpus_path], trainer=trainer) | |
os.remove(corpus_path) | |
raw_tokenizer.save(save_path) | |
return save_path | |
def create_seed_model(config): | |
return AutoModelForCausalLM.from_config(config) | |
def get_sft_curriculum(): | |
return [{"name": "Phase 1: Foundational Math (SFT)", "dataset_name": "openai/gsm8k", "dataset_config": "main", "data_slice": slice(0, 100), "epochs": 1, "learning_rate": 5e-5, "mastery_threshold": 16.0}] | |
def get_rl_curriculum(): | |
return {"name": "Phase 2: Problem Solving (RL)", "dataset_name": "openai/gsm8k", "dataset_config": "main", "data_slice": slice(1000, 1100), "num_steps": 100} | |
def get_folder_files(folder_path): | |
if not os.path.isdir(folder_path): return [] | |
return [os.path.join(folder_path, f) for f in os.listdir(folder_path)] | |
def extract_answer(text): | |
text = str(text) | |
match = re.search(r'\\boxed\{([^}]*)\}', text) | |
if match: | |
ans = match.group(1).strip().replace(",", "") | |
try: return float(ans) | |
except ValueError: return None | |
matches = re.findall(r'(\d+\.?\d*|\.\d+)', text) | |
if matches: | |
try: return float(matches[-1]) | |
except ValueError: return None | |
return None | |
# --- Master Training Process --- | |
def run_sft_phase(artifact_files): | |
global full_dataset_for_rag | |
log_text = "--- Starting Phase 1: Supervised Fine-Tuning (SFT) ---\n" | |
yield log_text, None, gr.Group(visible=False), artifact_files | |
stage = get_sft_curriculum()[0] | |
full_dataset_for_rag = load_dataset(stage['dataset_name'], name=stage['dataset_config'], split='train', cache_dir=CACHE_DIR) | |
tokenizer_file = create_tokenizer_file_from_dataset(full_dataset_for_rag) | |
artifact_files.append(tokenizer_file) | |
wrapped_tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file) | |
wrapped_tokenizer.pad_token = "[PAD]" | |
log_text += f"✓ Tokenizer created.\n" | |
yield log_text, None, gr.Group(visible=False), artifact_files | |
config = GPT2Config(vocab_size=wrapped_tokenizer.vocab_size, n_positions=512, n_layer=4, n_head=4, n_embd=256, pad_token_id=wrapped_tokenizer.pad_token_id, eos_token_id=wrapped_tokenizer.eos_token_id) | |
model = create_seed_model(config).to(DEVICE) | |
log_text += f"✓ Seed Model (GPT-2 Style) Initialized.\n\n" | |
yield log_text, None, gr.Group(visible=False), artifact_files | |
data_collator = DataCollatorForLanguageModeling(tokenizer=wrapped_tokenizer, mlm=False) | |
stage_ds = full_dataset_for_rag.select(range(stage['data_slice'].start, stage['data_slice'].stop)) | |
tokenized_dataset = stage_ds.map(lambda ex: wrapped_tokenizer([q + " " + a for q, a in zip(ex['question'], ex['answer'])], truncation=True, padding="max_length", max_length=512), batched=True, remove_columns=stage_ds.column_names) | |
dataloader = DataLoader(tokenized_dataset, batch_size=4, shuffle=True, collate_fn=data_collator) | |
optimizer = AdamW(model.parameters(), lr=stage['learning_rate']) | |
loss_history = [] | |
avg_epoch_loss = float('inf') | |
for epoch in range(stage['epochs']): | |
base_log_for_epoch = log_text + f" Starting SFT Epoch {epoch+1}/{stage['epochs']}...\n" | |
yield base_log_for_epoch, None, gr.Group(visible=False), artifact_files | |
for batch_idx, batch in enumerate(dataloader): | |
batch = {k: v.to(DEVICE) for k, v in batch.items()} | |
optimizer.zero_grad() | |
outputs = model(**batch) | |
loss = outputs.loss | |
loss.backward() | |
optimizer.step() | |
if (batch_idx + 1) % 20 == 0: | |
yield base_log_for_epoch + f" - Batch {batch_idx+1}/{len(dataloader)}\n", None, gr.Group(visible=False), artifact_files | |
avg_epoch_loss = loss.item() | |
loss_history.append({"Phase": "SFT", "Epoch": epoch, "Loss": avg_epoch_loss}) | |
loss_df = pd.DataFrame(loss_history) | |
log_text += f" Epoch {epoch+1}/{stage['epochs']} complete. Loss: {avg_epoch_loss:.4f}\n" | |
yield log_text, gr.LinePlot(loss_df, x="Epoch", y="Loss", color="Phase"), gr.Group(visible=False), artifact_files | |
log_text += f"✓ SFT Phase Complete. Final Loss: {avg_epoch_loss:.4f}\n" | |
if avg_epoch_loss < stage['mastery_threshold']: | |
log_text += f"✓ SFT Mastery Gate PASSED! Saving model...\n\n" | |
model.save_pretrained(SFT_MODEL_PATH) | |
wrapped_tokenizer.save_pretrained(SFT_MODEL_PATH) | |
artifact_files = [f for f in artifact_files if SFT_MODEL_PATH not in os.path.dirname(f)] | |
artifact_files.extend(get_folder_files(SFT_MODEL_PATH)) | |
yield log_text, None, gr.Group(visible=False), artifact_files | |
return model, wrapped_tokenizer, log_text, artifact_files | |
else: | |
log_text += f"✗ SFT Mastery Gate FAILED. Stopping.\n" | |
yield log_text, None, gr.Group(visible=False), artifact_files | |
return None, None, log_text, artifact_files | |
def run_rl_phase(sft_model, tokenizer, initial_log_text, rl_dataset_slice, artifact_files): | |
global SAVANT_MODEL, SAVANT_TOKENIZER | |
log_text = initial_log_text | |
rl_curriculum = get_rl_curriculum() | |
log_text += f"--- Starting Phase 2: Reinforcement Learning (RL) ---\n" | |
yield log_text, None, gr.Group(visible=False), artifact_files | |
def tokenize_query(examples): | |
return tokenizer(examples["question"], truncation=True, max_length=512) | |
rl_dataset = rl_dataset_slice.map(tokenize_query, batched=True) | |
ppo_config = PPOConfig( | |
learning_rate=1.41e-5, | |
batch_size=4, | |
mini_batch_size=4, | |
ppo_epochs=4, | |
) | |
model_with_value_head = AutoModelForCausalLMWithValueHead.from_pretrained(SFT_MODEL_PATH).to(DEVICE) | |
data_collator = DataCollatorWithPadding(tokenizer) | |
ppo_trainer = PPOTrainer( | |
config=ppo_config, | |
model=model_with_value_head, | |
ref_model=None, | |
tokenizer=tokenizer, | |
dataset=rl_dataset, | |
data_collator=data_collator, | |
) | |
# FIXED: Create our own DataLoader to ensure the 'answer' column is preserved in the batch. | |
# The PPOTrainer's internal dataloader drops all columns except for model inputs. | |
dataloader = DataLoader( | |
rl_dataset, | |
batch_size=ppo_config.batch_size, | |
collate_fn=data_collator | |
) | |
generation_kwargs = {"max_new_tokens": 100, "pad_token_id": tokenizer.pad_token_id, "do_sample": True, "temperature": 0.7} | |
log_text += "✓ PPOTrainer instantiated. Starting PPO training loop...\n" | |
yield log_text, None, gr.Group(visible=False), artifact_files | |
# FIXED: Iterate over our custom dataloader, not ppo_trainer.dataloader | |
for step, batch in enumerate(dataloader): | |
if step >= rl_curriculum["num_steps"]: | |
break | |
query_tensors = batch["input_ids"].to(DEVICE) | |
# This will now work correctly because our dataloader preserves the 'answer' column. | |
ground_truth_answers = batch["answer"] | |
response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs) | |
response_texts = [tokenizer.decode(r, skip_special_tokens=True) for r in response_tensors] | |
rewards = [] | |
for resp_text, gt_answer in zip(response_texts, ground_truth_answers): | |
pred = extract_answer(resp_text) | |
true = extract_answer(gt_answer) | |
reward_value = 1.0 if pred is not None and true is not None and abs(pred - true) < 1e-2 else 0.0 | |
rewards.append(torch.tensor(reward_value, device=DEVICE)) | |
stats = ppo_trainer.step(query_tensors, response_tensors, rewards) | |
if step % 10 == 0: | |
mean_reward = stats.get("ppo/returns/mean", torch.tensor(0.0)).item() | |
log_text += f"Step {step}/{rl_curriculum['num_steps']}, Mean Reward: {mean_reward:.2f}\n" | |
yield log_text, None, gr.Group(visible=False), artifact_files | |
log_text += "\n✓ Reinforcement Learning Phase Complete! Saving final policy model...\n" | |
ppo_trainer.save_model(SFT_MODEL_PATH) | |
SAVANT_MODEL = ppo_trainer.model | |
SAVANT_TOKENIZER = tokenizer | |
artifact_files = [f for f in artifact_files if SFT_MODEL_PATH not in os.path.dirname(f)] | |
artifact_files.extend(get_folder_files(SFT_MODEL_PATH)) | |
log_text += f"✓ Final RL-tuned model saved to {SFT_MODEL_PATH}\n" | |
yield log_text, None, gr.Group(visible=False), artifact_files | |
return log_text, artifact_files | |
def run_full_training(): | |
global RAG_DATABASE, SAVANT_MODEL, SAVANT_TOKENIZER, full_dataset_for_rag | |
artifact_files = [] | |
sft_model, tokenizer, log_text, artifact_files = yield from run_sft_phase(artifact_files) | |
if sft_model is None: | |
yield log_text, None, gr.Group(visible=False), artifact_files | |
return | |
log_text += "--- Building RAG Knowledge Base ---\n" | |
yield log_text, None, gr.Group(visible=False), artifact_files | |
rag_texts = [f"Question: {q}\nAnswer: {a}" for q, a in zip(full_dataset_for_rag['question'], full_dataset_for_rag['answer'])] | |
RAG_DATABASE = VectorDatabase() | |
RAG_DATABASE.build_index(rag_texts) | |
RAG_DATABASE.save_index(RAG_INDEX_PATH) | |
artifact_files.append(RAG_INDEX_PATH) | |
log_text += "✓ RAG Knowledge Base Built.\n\n" | |
yield log_text, None, gr.Group(visible=False), artifact_files | |
rl_curriculum = get_rl_curriculum() | |
rl_dataset_slice = full_dataset_for_rag.select(range(*rl_curriculum['data_slice'].indices(len(full_dataset_for_rag)))) | |
final_log_text, final_artifact_files = yield from run_rl_phase(sft_model, tokenizer, log_text, rl_dataset_slice, artifact_files) | |
final_log_text += "\n--- Curriculum Completed: RAG-Powered Savant Ready! ---\n" | |
SAVANT_MODEL = AutoModelForCausalLM.from_pretrained(SFT_MODEL_PATH).to(DEVICE) | |
SAVANT_TOKENIZER = PreTrainedTokenizerFast.from_pretrained(SFT_MODEL_PATH) | |
yield final_log_text, None, gr.Group(visible=True), final_artifact_files | |
def run_savant_inference(user_question): | |
if SAVANT_MODEL is None or SAVANT_TOKENIZER is None or RAG_DATABASE is None: | |
return "Model or RAG Database not ready. Please run the training first." | |
retrieved_context = RAG_DATABASE.search(user_question, k=3) | |
context_str = "\n".join(retrieved_context) | |
augmented_prompt = f"Context:\n{context_str}\n\nQuestion:\n{user_question}\n\nAnswer:" | |
inputs = SAVANT_TOKENIZER(augmented_prompt, return_tensors="pt").to(DEVICE) | |
with torch.no_grad(): | |
response_ids = SAVANT_MODEL.generate(**inputs, max_new_tokens=150, pad_token_id=SAVANT_TOKENIZER.pad_token_id, do_sample=True, temperature=0.7, top_p=0.9) | |
response_text = SAVANT_TOKENIZER.decode(response_ids[0], skip_special_tokens=True) | |
if "Answer:" in response_text: | |
final_answer = response_text.split("Answer:")[-1].strip() | |
else: | |
final_answer = response_text | |
return f"**Savant's Answer:**\n{final_answer}\n\n**Context Used by RAG:**\n- {retrieved_context[0]}\n- {retrieved_context[1]}\n- {retrieved_context[2]}" | |
# --- GRADIO UI DEFINITION --- | |
with gr.Blocks(theme=gr.themes.Soft(), title="Savant-Garde Dashboard") as demo: | |
gr.Markdown("# The Savant-Garde: RAG + SFT + RL Savant Factory") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
start_button = gr.Button("🚀 Begin Full Training", variant="primary") | |
log_output = gr.Textbox(label="Live Training Log", interactive=False, lines=25, max_lines=25) | |
with gr.Column(scale=1): | |
gr.Markdown("### Dashboard Visualizations") | |
loss_plot = gr.LinePlot(show_label=False) | |
with gr.Accordion("Downloadable Artifacts", open=True): | |
file_output = gr.File(label="Generated Files", file_count="multiple", interactive=False) | |
with gr.Group(visible=False) as interaction_group: | |
gr.Markdown("---") | |
gr.Markdown("### 🧠 Interact with the RAG-Powered Savant") | |
with gr.Row(): | |
question_input = gr.Textbox(label="Ask a Math Question", placeholder="e.g., Janet has 2 apples...") | |
ask_button = gr.Button("Ask Savant") | |
savant_answer = gr.Markdown() | |
start_button.click(fn=run_full_training, inputs=None, outputs=[log_output, loss_plot, interaction_group, file_output]) | |
ask_button.click(fn=run_savant_inference, inputs=question_input, outputs=savant_answer) | |
if __name__ == "__main__": | |
demo.launch(debug=True) |