MRasheq commited on
Commit
a9281cb
·
1 Parent(s): 7ed8a15
Files changed (1) hide show
  1. app.py +126 -64
app.py CHANGED
@@ -22,17 +22,32 @@ MODEL_NAME = "deepseek-ai/DeepSeek-R1"
22
  OUTPUT_DIR = "finetuned_models"
23
  LOGS_DIR = "training_logs"
24
 
25
- def save_uploaded_file(file):
26
  """Save uploaded file and return its path"""
27
- os.makedirs('uploads', exist_ok=True)
28
- import tempfile
29
-
30
- # Create a temporary file with .csv extension
31
- temp = tempfile.NamedTemporaryFile(delete=False, suffix='.csv', dir='uploads')
32
- temp.write(file)
33
- temp.close()
34
-
35
- return temp.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def prepare_training_data(df):
38
  """Convert DataFrame into Q&A format"""
@@ -133,6 +148,49 @@ def train_model(
133
  progress=gr.Progress()
134
  ):
135
  """Training function for Gradio interface"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  try:
137
  # Save uploaded file
138
  file_path = save_uploaded_file(file)
@@ -172,60 +230,64 @@ def train_model(
172
 
173
  # Create Gradio interface
174
  def create_interface():
175
- with gr.Blocks() as demo:
176
- gr.Markdown("# DeepSeek-R1 Model Finetuning Interface")
177
-
178
- with gr.Row():
179
- with gr.Column():
180
- file_input = gr.File(
181
- label="Upload Training Data (CSV)",
182
- type="binary",
183
- file_types=[".csv"]
184
- )
185
-
186
- learning_rate = gr.Slider(
187
- minimum=1e-5,
188
- maximum=1e-3,
189
- value=2e-4,
190
- label="Learning Rate"
191
- )
192
-
193
- num_epochs = gr.Slider(
194
- minimum=1,
195
- maximum=10,
196
- value=3,
197
- step=1,
198
- label="Number of Epochs"
199
- )
200
-
201
- batch_size = gr.Slider(
202
- minimum=1,
203
- maximum=8,
204
- value=4,
205
- step=1,
206
- label="Batch Size"
207
- )
208
-
209
- train_button = gr.Button("Start Training")
210
-
211
- with gr.Column():
212
- output = gr.Textbox(label="Training Status")
213
-
214
- train_button.click(
215
- fn=train_model,
216
- inputs=[file_input, learning_rate, num_epochs, batch_size],
217
- outputs=output
218
- )
219
-
220
- gr.Markdown("""
221
- ## Instructions
222
- 1. Upload your training data in CSV format with columns:
223
- - chunk_id (questions)
224
- - text (answers)
225
- 2. Adjust training parameters if needed
226
- 3. Click 'Start Training'
227
- 4. Wait for training to complete
228
- """)
 
 
 
 
229
 
230
  return demo
231
 
 
22
  OUTPUT_DIR = "finetuned_models"
23
  LOGS_DIR = "training_logs"
24
 
25
+ def save_uploaded_file(file_obj):
26
  """Save uploaded file and return its path"""
27
+ try:
28
+ os.makedirs('uploads', exist_ok=True)
29
+
30
+ if hasattr(file_obj, 'name'):
31
+ # If it's a FileUpload object
32
+ file_path = os.path.join('uploads', os.path.basename(file_obj.name))
33
+ if isinstance(file_obj, (bytes, bytearray)):
34
+ with open(file_path, 'wb') as f:
35
+ f.write(file_obj)
36
+ else:
37
+ file_obj.save(file_path)
38
+ else:
39
+ # If it's raw bytes
40
+ import tempfile
41
+ fd, file_path = tempfile.mkstemp(suffix='.csv', dir='uploads')
42
+ with os.fdopen(fd, 'wb') as temp:
43
+ if isinstance(file_obj, (bytes, bytearray)):
44
+ temp.write(file_obj)
45
+ else:
46
+ temp.write(file_obj.read())
47
+
48
+ return file_path
49
+ except Exception as e:
50
+ raise Exception(f"Error saving file: {str(e)}")
51
 
52
  def prepare_training_data(df):
53
  """Convert DataFrame into Q&A format"""
 
148
  progress=gr.Progress()
149
  ):
150
  """Training function for Gradio interface"""
151
+ if file is None:
152
+ return "Please upload a file first."
153
+
154
+ try:
155
+ # File validation
156
+ progress(0.1, desc="Validating file...")
157
+ file_path = save_uploaded_file(file)
158
+
159
+ # Prepare components
160
+ progress(0.2, desc="Preparing training components...")
161
+ components = prepare_training_components(
162
+ file_path,
163
+ learning_rate,
164
+ num_epochs,
165
+ batch_size
166
+ )
167
+
168
+ # Initialize trainer
169
+ progress(0.4, desc="Initializing trainer...")
170
+ trainer = Trainer(
171
+ model=components['model'],
172
+ args=components['training_args'],
173
+ train_dataset=components['dataset'],
174
+ data_collator=components['data_collator'],
175
+ )
176
+
177
+ # Train
178
+ progress(0.5, desc="Training model...")
179
+ trainer.train()
180
+
181
+ # Save model and tokenizer
182
+ progress(0.9, desc="Saving model...")
183
+ trainer.save_model()
184
+ components['tokenizer'].save_pretrained(components['output_dir'])
185
+
186
+ progress(1.0, desc="Training complete!")
187
+ return f"Training completed! Model saved in {components['output_dir']}"
188
+
189
+ except Exception as e:
190
+ error_msg = f"Error during training: {str(e)}"
191
+ print(error_msg) # Log the error
192
+ return error_msg
193
+ """Training function for Gradio interface"""
194
  try:
195
  # Save uploaded file
196
  file_path = save_uploaded_file(file)
 
230
 
231
  # Create Gradio interface
232
  def create_interface():
233
+ # Configure Gradio to handle larger file uploads
234
+ demo = gr.Interface(
235
+ title="Model Fine-tuning Interface"
236
+ )
237
+
238
+ gr.Config(upload_size_limit=100)
239
+
240
+ with gr.Row():
241
+ with gr.Column():
242
+ file_input = gr.File(
243
+ label="Upload Training Data (CSV)",
244
+ type="binary",
245
+ file_types=[".csv"]
246
+ )
247
+
248
+ learning_rate = gr.Slider(
249
+ minimum=1e-5,
250
+ maximum=1e-3,
251
+ value=2e-4,
252
+ label="Learning Rate"
253
+ )
254
+
255
+ num_epochs = gr.Slider(
256
+ minimum=1,
257
+ maximum=10,
258
+ value=3,
259
+ step=1,
260
+ label="Number of Epochs"
261
+ )
262
+
263
+ batch_size = gr.Slider(
264
+ minimum=1,
265
+ maximum=8,
266
+ value=4,
267
+ step=1,
268
+ label="Batch Size"
269
+ )
270
+
271
+ train_button = gr.Button("Start Training")
272
+
273
+ with gr.Column():
274
+ output = gr.Textbox(label="Training Status")
275
+
276
+ train_button.click(
277
+ fn=train_model,
278
+ inputs=[file_input, learning_rate, num_epochs, batch_size],
279
+ outputs=output
280
+ )
281
+
282
+ gr.Markdown("""
283
+ ## Instructions
284
+ 1. Upload your training data in CSV format with columns:
285
+ - chunk_id (questions)
286
+ - text (answers)
287
+ 2. Adjust training parameters if needed
288
+ 3. Click 'Start Training'
289
+ 4. Wait for training to complete
290
+ """)
291
 
292
  return demo
293