Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
# Load the original pre-trained model | |
def load_model(model_name): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
return tokenizer, model | |
# Models to compare | |
original_model_name = "bert-base-uncased" # Replace with your original model | |
fine_tuned_model_name = "Vishwas1/bert-base-imdb" # Replace with your fine-tuned model's repo ID | |
# Load models | |
original_tokenizer, original_model = load_model(original_model_name) | |
fine_tuned_tokenizer, fine_tuned_model = load_model(fine_tuned_model_name) | |
# Ensure models are in evaluation mode | |
original_model.eval() | |
fine_tuned_model.eval() | |
def compare_models(text): | |
# Original model prediction | |
inputs_orig = original_tokenizer(text, return_tensors='pt', truncation=True, padding=True) | |
with torch.no_grad(): | |
outputs_orig = original_model(**inputs_orig) | |
logits_orig = outputs_orig.logits | |
probs_orig = torch.softmax(logits_orig, dim=1) | |
pred_orig = torch.argmax(probs_orig, dim=1).item() | |
confidence_orig = probs_orig[0][pred_orig].item() | |
# Fine-tuned model prediction | |
inputs_fine = fine_tuned_tokenizer(text, return_tensors='pt', truncation=True, padding=True) | |
with torch.no_grad(): | |
outputs_fine = fine_tuned_model(**inputs_fine) | |
logits_fine = outputs_fine.logits | |
probs_fine = torch.softmax(logits_fine, dim=1) | |
pred_fine = torch.argmax(probs_fine, dim=1).item() | |
confidence_fine = probs_fine[0][pred_fine].item() | |
# Map predictions to labels (adjust based on your model's labels) | |
labels = {0: "Negative", 1: "Positive"} | |
result = { | |
"Original Model Prediction": f"{labels[pred_orig]} ({confidence_orig:.2f})", | |
"Fine-Tuned Model Prediction": f"{labels[pred_fine]} ({confidence_fine:.2f})" | |
} | |
return result | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=compare_models, | |
inputs=gr.Textbox(lines=5, placeholder="Enter text here...", label="Input Text"), | |
outputs=gr.JSON(label="Model Predictions"), | |
title="Compare Original and Fine-Tuned Models", | |
description="Enter text to see predictions from the original and fine-tuned models." | |
) | |
iface.launch() | |