tdnathmlenthusiast's picture
added all required files for model
3d93357 verified
import gradio as gr
import torch
from PIL import Image
import numpy as np
from transformers import AutoImageProcessor, SwinForImageClassification
from torchvision import transforms
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load Swin Transformer model with original classifier (1000 classes)
swin_processor = AutoImageProcessor.from_pretrained("microsoft/swin-large-patch4-window12-384")
model = SwinForImageClassification.from_pretrained("microsoft/swin-large-patch4-window12-384")
# Modify input channels to 4 (RGB + mask)
original_conv = model.swin.embeddings.patch_embeddings.projection
new_conv = torch.nn.Conv2d(
in_channels=4,
out_channels=original_conv.out_channels,
kernel_size=original_conv.kernel_size,
stride=original_conv.stride,
padding=original_conv.padding,
bias=original_conv.bias is not None
)
with torch.no_grad():
new_conv.weight[:, :3] = original_conv.weight.clone()
new_conv.weight[:, 3] = original_conv.weight.mean(dim=1)
model.swin.embeddings.patch_embeddings.projection = new_conv
# Load the trained state dict from best_model.pth
model.load_state_dict(torch.load("best_model.pth", map_location=device))
model.to(device)
model.eval()
# Define transformations for Swin Transformer input
swin_transform = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define label mapping for the first 7 classes
label_to_idx = {
'akiec': 0, 'bcc': 1, 'bkl': 2, 'df': 3,
'mel': 4, 'nv': 5, 'vasc': 6
}
idx_to_label = {v: k for k, v in label_to_idx.items()}
# Prediction function
def predict(image):
# Convert numpy array to PIL Image if necessary
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Process image for Swin Transformer
swin_image = swin_transform(image).to(device)
# Generate a dummy mask channel (all zeros)
mask = torch.zeros(1, 384, 384).to(device)
# Combine image and dummy mask
combined = torch.cat([swin_image, mask], dim=0).unsqueeze(0) # Add batch dimension
# Get prediction using only the first 7 logits
with torch.no_grad():
outputs = model(combined).logits[:, :7] # Take only the first 7 classes
_, pred = torch.max(outputs, 1)
pred_label = idx_to_label[pred.item()]
return pred_label
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Text(),
title="Skin Cancer Classification",
description="Upload an image to classify the type of skin cancer. Supported classes: akiec, bcc, bkl, df, mel, nv, vasc."
)
# Launch the interface
iface.launch()