Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from torch.utils.data import DataLoader, Dataset | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AdamW | |
# Load model and tokenizer | |
model_name = "xlm-roberta-base" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) | |
# Prepare a custom dataset | |
train_texts = [ | |
"Water freezes at 0 degrees Celsius.", | |
"The sun rises in the west.", | |
"Dogs can fly in the sky.", | |
"Birds lay eggs.", | |
"The earth is flat.", | |
"Fish can swim in water.", | |
"Humans can live without oxygen.", | |
"Plants need sunlight to grow.", | |
"Cars run on milk.", | |
"The moon orbits the earth." | |
] | |
train_labels = [ | |
1, # Correct | |
0, # Incorrect | |
0, # Incorrect | |
1, # Correct | |
0, # Incorrect | |
1, # Correct | |
0, # Incorrect | |
1, # Correct | |
0, # Incorrect | |
1 # Correct | |
] | |
# Create Dataset class | |
class TextDataset(Dataset): | |
def __init__(self, texts, labels, tokenizer): | |
self.texts = texts | |
self.labels = labels | |
self.tokenizer = tokenizer | |
def __len__(self): | |
return len(self.texts) | |
def __getitem__(self, idx): | |
encodings = self.tokenizer(self.texts[idx], truncation=True, padding="max_length", max_length=128, return_tensors="pt") | |
item = {key: val.squeeze(0) for key, val in encodings.items()} | |
item['labels'] = torch.tensor(self.labels[idx]) | |
return item | |
# Load Dataset | |
train_dataset = TextDataset(train_texts, train_labels, tokenizer) | |
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True) | |
# Define optimizer | |
optimizer = AdamW(model.parameters(), lr=5e-5) | |
# Fine-tuning loop | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
model.train() | |
for epoch in range(5): # Train for 5 epochs | |
for batch in train_loader: | |
batch = {k: v.to(device) for k, v in batch.items()} | |
outputs = model(**batch) | |
loss = outputs.loss | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
print(f"Epoch {epoch+1} completed") | |
# Now model is fine-tuned! | |
# Define prediction function | |
def classify_text(text): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=128) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
output = model(**inputs) | |
label = torch.argmax(output.logits, dim=1).item() | |
return "Correct" if label == 1 else "Incorrect" | |
# Gradio UI | |
gradio_app = gr.Interface( | |
fn=classify_text, | |
inputs=gr.Textbox(label="Enter Text"), | |
outputs="text", | |
title="Multi-Language RL Model (Trained)" | |
) | |
gradio_app.launch() | |