|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq |
|
import datasets |
|
import torch |
|
import json |
|
import os |
|
import accelerate |
|
except ImportError: |
|
os.system('pip install "accelerate>=0.26.0"') |
|
|
|
|
|
MODEL_ID = "facebook/opt-350m" |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16, device_map="auto") |
|
|
|
|
|
def train_ui_tars(file): |
|
try: |
|
|
|
with open(file.name, "r", encoding="utf-8") as f: |
|
raw_data = json.load(f) |
|
|
|
|
|
training_data = raw_data.get("training_pairs", raw_data) |
|
|
|
|
|
fixed_json_path = "fixed_fraud_data.json" |
|
with open(fixed_json_path, "w", encoding="utf-8") as f: |
|
json.dump(training_data, f, indent=4) |
|
|
|
|
|
dataset = datasets.load_dataset("json", data_files=fixed_json_path) |
|
|
|
|
|
def tokenize_data(example): |
|
inputs = tokenizer(example["input"], padding="max_length", truncation=True, max_length=512) |
|
targets = tokenizer(example["output"], padding="max_length", truncation=True, max_length=512) |
|
inputs["labels"] = targets["input_ids"] |
|
return inputs |
|
|
|
tokenized_dataset = dataset.map(tokenize_data, batched=True) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir="./fine_tuned_llama2", |
|
per_device_train_batch_size=2, |
|
evaluation_strategy="no", |
|
save_strategy="epoch", |
|
save_total_limit=2, |
|
num_train_epochs=3, |
|
learning_rate=2e-5, |
|
weight_decay=0.01, |
|
logging_dir="./logs" |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=tokenized_dataset["train"], |
|
data_collator=DataCollatorForSeq2Seq(tokenizer, model=model) |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
model.save_pretrained("./fine_tuned_llama2") |
|
tokenizer.save_pretrained("./fine_tuned_llama2") |
|
|
|
return "Training completed successfully! Model saved to ./fine_tuned_llama2" |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
with gr.Blocks(title="Model Fine-Tuning Interface") as demo: |
|
gr.Markdown("# OPT-350M Fine-Tuning UI") |
|
gr.Markdown("Upload a JSON file with 'input' and 'output' pairs to fine-tune the model on your fraud dataset.") |
|
|
|
file_input = gr.File(label="Upload Fraud Dataset (JSON)") |
|
train_button = gr.Button("Start Fine-Tuning") |
|
output = gr.Textbox(label="Training Status") |
|
|
|
train_button.click(fn=train_ui_tars, inputs=file_input, outputs=output) |
|
|
|
|
|
demo.launch() |