Embedding-Atlas / app.py
broadfield-dev's picture
Update app.py
58b8f23 verified
raw
history blame
15.3 kB
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
GPT2Config,
AutoModelForCausalLM,
DataCollatorForLanguageModeling,
PreTrainedTokenizerFast,
DataCollatorWithPadding,
GenerationConfig,
)
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from datasets import Dataset, load_dataset
import os
import torch.nn as nn
import gradio as gr
import pandas as pd
import time
import re
from sentence_transformers import SentenceTransformer
import faiss
# --- Configuration & Global State ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SFT_MODEL_PATH = "./sft_model"
RAG_INDEX_PATH = "faiss_index.bin"
TOKENIZER_FILE_PATH = "savant_tokenizer.json"
CACHE_DIR = "./hf_cache"
SAVANT_MODEL = None
SAVANT_TOKENIZER = None
RAG_DATABASE = None
full_dataset_for_rag = None
# --- RAG Database Class ---
class VectorDatabase:
def __init__(self, embedder_model_name='all-MiniLM-L6-v2'):
self.embedder = SentenceTransformer(embedder_model_name, device=str(DEVICE), cache_folder=CACHE_DIR)
self.index = None
self.documents = []
def build_index(self, texts):
print("Building RAG vector index...")
self.documents = texts
embeddings = self.embedder.encode(texts, convert_to_tensor=True, show_progress_bar=True)
self.index = faiss.IndexFlatL2(embeddings.shape[1])
self.index.add(embeddings.cpu().numpy())
print(f"RAG Index built with {len(self.documents)} documents.")
def save_index(self, path):
if self.index:
faiss.write_index(self.index, path)
print(f"RAG Index saved to {path}")
def search(self, query, k=3):
if self.index is None:
return []
query_embedding = self.embedder.encode([query], convert_to_tensor=True)
distances, indices = self.index.search(query_embedding.cpu().numpy(), k)
return [self.documents[i] for i in indices[0]]
# --- Core Logic ---
def create_tokenizer_file_from_dataset(dataset, save_path=TOKENIZER_FILE_PATH):
corpus_path = "temp_corpus.txt"
with open(corpus_path, "w", encoding="utf-8") as f:
for item in dataset:
if item and item.get('question') and item.get('answer'):
f.write(str(item['question']) + " " + str(item['answer']) + "\n")
raw_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
raw_tokenizer.pre_tokenizer = Whitespace()
special_tokens = ["[UNK]", "[PAD]", "<|startoftext|>", "<|endoftext|>"]
trainer = BpeTrainer(vocab_size=8192, special_tokens=special_tokens)
raw_tokenizer.train(files=[corpus_path], trainer=trainer)
os.remove(corpus_path)
raw_tokenizer.save(save_path)
return save_path
def create_seed_model(config):
return AutoModelForCausalLM.from_config(config)
def get_sft_curriculum():
return [{"name": "Phase 1: Foundational Math (SFT)", "dataset_name": "openai/gsm8k", "dataset_config": "main", "data_slice": slice(0, 100), "epochs": 1, "learning_rate": 5e-5, "mastery_threshold": 16.0}]
def get_rl_curriculum():
return {"name": "Phase 2: Problem Solving (RL)", "dataset_name": "openai/gsm8k", "dataset_config": "main", "data_slice": slice(1000, 1100), "num_steps": 100}
def get_folder_files(folder_path):
if not os.path.isdir(folder_path): return []
return [os.path.join(folder_path, f) for f in os.listdir(folder_path)]
def extract_answer(text):
text = str(text)
match = re.search(r'\\boxed\{([^}]*)\}', text)
if match:
ans = match.group(1).strip().replace(",", "")
try: return float(ans)
except ValueError: return None
matches = re.findall(r'(\d+\.?\d*|\.\d+)', text)
if matches:
try: return float(matches[-1])
except ValueError: return None
return None
# --- Master Training Process ---
def run_sft_phase(artifact_files):
global full_dataset_for_rag
log_text = "--- Starting Phase 1: Supervised Fine-Tuning (SFT) ---\n"
yield log_text, None, gr.Group(visible=False), artifact_files
stage = get_sft_curriculum()[0]
full_dataset_for_rag = load_dataset(stage['dataset_name'], name=stage['dataset_config'], split='train', cache_dir=CACHE_DIR)
tokenizer_file = create_tokenizer_file_from_dataset(full_dataset_for_rag)
artifact_files.append(tokenizer_file)
wrapped_tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
wrapped_tokenizer.pad_token = "[PAD]"
log_text += f"✓ Tokenizer created.\n"
yield log_text, None, gr.Group(visible=False), artifact_files
config = GPT2Config(vocab_size=wrapped_tokenizer.vocab_size, n_positions=512, n_layer=4, n_head=4, n_embd=256, pad_token_id=wrapped_tokenizer.pad_token_id, eos_token_id=wrapped_tokenizer.eos_token_id)
model = create_seed_model(config).to(DEVICE)
log_text += f"✓ Seed Model (GPT-2 Style) Initialized.\n\n"
yield log_text, None, gr.Group(visible=False), artifact_files
data_collator = DataCollatorForLanguageModeling(tokenizer=wrapped_tokenizer, mlm=False)
stage_ds = full_dataset_for_rag.select(range(stage['data_slice'].start, stage['data_slice'].stop))
tokenized_dataset = stage_ds.map(lambda ex: wrapped_tokenizer([q + " " + a for q, a in zip(ex['question'], ex['answer'])], truncation=True, padding="max_length", max_length=512), batched=True, remove_columns=stage_ds.column_names)
dataloader = DataLoader(tokenized_dataset, batch_size=4, shuffle=True, collate_fn=data_collator)
optimizer = AdamW(model.parameters(), lr=stage['learning_rate'])
loss_history = []
avg_epoch_loss = float('inf')
for epoch in range(stage['epochs']):
base_log_for_epoch = log_text + f" Starting SFT Epoch {epoch+1}/{stage['epochs']}...\n"
yield base_log_for_epoch, None, gr.Group(visible=False), artifact_files
for batch_idx, batch in enumerate(dataloader):
batch = {k: v.to(DEVICE) for k, v in batch.items()}
optimizer.zero_grad()
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
if (batch_idx + 1) % 20 == 0:
yield base_log_for_epoch + f" - Batch {batch_idx+1}/{len(dataloader)}\n", None, gr.Group(visible=False), artifact_files
avg_epoch_loss = loss.item()
loss_history.append({"Phase": "SFT", "Epoch": epoch, "Loss": avg_epoch_loss})
loss_df = pd.DataFrame(loss_history)
log_text += f" Epoch {epoch+1}/{stage['epochs']} complete. Loss: {avg_epoch_loss:.4f}\n"
yield log_text, gr.LinePlot(loss_df, x="Epoch", y="Loss", color="Phase"), gr.Group(visible=False), artifact_files
log_text += f"✓ SFT Phase Complete. Final Loss: {avg_epoch_loss:.4f}\n"
if avg_epoch_loss < stage['mastery_threshold']:
log_text += f"✓ SFT Mastery Gate PASSED! Saving model...\n\n"
model.save_pretrained(SFT_MODEL_PATH)
wrapped_tokenizer.save_pretrained(SFT_MODEL_PATH)
artifact_files = [f for f in artifact_files if SFT_MODEL_PATH not in os.path.dirname(f)]
artifact_files.extend(get_folder_files(SFT_MODEL_PATH))
yield log_text, None, gr.Group(visible=False), artifact_files
return model, wrapped_tokenizer, log_text, artifact_files
else:
log_text += f"✗ SFT Mastery Gate FAILED. Stopping.\n"
yield log_text, None, gr.Group(visible=False), artifact_files
return None, None, log_text, artifact_files
def run_rl_phase(sft_model, tokenizer, initial_log_text, rl_dataset_slice, artifact_files):
global SAVANT_MODEL, SAVANT_TOKENIZER
log_text = initial_log_text
rl_curriculum = get_rl_curriculum()
log_text += f"--- Starting Phase 2: Reinforcement Learning (RL) ---\n"
yield log_text, None, gr.Group(visible=False), artifact_files
def tokenize_query(examples):
return tokenizer(examples["question"], truncation=True, max_length=512)
rl_dataset = rl_dataset_slice.map(tokenize_query, batched=True)
ppo_config = PPOConfig(
learning_rate=1.41e-5,
batch_size=4,
mini_batch_size=4,
ppo_epochs=4,
)
model_with_value_head = AutoModelForCausalLMWithValueHead.from_pretrained(SFT_MODEL_PATH).to(DEVICE)
data_collator = DataCollatorWithPadding(tokenizer)
ppo_trainer = PPOTrainer(
config=ppo_config,
model=model_with_value_head,
ref_model=None,
tokenizer=tokenizer,
dataset=rl_dataset,
data_collator=data_collator,
)
# FIXED: Create our own DataLoader to ensure the 'answer' column is preserved in the batch.
# The PPOTrainer's internal dataloader drops all columns except for model inputs.
dataloader = DataLoader(
rl_dataset,
batch_size=ppo_config.batch_size,
collate_fn=data_collator
)
generation_kwargs = {"max_new_tokens": 100, "pad_token_id": tokenizer.pad_token_id, "do_sample": True, "temperature": 0.7}
log_text += "✓ PPOTrainer instantiated. Starting PPO training loop...\n"
yield log_text, None, gr.Group(visible=False), artifact_files
# FIXED: Iterate over our custom dataloader, not ppo_trainer.dataloader
for step, batch in enumerate(dataloader):
if step >= rl_curriculum["num_steps"]:
break
query_tensors = batch["input_ids"].to(DEVICE)
# This will now work correctly because our dataloader preserves the 'answer' column.
ground_truth_answers = batch["answer"]
response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
response_texts = [tokenizer.decode(r, skip_special_tokens=True) for r in response_tensors]
rewards = []
for resp_text, gt_answer in zip(response_texts, ground_truth_answers):
pred = extract_answer(resp_text)
true = extract_answer(gt_answer)
reward_value = 1.0 if pred is not None and true is not None and abs(pred - true) < 1e-2 else 0.0
rewards.append(torch.tensor(reward_value, device=DEVICE))
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
if step % 10 == 0:
mean_reward = stats.get("ppo/returns/mean", torch.tensor(0.0)).item()
log_text += f"Step {step}/{rl_curriculum['num_steps']}, Mean Reward: {mean_reward:.2f}\n"
yield log_text, None, gr.Group(visible=False), artifact_files
log_text += "\n✓ Reinforcement Learning Phase Complete! Saving final policy model...\n"
ppo_trainer.save_model(SFT_MODEL_PATH)
SAVANT_MODEL = ppo_trainer.model
SAVANT_TOKENIZER = tokenizer
artifact_files = [f for f in artifact_files if SFT_MODEL_PATH not in os.path.dirname(f)]
artifact_files.extend(get_folder_files(SFT_MODEL_PATH))
log_text += f"✓ Final RL-tuned model saved to {SFT_MODEL_PATH}\n"
yield log_text, None, gr.Group(visible=False), artifact_files
return log_text, artifact_files
def run_full_training():
global RAG_DATABASE, SAVANT_MODEL, SAVANT_TOKENIZER, full_dataset_for_rag
artifact_files = []
sft_model, tokenizer, log_text, artifact_files = yield from run_sft_phase(artifact_files)
if sft_model is None:
yield log_text, None, gr.Group(visible=False), artifact_files
return
log_text += "--- Building RAG Knowledge Base ---\n"
yield log_text, None, gr.Group(visible=False), artifact_files
rag_texts = [f"Question: {q}\nAnswer: {a}" for q, a in zip(full_dataset_for_rag['question'], full_dataset_for_rag['answer'])]
RAG_DATABASE = VectorDatabase()
RAG_DATABASE.build_index(rag_texts)
RAG_DATABASE.save_index(RAG_INDEX_PATH)
artifact_files.append(RAG_INDEX_PATH)
log_text += "✓ RAG Knowledge Base Built.\n\n"
yield log_text, None, gr.Group(visible=False), artifact_files
rl_curriculum = get_rl_curriculum()
rl_dataset_slice = full_dataset_for_rag.select(range(*rl_curriculum['data_slice'].indices(len(full_dataset_for_rag))))
final_log_text, final_artifact_files = yield from run_rl_phase(sft_model, tokenizer, log_text, rl_dataset_slice, artifact_files)
final_log_text += "\n--- Curriculum Completed: RAG-Powered Savant Ready! ---\n"
SAVANT_MODEL = AutoModelForCausalLM.from_pretrained(SFT_MODEL_PATH).to(DEVICE)
SAVANT_TOKENIZER = PreTrainedTokenizerFast.from_pretrained(SFT_MODEL_PATH)
yield final_log_text, None, gr.Group(visible=True), final_artifact_files
def run_savant_inference(user_question):
if SAVANT_MODEL is None or SAVANT_TOKENIZER is None or RAG_DATABASE is None:
return "Model or RAG Database not ready. Please run the training first."
retrieved_context = RAG_DATABASE.search(user_question, k=3)
context_str = "\n".join(retrieved_context)
augmented_prompt = f"Context:\n{context_str}\n\nQuestion:\n{user_question}\n\nAnswer:"
inputs = SAVANT_TOKENIZER(augmented_prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
response_ids = SAVANT_MODEL.generate(**inputs, max_new_tokens=150, pad_token_id=SAVANT_TOKENIZER.pad_token_id, do_sample=True, temperature=0.7, top_p=0.9)
response_text = SAVANT_TOKENIZER.decode(response_ids[0], skip_special_tokens=True)
if "Answer:" in response_text:
final_answer = response_text.split("Answer:")[-1].strip()
else:
final_answer = response_text
return f"**Savant's Answer:**\n{final_answer}\n\n**Context Used by RAG:**\n- {retrieved_context[0]}\n- {retrieved_context[1]}\n- {retrieved_context[2]}"
# --- GRADIO UI DEFINITION ---
with gr.Blocks(theme=gr.themes.Soft(), title="Savant-Garde Dashboard") as demo:
gr.Markdown("# The Savant-Garde: RAG + SFT + RL Savant Factory")
with gr.Row():
with gr.Column(scale=1):
start_button = gr.Button("🚀 Begin Full Training", variant="primary")
log_output = gr.Textbox(label="Live Training Log", interactive=False, lines=25, max_lines=25)
with gr.Column(scale=1):
gr.Markdown("### Dashboard Visualizations")
loss_plot = gr.LinePlot(show_label=False)
with gr.Accordion("Downloadable Artifacts", open=True):
file_output = gr.File(label="Generated Files", file_count="multiple", interactive=False)
with gr.Group(visible=False) as interaction_group:
gr.Markdown("---")
gr.Markdown("### 🧠 Interact with the RAG-Powered Savant")
with gr.Row():
question_input = gr.Textbox(label="Ask a Math Question", placeholder="e.g., Janet has 2 apples...")
ask_button = gr.Button("Ask Savant")
savant_answer = gr.Markdown()
start_button.click(fn=run_full_training, inputs=None, outputs=[log_output, loss_plot, interaction_group, file_output])
ask_button.click(fn=run_savant_inference, inputs=question_input, outputs=savant_answer)
if __name__ == "__main__":
demo.launch(debug=True)