MRasheq commited on
Commit
00941a5
·
1 Parent(s): 4b49da8

First commit

Browse files
Files changed (1) hide show
  1. app.py +279 -59
app.py CHANGED
@@ -1,64 +1,284 @@
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import json
4
+ import torch
5
+ import pandas as pd
6
  import gradio as gr
7
+ from sqlalchemy import create_engine, text
8
+ from transformers import (
9
+ TrainingArguments,
10
+ Trainer,
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ DataCollatorForLanguageModeling
14
+ )
15
+ from datasets import Dataset
16
+ from peft import (
17
+ prepare_model_for_kbit_training,
18
+ LoraConfig,
19
+ get_peft_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  )
21
+ from datetime import datetime
22
+
23
+ # Constants - Modified for HF Spaces
24
+ MODEL_NAME = "deepseek-ai/DeepSeek-R1"
25
+ OUTPUT_DIR = "/tmp/finetuned_models" # Using /tmp for HF Spaces
26
+ LOGS_DIR = "/tmp/training_logs" # Using /tmp for HF Spaces
27
+
28
+ class TrainingInterface:
29
+ def __init__(self):
30
+ self.current_status = "Idle"
31
+ self.progress = 0
32
+ self.is_training = False
33
+
34
+ def get_database_url(self):
35
+ """Get database URL from HF Space secrets"""
36
+ database_url = os.environ.get('DATABASE_URL')
37
+ if not database_url:
38
+ raise Exception("DATABASE_URL not found in environment variables")
39
+ return database_url
40
+
41
+ def fetch_training_data(self, progress=gr.Progress()):
42
+ """Fetch training data from database"""
43
+ try:
44
+ database_url = self.get_database_url()
45
+ engine = create_engine(database_url)
46
+
47
+ progress(0, desc="Connecting to database...")
48
+
49
+ with engine.connect() as conn:
50
+ result = conn.execute(text("SELECT COUNT(*) FROM bents"))
51
+ total_rows = result.scalar()
52
+
53
+ query = text("SELECT chunk_id, text FROM bents")
54
+ df = pd.read_sql_query(query, conn)
55
+
56
+ progress(0.5, desc="Data fetched successfully")
57
+ return df
58
+
59
+ except Exception as e:
60
+ raise gr.Error(f"Database error: {str(e)}")
61
+
62
+ def prepare_training_data(self, df, progress=gr.Progress()):
63
+ """Convert DataFrame into training format"""
64
+ formatted_data = []
65
+ try:
66
+ total_rows = len(df)
67
+ for idx, row in enumerate(df.iterrows()):
68
+ progress(idx/total_rows, desc="Preparing training data...")
69
+ _, row_data = row
70
+ chunk_id = str(row_data['chunk_id']).strip()
71
+ text = str(row_data['text']).strip()
72
+
73
+ if chunk_id and text:
74
+ formatted_text = f"User: {chunk_id}\nAssistant: {text}"
75
+ formatted_data.append({"text": formatted_text})
76
+
77
+ if not formatted_data:
78
+ raise ValueError("No valid training data found")
79
+
80
+ return formatted_data
81
+ except Exception as e:
82
+ raise gr.Error(f"Data preparation error: {str(e)}")
83
+
84
+ def stop_training(self):
85
+ """Stop the training process"""
86
+ self.is_training = False
87
+ return "Training stopped by user."
88
+
89
+ def train_model(
90
+ self,
91
+ learning_rate=2e-4,
92
+ num_epochs=3,
93
+ batch_size=4,
94
+ progress=gr.Progress()
95
+ ):
96
+ """Main training function"""
97
+ try:
98
+ self.is_training = True
99
+
100
+ # Create directories in /tmp for HF Spaces
101
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
102
+ specific_output_dir = os.path.join(OUTPUT_DIR, f"run_{timestamp}")
103
+ os.makedirs(specific_output_dir, exist_ok=True)
104
+ os.makedirs(LOGS_DIR, exist_ok=True)
105
+
106
+ # Data preparation
107
+ progress(0.1, desc="Fetching data...")
108
+ if not self.is_training:
109
+ return "Training cancelled."
110
+
111
+ df = self.fetch_training_data()
112
+ formatted_data = self.prepare_training_data(df)
113
+
114
+ # Model initialization
115
+ progress(0.2, desc="Loading model...")
116
+ if not self.is_training:
117
+ return "Training cancelled."
118
+
119
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
120
+ model = AutoModelForCausalLM.from_pretrained(
121
+ MODEL_NAME,
122
+ trust_remote_code=True,
123
+ torch_dtype=torch.float16,
124
+ load_in_8bit=True,
125
+ device_map="auto" # Important for HF Spaces GPU allocation
126
+ )
127
+
128
+ # LoRA configuration
129
+ progress(0.3, desc="Setting up LoRA...")
130
+ if not self.is_training:
131
+ return "Training cancelled."
132
+
133
+ lora_config = LoraConfig(
134
+ r=16,
135
+ lora_alpha=32,
136
+ target_modules=[
137
+ "q_proj", "k_proj", "v_proj", "o_proj",
138
+ "gate_proj", "up_proj", "down_proj"
139
+ ],
140
+ lora_dropout=0.05,
141
+ bias="none",
142
+ task_type="CAUSAL_LM"
143
+ )
144
+
145
+ model = prepare_model_for_kbit_training(model)
146
+ model = get_peft_model(model, lora_config)
147
+
148
+ # Training setup
149
+ progress(0.4, desc="Configuring training...")
150
+ if not self.is_training:
151
+ return "Training cancelled."
152
+
153
+ training_args = TrainingArguments(
154
+ output_dir=specific_output_dir,
155
+ num_train_epochs=num_epochs,
156
+ per_device_train_batch_size=batch_size,
157
+ learning_rate=learning_rate,
158
+ fp16=True,
159
+ gradient_accumulation_steps=8,
160
+ gradient_checkpointing=True,
161
+ logging_dir=os.path.join(LOGS_DIR, f"run_{timestamp}"),
162
+ logging_steps=10,
163
+ save_strategy="epoch",
164
+ evaluation_strategy="epoch",
165
+ save_total_limit=2,
166
+ remove_unused_columns=False, # Important for HF Spaces
167
+ )
168
+
169
+ dataset = Dataset.from_dict({
170
+ 'text': [item['text'] for item in formatted_data]
171
+ })
172
+
173
+ data_collator = DataCollatorForLanguageModeling(
174
+ tokenizer=tokenizer,
175
+ mlm=False
176
+ )
177
+
178
+ # Custom progress callback
179
+ class ProgressCallback(gr.Progress):
180
+ def __init__(self, progress_callback, training_interface):
181
+ self.progress_callback = progress_callback
182
+ self.training_interface = training_interface
183
+
184
+ def on_train_begin(self, args, state, control, **kwargs):
185
+ if not self.training_interface.is_training:
186
+ control.should_training_stop = True
187
+ self.progress_callback(0.5, desc="Training started...")
188
+
189
+ def on_epoch_begin(self, args, state, control, **kwargs):
190
+ if not self.training_interface.is_training:
191
+ control.should_training_stop = True
192
+ epoch_progress = (state.epoch / args.num_train_epochs)
193
+ total_progress = 0.5 + (epoch_progress * 0.4)
194
+ self.progress_callback(total_progress,
195
+ desc=f"Training epoch {state.epoch + 1}/{args.num_train_epochs}...")
196
+
197
+ trainer = Trainer(
198
+ model=model,
199
+ args=training_args,
200
+ train_dataset=dataset,
201
+ data_collator=data_collator,
202
+ callbacks=[ProgressCallback(progress, self)]
203
+ )
204
+
205
+ if not self.is_training:
206
+ return "Training cancelled."
207
+
208
+ trainer.train()
209
+
210
+ if not self.is_training:
211
+ return "Training cancelled."
212
+
213
+ # Save model
214
+ progress(0.9, desc="Saving model...")
215
+ trainer.save_model()
216
+ tokenizer.save_pretrained(specific_output_dir)
217
+
218
+ progress(1.0, desc="Training completed!")
219
+ return f"Training completed! Model saved in {specific_output_dir}"
220
+
221
+ except Exception as e:
222
+ self.is_training = False
223
+ raise gr.Error(f"Training error: {str(e)}")
224
+
225
+ def create_training_interface():
226
+ """Create Gradio interface"""
227
+ interface = TrainingInterface()
228
+
229
+ with gr.Blocks(title="DeepSeek Model Training Interface") as app:
230
+ gr.Markdown("# DeepSeek Model Fine-tuning Interface")
231
+
232
+ with gr.Row():
233
+ with gr.Column():
234
+ learning_rate = gr.Slider(
235
+ minimum=1e-5,
236
+ maximum=1e-3,
237
+ value=2e-4,
238
+ label="Learning Rate"
239
+ )
240
+ num_epochs = gr.Slider(
241
+ minimum=1,
242
+ maximum=10,
243
+ value=3,
244
+ step=1,
245
+ label="Number of Epochs"
246
+ )
247
+ batch_size = gr.Slider(
248
+ minimum=1,
249
+ maximum=8,
250
+ value=4,
251
+ step=1,
252
+ label="Batch Size"
253
+ )
254
+
255
+ with gr.Row():
256
+ train_button = gr.Button("Start Training", variant="primary")
257
+ stop_button = gr.Button("Stop Training", variant="secondary")
258
+
259
+ output_text = gr.Textbox(
260
+ label="Training Status",
261
+ placeholder="Training status will appear here...",
262
+ lines=10
263
+ )
264
+
265
+ train_button.click(
266
+ fn=interface.train_model,
267
+ inputs=[learning_rate, num_epochs, batch_size],
268
+ outputs=output_text
269
+ )
270
+
271
+ stop_button.click(
272
+ fn=interface.stop_training,
273
+ inputs=[],
274
+ outputs=output_text
275
+ )
276
 
277
+ return app
278
 
279
  if __name__ == "__main__":
280
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
281
+ os.makedirs(LOGS_DIR, exist_ok=True)
282
+
283
+ app = create_training_interface()
284
+ app.launch()