Rehman1603's picture
Create app.py
2a2d54d verified
raw
history blame
4.49 kB
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from transformers import BertTokenizer, BertModel, ViTModel
import gradio as gr
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define the VQA Model
class VQAModel(nn.Module):
def __init__(self, vit_model, bert_model, num_classes, hidden_size=768):
super(VQAModel, self).__init__()
self.vit_model = vit_model
self.bert_model = bert_model
self.fc = nn.Linear(768 + hidden_size, hidden_size) # Adjust input size to match concatenated features
self.classifier = nn.Linear(hidden_size, num_classes) # num_classes is dynamically determined
def forward(self, image, question):
# Extract image features
with torch.no_grad():
image_features = self.vit_model(image).last_hidden_state[:, 0, :] # [CLS] token, Shape: (batch_size, 768)
# Extract text features
with torch.no_grad():
question_encoded = self.bert_model(question).last_hidden_state[:, 0, :] # [CLS] token, Shape: (batch_size, 768)
# Concatenate image and text features
combined_features = torch.cat((image_features, question_encoded), dim=1) # Shape: (batch_size, 1536)
# Pass through fully connected layer
combined_features = self.fc(combined_features) # Shape: (batch_size, hidden_size)
# Classify
output = self.classifier(combined_features) # Shape: (batch_size, num_classes)
return output
# Load the saved model checkpoint
checkpoint_path = 'vqa_vit_best_model.pth' # Path to the saved model
checkpoint = torch.load(checkpoint_path, map_location=device)
# Load ViT and BERT models
vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k').to(device)
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased').to(device)
# Initialize the VQA model with the correct number of classes
model = VQAModel(vit_model, bert_model, num_classes=checkpoint['num_classes']).to(device)
# Load the model state dict
model.load_state_dict(checkpoint['model_state_dict'])
# Load the answer-to-label mapping
answer_to_label = checkpoint['answer_to_label']
label_to_answer = {v: k for k, v in answer_to_label.items()} # Reverse mapping for inference
# Set the model to evaluation mode
model.eval()
# Define transformations for the image
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to 224x224 as required by ViT
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize for ViT
])
# Function to preprocess and predict
def predict(image_path, question):
# Load and transform the image
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to device
# Tokenize the question
question_encoded = bert_tokenizer(
question,
return_tensors='pt',
padding='max_length', # Pad to the maximum length
truncation=True, # Truncate if the question is too long
max_length=32 # Set a maximum sequence length
).to(device)
# Perform inference
with torch.no_grad():
output = model(image, question_encoded['input_ids'])
# Get the predicted label
_, predicted_label = torch.max(output, 1)
predicted_label = predicted_label.item()
# Map the label back to the answer
predicted_answer = label_to_answer[predicted_label]
return predicted_answer
# Define the question (already set)
question = "What is the overall complexity of this model?"
# Define the Gradio interface function
def vqa_interface(image):
# Predict the answer using the provided image and the predefined question
predicted_answer = predict(image, question)
return predicted_answer
# Create the Gradio interface
iface = gr.Interface(
fn=vqa_interface, # Function to call
inputs=gr.Image(type="filepath"), # Input type: image file path
outputs="text", # Output type: text (predicted answer)
title="Visual Question Answering (VQA) System",
description="Upload an image, and the system will answer the question: 'What is the overall complexity of this model?'",
examples=[
["02_uml.png"],["2ndIterationClassDiagram.png"],["4-gameUML.png"]]
)
# Launch the Gradio interface
iface.launch()