Cylanoid's picture
Upload 4 files
b4ff959 verified
raw
history blame
3.16 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq
import datasets
import torch
import json
import os
import accelerate
except ImportError:
os.system('pip install "accelerate>=0.26.0"')
# Model setup
MODEL_ID = "facebook/opt-350m" # Smaller, open access model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16, device_map="auto")
# Function to process uploaded JSON and train
def train_ui_tars(file):
try:
# Step 1: Load and preprocess the uploaded JSON file
with open(file.name, "r", encoding="utf-8") as f:
raw_data = json.load(f)
# Extract training pairs or use flat structure
training_data = raw_data.get("training_pairs", raw_data)
# Save fixed JSON to avoid issues
fixed_json_path = "fixed_fraud_data.json"
with open(fixed_json_path, "w", encoding="utf-8") as f:
json.dump(training_data, f, indent=4)
# Load dataset
dataset = datasets.load_dataset("json", data_files=fixed_json_path)
# Step 2: Tokenize dataset
def tokenize_data(example):
inputs = tokenizer(example["input"], padding="max_length", truncation=True, max_length=512)
targets = tokenizer(example["output"], padding="max_length", truncation=True, max_length=512)
inputs["labels"] = targets["input_ids"]
return inputs
tokenized_dataset = dataset.map(tokenize_data, batched=True)
# Step 3: Training setup
training_args = TrainingArguments(
output_dir="./fine_tuned_llama2",
per_device_train_batch_size=2,
evaluation_strategy="no",
save_strategy="epoch",
save_total_limit=2,
num_train_epochs=3,
learning_rate=2e-5,
weight_decay=0.01,
logging_dir="./logs"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
data_collator=DataCollatorForSeq2Seq(tokenizer, model=model)
)
# Step 4: Start training
trainer.train()
# Step 5: Save the model
model.save_pretrained("./fine_tuned_llama2")
tokenizer.save_pretrained("./fine_tuned_llama2")
return "Training completed successfully! Model saved to ./fine_tuned_llama2"
except Exception as e:
return f"Error: {str(e)}"
# Gradio UI
with gr.Blocks(title="Model Fine-Tuning Interface") as demo:
gr.Markdown("# OPT-350M Fine-Tuning UI")
gr.Markdown("Upload a JSON file with 'input' and 'output' pairs to fine-tune the model on your fraud dataset.")
file_input = gr.File(label="Upload Fraud Dataset (JSON)")
train_button = gr.Button("Start Fine-Tuning")
output = gr.Textbox(label="Training Status")
train_button.click(fn=train_ui_tars, inputs=file_input, outputs=output)
# Launch the app
demo.launch()