SivaMallikarjun's picture
Update app.py
6281c34 verified
raw
history blame
2.76 kB
import gradio as gr
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AdamW
# Load model and tokenizer
model_name = "xlm-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
# Prepare a custom dataset
train_texts = [
"Water freezes at 0 degrees Celsius.",
"The sun rises in the west.",
"Dogs can fly in the sky.",
"Birds lay eggs.",
"The earth is flat.",
"Fish can swim in water.",
"Humans can live without oxygen.",
"Plants need sunlight to grow.",
"Cars run on milk.",
"The moon orbits the earth."
]
train_labels = [
1, # Correct
0, # Incorrect
0, # Incorrect
1, # Correct
0, # Incorrect
1, # Correct
0, # Incorrect
1, # Correct
0, # Incorrect
1 # Correct
]
# Create Dataset class
class TextDataset(Dataset):
def __init__(self, texts, labels, tokenizer):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
encodings = self.tokenizer(self.texts[idx], truncation=True, padding="max_length", max_length=128, return_tensors="pt")
item = {key: val.squeeze(0) for key, val in encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
# Load Dataset
train_dataset = TextDataset(train_texts, train_labels, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
# Define optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)
# Fine-tuning loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()
for epoch in range(5): # Train for 5 epochs
for batch in train_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"Epoch {epoch+1} completed")
# Now model is fine-tuned!
# Define prediction function
def classify_text(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
output = model(**inputs)
label = torch.argmax(output.logits, dim=1).item()
return "Correct" if label == 1 else "Incorrect"
# Gradio UI
gradio_app = gr.Interface(
fn=classify_text,
inputs=gr.Textbox(label="Enter Text"),
outputs="text",
title="Multi-Language RL Model (Trained)"
)
gradio_app.launch()