Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSeq2SeqLM | |
from datasets import load_dataset | |
import traceback | |
def fine_tune_model(model_name, dataset_name, hub_id, num_epochs, batch_size, lr, grad): | |
''' | |
try: | |
# Load the dataset | |
dataset = load_dataset(dataset_name) | |
# Load the model and tokenizer | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, num_labels=2) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Tokenize the dataset | |
def tokenize_function(examples): | |
return tokenizer(examples['text'], padding="max_length", truncation=True) | |
tokenized_datasets = dataset.map(tokenize_function, batched=True) | |
# Set training arguments | |
training_args = TrainingArguments( | |
output_dir='./results', | |
evaluation_strategy="epoch", | |
learning_rate=lr, | |
per_device_train_batch_size=batch_size, | |
per_device_eval_batch_size=batch_size, | |
num_train_epochs=num_epochs, | |
weight_decay=0.01, | |
evaluation_strategy='epoch', | |
gradient_accumulation_steps=grad, | |
load_best_model_at_end=True, | |
metric_for_best_model="accuracy", | |
greater_is_better=True, | |
logging_dir='./logs', | |
logging_steps=10, | |
push_to_hub=True, | |
hub_model_id=hub_id, | |
) | |
# Create Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_datasets['train'], | |
eval_dataset=tokenized_datasets['validation'], | |
) | |
# Fine-tune the model | |
trainer.train() | |
trainer.push_to_hub(commit_message="Training complete!") | |
except Exception as e: | |
return f"An error occurred: {str(e)}, TB: {traceback.format_exc()}" | |
''' | |
return 'DONE!'#model | |
''' | |
# Define Gradio interface | |
def predict(text): | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
outputs = model(inputs) | |
predictions = outputs.logits.argmax(dim=-1) | |
return "Positive" if predictions.item() == 1 else "Negative" | |
''' | |
# Create Gradio interface | |
try: | |
iface = gr.Interface( | |
fine_tune_model, | |
inputs=[ | |
gr.inputs.Textbox(label="Model Name (e.g., 'google/t5-efficient-tiny-nh8')"), | |
gr.inputs.Textbox(label="Dataset Name (e.g., 'imdb')"), | |
gr.inputs.Textbox(label="HF hub to push to after training"), | |
gr.inputs.Slider(minimum=1, maximum=10, default=3, label="Number of Epochs"), | |
gr.inputs.Slider(minimum=1, maximum=16, default=4, label="Batch Size"), | |
gr.inputs.Slider(minimum=1, maximum=100, default=50, label="Learning Rate (e-5)"), | |
gr.inputs.Slider(minimum=1, maximum=100, default=1, label="Gradient accumulation (e-1)"), | |
], | |
outputs="text", | |
title="Fine-Tune Hugging Face Model", | |
description="This interface allows you to fine-tune a Hugging Face model on a specified dataset." | |
) | |
except Exception as e: | |
print(f"An error occurred: {str(e)}, TB: {traceback.format_exc()}") | |
# Launch the interface | |
iface.launch() |