Spaces:
Runtime error
Runtime error
from transformers import XLMRobertaTokenizer, XLMRobertaForSequenceClassification, Trainer, TrainingArguments | |
import datasets | |
# Load the pre-trained XLM-Roberta-Large model and tokenizer | |
model_name = 'xlm-roberta-large' | |
tokenizer = XLMRobertaTokenizer.from_pretrained(model_name) | |
model = XLMRobertaForSequenceClassification.from_pretrained(model_name, num_labels=2) | |
# Load the sentiment analysis dataset | |
dataset = datasets.load_dataset('imdb') | |
# Tokenize the dataset | |
def tokenize(batch): | |
return tokenizer(batch['text'], padding=True, truncation=True) | |
dataset = dataset.map(tokenize, batched=True) | |
# Fine-tune the model on the dataset | |
training_args = TrainingArguments( | |
output_dir='./results', | |
evaluation_strategy='epoch', | |
learning_rate=2e-5, | |
per_device_train_batch_size=8, | |
per_device_eval_batch_size=8, | |
num_train_epochs=3, | |
weight_decay=0.01, | |
push_to_hub=False, | |
logging_dir='./logs', | |
logging_steps=10, | |
load_best_model_at_end=True, | |
metric_for_best_model='accuracy' | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset['train'], | |
eval_dataset=dataset['test'] | |
) | |
trainer.train() | |
import torch | |
# Load the fine-tuned XLM-Roberta-Large model | |
model_path = './results/checkpoint-1000' | |
model = XLMRobertaForSequenceClassification.from_pretrained(model_path) | |
# Create a function that takes a text input and returns the predicted sentiment label | |
def predict_sentiment(text): | |
inputs = tokenizer(text, padding=True, truncation=True, return_tensors='pt') | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class = torch.argmax(logits, dim=1) | |
return 'positive' if predicted_class == 1 else 'negative' | |
import gradio as gr | |
# Create a Gradio interface for the predict_sentiment function | |
iface = gr.Interface( | |
fn=predict_sentiment, | |
inputs=gr.inputs.Textbox(placeholder='Enter text here...'), | |
outputs=gr.outputs.Textbox(placeholder='Sentiment prediction...') | |
) | |
# Launch the interface | |
iface.launch() |