|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer |
|
from PIL import Image |
|
from torchvision import transforms |
|
|
|
|
|
model = load_model(model_weights.pth) |
|
model.eval() |
|
text_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
|
|
|
|
image_transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
|
|
def predict(image: Image.Image, text: str) -> str: |
|
|
|
text_inputs = text_tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=512 |
|
) |
|
|
|
|
|
image_input = image_transform(image).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
classification_output = model( |
|
pixel_values=image_input, |
|
input_ids=text_inputs["input_ids"], |
|
attention_mask=text_inputs["attention_mask"] |
|
) |
|
predicted_class = torch.sigmoid(classification_output).round().item() |
|
|
|
return "Biased" if predicted_class == 1 else "Unbiased" |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
gr.Image(type="pil", label="Upload Image"), |
|
gr.Textbox(lines=2, placeholder="Enter text for classification...", label="Input Text") |
|
], |
|
outputs=gr.Label(label="Prediction"), |
|
title="Multimodal Bias Classifier", |
|
description="Upload an image and provide a text to classify it as 'Biased' or 'Unbiased'." |
|
) |
|
|
|
interface.launch() |
|
|