wassemgtk commited on
Commit
e911248
Β·
verified Β·
1 Parent(s): eafbdfb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +321 -0
app.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ import numpy as np
6
+ from typing import List, Dict, Tuple
7
+ import json
8
+ import os
9
+ from datetime import datetime
10
+
11
+ class GRPOTrainer:
12
+ def __init__(self):
13
+ self.model = None
14
+ self.ref_model = None
15
+ self.tokenizer = None
16
+ self.optimizer = None
17
+ self.training_history = []
18
+
19
+ def load_model(self, model_name: str) -> str:
20
+ """Load the model and tokenizer"""
21
+ try:
22
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
24
+ self.ref_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
25
+
26
+ # Set padding token
27
+ if self.tokenizer.pad_token is None:
28
+ self.tokenizer.pad_token = self.tokenizer.eos_token
29
+
30
+ # Freeze reference model
31
+ for param in self.ref_model.parameters():
32
+ param.requires_grad = False
33
+
34
+ return f"βœ… Successfully loaded model: {model_name}"
35
+ except Exception as e:
36
+ return f"❌ Error loading model: {str(e)}"
37
+
38
+ def compute_rewards(self, prompts: List[str], responses: List[str]) -> torch.Tensor:
39
+ """Compute rewards for responses (simplified reward function)"""
40
+ rewards = []
41
+ for response in responses:
42
+ # Simple reward based on response length and diversity
43
+ length_reward = min(len(response.split()) / 50, 1.0)
44
+ unique_words = len(set(response.lower().split()))
45
+ diversity_reward = min(unique_words / 20, 1.0)
46
+ reward = (length_reward + diversity_reward) / 2
47
+ rewards.append(reward)
48
+ return torch.tensor(rewards)
49
+
50
+ def compute_kl_penalty(self, logits: torch.Tensor, ref_logits: torch.Tensor) -> torch.Tensor:
51
+ """Compute KL divergence penalty"""
52
+ probs = F.softmax(logits, dim=-1)
53
+ ref_probs = F.softmax(ref_logits, dim=-1)
54
+ kl = (probs * (probs / ref_probs).log()).sum(-1)
55
+ return kl.mean()
56
+
57
+ def grpo_step(self, prompts: List[str], beta: float = 0.1) -> Dict:
58
+ """Perform one GRPO training step"""
59
+ if not self.model or not self.tokenizer:
60
+ return {"error": "Model not loaded"}
61
+
62
+ # Tokenize prompts
63
+ inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)
64
+
65
+ # Generate responses
66
+ with torch.no_grad():
67
+ outputs = self.model.generate(
68
+ inputs.input_ids,
69
+ max_length=inputs.input_ids.shape[1] + 50,
70
+ do_sample=True,
71
+ temperature=0.8,
72
+ pad_token_id=self.tokenizer.pad_token_id
73
+ )
74
+
75
+ # Decode responses
76
+ responses = []
77
+ for output in outputs:
78
+ response = self.tokenizer.decode(output[inputs.input_ids.shape[1]:], skip_special_tokens=True)
79
+ responses.append(response)
80
+
81
+ # Compute rewards
82
+ rewards = self.compute_rewards(prompts, responses)
83
+
84
+ # Forward pass through both models
85
+ self.model.train()
86
+ model_outputs = self.model(inputs.input_ids)
87
+ ref_outputs = self.ref_model(inputs.input_ids)
88
+
89
+ # Compute KL penalty
90
+ kl_penalty = self.compute_kl_penalty(model_outputs.logits, ref_outputs.logits)
91
+
92
+ # Compute loss (simplified GRPO loss)
93
+ loss = -rewards.mean() + beta * kl_penalty
94
+
95
+ # Backward pass
96
+ if self.optimizer:
97
+ self.optimizer.zero_grad()
98
+ loss.backward()
99
+ self.optimizer.step()
100
+
101
+ return {
102
+ "loss": loss.item(),
103
+ "reward": rewards.mean().item(),
104
+ "kl_penalty": kl_penalty.item(),
105
+ "responses": responses
106
+ }
107
+
108
+ def train(self, prompts: List[str], num_steps: int, lr: float, beta: float) -> str:
109
+ """Run GRPO training"""
110
+ if not self.model:
111
+ return "❌ Please load a model first"
112
+
113
+ # Initialize optimizer
114
+ self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
115
+
116
+ results = []
117
+ for step in range(num_steps):
118
+ step_result = self.grpo_step(prompts, beta)
119
+
120
+ if "error" in step_result:
121
+ return f"❌ Error: {step_result['error']}"
122
+
123
+ result_str = f"Step {step + 1}/{num_steps} - Loss: {step_result['loss']:.4f}, Reward: {step_result['reward']:.4f}, KL: {step_result['kl_penalty']:.4f}"
124
+ results.append(result_str)
125
+
126
+ # Store training history
127
+ self.training_history.append({
128
+ "step": step + 1,
129
+ "loss": step_result['loss'],
130
+ "reward": step_result['reward'],
131
+ "kl_penalty": step_result['kl_penalty']
132
+ })
133
+
134
+ return "\n".join(results)
135
+
136
+ def generate_response(self, prompt: str, max_length: int = 100, temperature: float = 0.8) -> str:
137
+ """Generate a response using the trained model"""
138
+ if not self.model or not self.tokenizer:
139
+ return "❌ Please load a model first"
140
+
141
+ inputs = self.tokenizer(prompt, return_tensors="pt")
142
+
143
+ with torch.no_grad():
144
+ outputs = self.model.generate(
145
+ inputs.input_ids,
146
+ max_length=inputs.input_ids.shape[1] + max_length,
147
+ temperature=temperature,
148
+ do_sample=True,
149
+ pad_token_id=self.tokenizer.pad_token_id
150
+ )
151
+
152
+ response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
153
+ return response
154
+
155
+ def save_model(self, save_path: str) -> str:
156
+ """Save the trained model"""
157
+ if not self.model:
158
+ return "❌ No model to save"
159
+
160
+ try:
161
+ self.model.save_pretrained(save_path)
162
+ self.tokenizer.save_pretrained(save_path)
163
+
164
+ # Save training history
165
+ with open(os.path.join(save_path, "training_history.json"), "w") as f:
166
+ json.dump(self.training_history, f)
167
+
168
+ return f"βœ… Model saved to {save_path}"
169
+ except Exception as e:
170
+ return f"❌ Error saving model: {str(e)}"
171
+
172
+ # Initialize trainer
173
+ trainer = GRPOTrainer()
174
+
175
+ # Gradio interface
176
+ def load_model_interface(model_name):
177
+ return trainer.load_model(model_name)
178
+
179
+ def train_interface(prompts_text, num_steps, learning_rate, beta):
180
+ prompts = [p.strip() for p in prompts_text.split("\n") if p.strip()]
181
+ if not prompts:
182
+ return "❌ Please provide at least one prompt"
183
+ return trainer.train(prompts, int(num_steps), float(learning_rate), float(beta))
184
+
185
+ def generate_interface(prompt, max_length, temperature):
186
+ return trainer.generate_response(prompt, int(max_length), float(temperature))
187
+
188
+ def save_model_interface(save_path):
189
+ return trainer.save_model(save_path)
190
+
191
+ def get_training_history():
192
+ if not trainer.training_history:
193
+ return "No training history available"
194
+
195
+ history_str = "Training History:\n"
196
+ history_str += "-" * 50 + "\n"
197
+ for entry in trainer.training_history[-10:]: # Show last 10 entries
198
+ history_str += f"Step {entry['step']}: Loss={entry['loss']:.4f}, Reward={entry['reward']:.4f}, KL={entry['kl_penalty']:.4f}\n"
199
+ return history_str
200
+
201
+ # Create Gradio interface
202
+ with gr.Blocks(title="GRPO Model Training") as app:
203
+ gr.Markdown("# πŸš€ GRPO (Group Relative Policy Optimization) Training App")
204
+ gr.Markdown("Train language models using GRPO technique with this simple interface")
205
+
206
+ with gr.Tab("πŸ”§ Model Setup"):
207
+ with gr.Row():
208
+ model_input = gr.Textbox(
209
+ label="Model Name",
210
+ value="Palmyra-56b",
211
+ placeholder="Enter HuggingFace model name (e.g., Palmyra, Qwen, Llama)"
212
+ )
213
+ load_btn = gr.Button("Load Model", variant="primary")
214
+
215
+ model_status = gr.Textbox(label="Status", lines=2)
216
+ load_btn.click(load_model_interface, inputs=model_input, outputs=model_status)
217
+
218
+ with gr.Tab("🎯 Training"):
219
+ with gr.Row():
220
+ with gr.Column():
221
+ prompts_input = gr.Textbox(
222
+ label="Training Prompts (one per line)",
223
+ lines=5,
224
+ value="Tell me about artificial intelligence\nExplain quantum computing\nWhat is machine learning?",
225
+ placeholder="Enter your prompts here..."
226
+ )
227
+
228
+ with gr.Column():
229
+ num_steps_input = gr.Slider(
230
+ label="Number of Training Steps",
231
+ minimum=1,
232
+ maximum=100,
233
+ value=10,
234
+ step=1
235
+ )
236
+ lr_input = gr.Number(
237
+ label="Learning Rate",
238
+ value=1e-5,
239
+ step=1e-6
240
+ )
241
+ beta_input = gr.Number(
242
+ label="KL Penalty Weight (Ξ²)",
243
+ value=0.1,
244
+ step=0.01
245
+ )
246
+
247
+ train_btn = gr.Button("Start Training", variant="primary")
248
+ training_output = gr.Textbox(label="Training Progress", lines=10)
249
+
250
+ train_btn.click(
251
+ train_interface,
252
+ inputs=[prompts_input, num_steps_input, lr_input, beta_input],
253
+ outputs=training_output
254
+ )
255
+
256
+ with gr.Tab("πŸ’¬ Generation"):
257
+ with gr.Row():
258
+ with gr.Column():
259
+ gen_prompt = gr.Textbox(
260
+ label="Prompt",
261
+ placeholder="Enter your prompt here...",
262
+ value="Tell me about"
263
+ )
264
+ max_length = gr.Slider(
265
+ label="Max Length",
266
+ minimum=10,
267
+ maximum=500,
268
+ value=100,
269
+ step=10
270
+ )
271
+ temp_slider = gr.Slider(
272
+ label="Temperature",
273
+ minimum=0.1,
274
+ maximum=2.0,
275
+ value=0.8,
276
+ step=0.1
277
+ )
278
+
279
+ with gr.Column():
280
+ gen_btn = gr.Button("Generate", variant="primary")
281
+ gen_output = gr.Textbox(label="Generated Response", lines=10)
282
+
283
+ gen_btn.click(
284
+ generate_interface,
285
+ inputs=[gen_prompt, max_length, temp_slider],
286
+ outputs=gen_output
287
+ )
288
+
289
+ with gr.Tab("πŸ’Ύ Save Model"):
290
+ save_path_input = gr.Textbox(
291
+ label="Save Path",
292
+ value="./grpo_trained_model",
293
+ placeholder="Enter path to save the model"
294
+ )
295
+ save_btn = gr.Button("Save Model", variant="primary")
296
+ save_status = gr.Textbox(label="Save Status")
297
+
298
+ save_btn.click(save_model_interface, inputs=save_path_input, outputs=save_status)
299
+
300
+ with gr.Tab("πŸ“Š Training History"):
301
+ history_btn = gr.Button("Refresh History", variant="secondary")
302
+ history_output = gr.Textbox(label="Training History", lines=15)
303
+
304
+ history_btn.click(get_training_history, outputs=history_output)
305
+
306
+ gr.Markdown("""
307
+ ## πŸ“ Instructions:
308
+ 1. **Load Model**: Start by loading a pre-trained model from HuggingFace
309
+ 2. **Training**: Add your prompts and configure training parameters
310
+ 3. **Generation**: Test your trained model with custom prompts
311
+ 4. **Save**: Save your fine-tuned model for later use
312
+
313
+ ## ⚠️ Note:
314
+ - This is a simplified GRPO implementation for demonstration
315
+ - For production use, consider more sophisticated reward functions
316
+ - GPU recommended for larger models
317
+ """)
318
+
319
+ # Launch the app
320
+ if __name__ == "__main__":
321
+ app.launch(share=True)