Spaces:
Sleeping
Sleeping
import os | |
import json | |
import gradio as gr | |
import torch | |
import pandas as pd | |
from transformers import ( | |
TrainingArguments, | |
Trainer, | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
DataCollatorForLanguageModeling | |
) | |
from datasets import Dataset | |
from peft import ( | |
prepare_model_for_kbit_training, | |
LoraConfig, | |
get_peft_model | |
) | |
# Constants | |
MODEL_NAME = "deepseek-ai/DeepSeek-R1" | |
OUTPUT_DIR = "finetuned_models" | |
LOGS_DIR = "training_logs" | |
def save_uploaded_file(file_obj): | |
"""Save uploaded file and return its path""" | |
try: | |
os.makedirs('uploads', exist_ok=True) | |
import tempfile | |
# Create a temporary file | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv', dir='uploads') | |
# Write the content | |
if isinstance(file_obj, (bytes, bytearray)): | |
temp_file.write(file_obj) | |
else: | |
content = file_obj.read() | |
if isinstance(content, str): | |
temp_file.write(content.encode('utf-8')) | |
else: | |
temp_file.write(content) | |
temp_file.close() | |
return temp_file.name | |
except Exception as e: | |
raise Exception(f"Error saving file: {str(e)}") | |
def prepare_training_data(df): | |
"""Convert DataFrame into Q&A format""" | |
formatted_data = [] | |
try: | |
for _, row in df.iterrows(): | |
# Clean and validate the data | |
chunk_id = str(row['chunk_id']).strip() | |
text = str(row['text']).strip() | |
if chunk_id and text: # Only include non-empty pairs | |
# Format each conversation in the required structure | |
formatted_text = f"User: {chunk_id}\nAssistant: {text}" | |
formatted_data.append({"text": formatted_text}) | |
if not formatted_data: | |
raise ValueError("No valid training pairs found in the data") | |
return formatted_data | |
except Exception as e: | |
raise Exception(f"Error preparing training data: {str(e)}") | |
def prepare_training_components( | |
data_path, | |
learning_rate, | |
num_epochs, | |
batch_size, | |
model_name=MODEL_NAME | |
): | |
"""Prepare model, tokenizer, and training arguments""" | |
print(f"Loading data from: {data_path}") # Debug logging | |
"""Prepare model, tokenizer, and training arguments""" | |
# Create output directory with timestamp | |
import time | |
timestamp = time.strftime("%Y%m%d_%H%M%S") | |
specific_output_dir = os.path.join(OUTPUT_DIR, f"run_{timestamp}") | |
os.makedirs(specific_output_dir, exist_ok=True) | |
os.makedirs(LOGS_DIR, exist_ok=True) | |
# Load data and convert to Q&A format | |
try: | |
df = pd.read_csv(data_path, encoding='utf-8') | |
print(f"Loaded CSV with {len(df)} rows") # Debug logging | |
formatted_data = prepare_training_data(df) | |
print(f"Prepared {len(formatted_data)} training examples") # Debug logging | |
except Exception as e: | |
print(f"Error loading CSV: {str(e)}") # Debug logging | |
raise | |
# Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
torch_dtype=torch.float16, | |
load_in_8bit=True | |
) | |
# LoRA Configuration | |
lora_config = LoraConfig( | |
r=16, | |
lora_alpha=32, | |
target_modules=[ | |
"q_proj", "k_proj", "v_proj", "o_proj", | |
"gate_proj", "up_proj", "down_proj" | |
], | |
lora_dropout=0.05, | |
bias="none", | |
task_type="CAUSAL_LM" | |
) | |
# Prepare model | |
model = prepare_model_for_kbit_training(model) | |
model = get_peft_model(model, lora_config) | |
# Training Arguments | |
training_args = TrainingArguments( | |
output_dir=specific_output_dir, | |
num_train_epochs=num_epochs, | |
per_device_train_batch_size=batch_size, | |
learning_rate=learning_rate, | |
fp16=True, | |
gradient_accumulation_steps=8, | |
gradient_checkpointing=True, | |
logging_dir=os.path.join(LOGS_DIR, f"run_{timestamp}"), | |
logging_steps=10, | |
save_strategy="epoch", | |
evaluation_strategy="epoch", | |
save_total_limit=2, | |
) | |
# Convert to datasets format | |
dataset = Dataset.from_dict({ | |
'text': [item['text'] for item in formatted_data] | |
}) | |
# Create data collator | |
data_collator = DataCollatorForLanguageModeling( | |
tokenizer=tokenizer, | |
mlm=False | |
) | |
return { | |
'model': model, | |
'tokenizer': tokenizer, | |
'training_args': training_args, | |
'dataset': dataset, | |
'data_collator': data_collator, | |
'output_dir': specific_output_dir | |
} | |
def train_model( | |
file, | |
learning_rate=2e-4, | |
num_epochs=3, | |
batch_size=4, | |
progress=gr.Progress() | |
): | |
"""Training function for Gradio interface""" | |
if file is None: | |
return "Please upload a file first." | |
try: | |
# File validation | |
progress(0.1, desc="Validating file...") | |
file_path = save_uploaded_file(file) | |
# Prepare components | |
progress(0.2, desc="Preparing training components...") | |
components = prepare_training_components( | |
file_path, | |
learning_rate, | |
num_epochs, | |
batch_size | |
) | |
# Initialize trainer | |
progress(0.4, desc="Initializing trainer...") | |
trainer = Trainer( | |
model=components['model'], | |
args=components['training_args'], | |
train_dataset=components['dataset'], | |
data_collator=components['data_collator'], | |
) | |
# Train | |
progress(0.5, desc="Training model...") | |
trainer.train() | |
# Save model and tokenizer | |
progress(0.9, desc="Saving model...") | |
trainer.save_model() | |
components['tokenizer'].save_pretrained(components['output_dir']) | |
progress(1.0, desc="Training complete!") | |
return f"Training completed! Model saved in {components['output_dir']}" | |
except Exception as e: | |
error_msg = f"Error during training: {str(e)}" | |
print(error_msg) # Log the error | |
return error_msg | |
"""Training function for Gradio interface""" | |
try: | |
# Save uploaded file | |
file_path = save_uploaded_file(file) | |
# Prepare components | |
progress(0.2, desc="Preparing training components...") | |
components = prepare_training_components( | |
file_path, | |
learning_rate, | |
num_epochs, | |
batch_size | |
) | |
# Initialize trainer | |
progress(0.4, desc="Initializing trainer...") | |
trainer = Trainer( | |
model=components['model'], | |
args=components['training_args'], | |
train_dataset=components['dataset'], | |
data_collator=components['data_collator'], | |
) | |
# Train | |
progress(0.5, desc="Training model...") | |
trainer.train() | |
# Save model and tokenizer | |
progress(0.9, desc="Saving model...") | |
trainer.save_model() | |
components['tokenizer'].save_pretrained(components['output_dir']) | |
progress(1.0, desc="Training complete!") | |
return f"Training completed! Model saved in {components['output_dir']}" | |
except Exception as e: | |
return f"Error during training: {str(e)}" | |
# Create Gradio interface | |
def create_interface(): | |
demo = gr.Interface( | |
# Configure Gradio to handle larger file uploads | |
upload_size_limit=100 | |
) | |
with gr.Row(): | |
with gr.Column(): | |
file_input = gr.File( | |
label="Upload Training Data (CSV)", | |
type="binary", | |
file_types=[".csv"] | |
) | |
learning_rate = gr.Slider( | |
minimum=1e-5, | |
maximum=1e-3, | |
value=2e-4, | |
label="Learning Rate" | |
) | |
num_epochs = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=3, | |
step=1, | |
label="Number of Epochs" | |
) | |
batch_size = gr.Slider( | |
minimum=1, | |
maximum=8, | |
value=4, | |
step=1, | |
label="Batch Size" | |
) | |
train_button = gr.Button("Start Training") | |
with gr.Column(): | |
output = gr.Textbox(label="Training Status") | |
train_button.click( | |
fn=train_model, | |
inputs=[file_input, learning_rate, num_epochs, batch_size], | |
outputs=output | |
) | |
gr.Markdown(""" | |
## Instructions | |
1. Upload your training data in CSV format with columns: | |
- chunk_id (questions) | |
- text (answers) | |
2. Adjust training parameters if needed | |
3. Click 'Start Training' | |
4. Wait for training to complete | |
""") | |
return demo | |
if __name__ == "__main__": | |
# Create necessary directories | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
os.makedirs(LOGS_DIR, exist_ok=True) | |
# Launch Gradio interface | |
demo = create_interface() | |
demo.launch(share=True) |