Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| from transformers import BertTokenizer, BertModel, ViTModel | |
| import gradio as gr | |
| # Set device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Define the VQA Model | |
| class VQAModel(nn.Module): | |
| def __init__(self, vit_model, bert_model, num_classes, hidden_size=768): | |
| super(VQAModel, self).__init__() | |
| self.vit_model = vit_model | |
| self.bert_model = bert_model | |
| self.fc = nn.Linear(768 + hidden_size, hidden_size) # Adjust input size to match concatenated features | |
| self.classifier = nn.Linear(hidden_size, num_classes) # num_classes is dynamically determined | |
| def forward(self, image, question): | |
| # Extract image features | |
| with torch.no_grad(): | |
| image_features = self.vit_model(image).last_hidden_state[:, 0, :] # [CLS] token, Shape: (batch_size, 768) | |
| # Extract text features | |
| with torch.no_grad(): | |
| question_encoded = self.bert_model(question).last_hidden_state[:, 0, :] # [CLS] token, Shape: (batch_size, 768) | |
| # Concatenate image and text features | |
| combined_features = torch.cat((image_features, question_encoded), dim=1) # Shape: (batch_size, 1536) | |
| # Pass through fully connected layer | |
| combined_features = self.fc(combined_features) # Shape: (batch_size, hidden_size) | |
| # Classify | |
| output = self.classifier(combined_features) # Shape: (batch_size, num_classes) | |
| return output | |
| # Load the saved model checkpoint | |
| checkpoint_path = 'vqa_vit_best_model.pth' # Path to the saved model | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| # Load ViT and BERT models | |
| vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k').to(device) | |
| bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| bert_model = BertModel.from_pretrained('bert-base-uncased').to(device) | |
| # Initialize the VQA model with the correct number of classes | |
| model = VQAModel(vit_model, bert_model, num_classes=checkpoint['num_classes']).to(device) | |
| # Load the model state dict | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| # Load the answer-to-label mapping | |
| answer_to_label = checkpoint['answer_to_label'] | |
| label_to_answer = {v: k for k, v in answer_to_label.items()} # Reverse mapping for inference | |
| # Set the model to evaluation mode | |
| model.eval() | |
| # Define transformations for the image | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), # Resize to 224x224 as required by ViT | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize for ViT | |
| ]) | |
| # Function to preprocess and predict | |
| def predict(image_path, question): | |
| # Load and transform the image | |
| image = Image.open(image_path).convert('RGB') | |
| image = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to device | |
| # Tokenize the question | |
| question_encoded = bert_tokenizer( | |
| question, | |
| return_tensors='pt', | |
| padding='max_length', # Pad to the maximum length | |
| truncation=True, # Truncate if the question is too long | |
| max_length=32 # Set a maximum sequence length | |
| ).to(device) | |
| # Perform inference | |
| with torch.no_grad(): | |
| output = model(image, question_encoded['input_ids']) | |
| # Get the predicted label | |
| _, predicted_label = torch.max(output, 1) | |
| predicted_label = predicted_label.item() | |
| # Map the label back to the answer | |
| predicted_answer = label_to_answer[predicted_label] | |
| return predicted_answer | |
| # Define the question (already set) | |
| question = "What is the overall complexity of this model?" | |
| # Define the Gradio interface function | |
| def vqa_interface(image): | |
| # Predict the answer using the provided image and the predefined question | |
| predicted_answer = predict(image, question) | |
| return predicted_answer | |
| # Create the Gradio interface | |
| iface = gr.Interface( | |
| fn=vqa_interface, # Function to call | |
| inputs=gr.Image(type="filepath"), # Input type: image file path | |
| outputs="text", # Output type: text (predicted answer) | |
| title="Visual Question Answering (VQA) System", | |
| description="Upload an image, and the system will answer the question: 'What is the overall complexity of this model?'", | |
| examples=[ | |
| ["02_uml.png"],["2ndIterationClassDiagram.png"],["4-gameUML.png"]] | |
| ) | |
| # Launch the Gradio interface | |
| iface.launch() |