|
import gradio as gr
|
|
import torch
|
|
from PIL import Image
|
|
import numpy as np
|
|
from transformers import AutoImageProcessor, SwinForImageClassification
|
|
from torchvision import transforms
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
swin_processor = AutoImageProcessor.from_pretrained("microsoft/swin-large-patch4-window12-384")
|
|
model = SwinForImageClassification.from_pretrained("microsoft/swin-large-patch4-window12-384")
|
|
|
|
|
|
original_conv = model.swin.embeddings.patch_embeddings.projection
|
|
new_conv = torch.nn.Conv2d(
|
|
in_channels=4,
|
|
out_channels=original_conv.out_channels,
|
|
kernel_size=original_conv.kernel_size,
|
|
stride=original_conv.stride,
|
|
padding=original_conv.padding,
|
|
bias=original_conv.bias is not None
|
|
)
|
|
with torch.no_grad():
|
|
new_conv.weight[:, :3] = original_conv.weight.clone()
|
|
new_conv.weight[:, 3] = original_conv.weight.mean(dim=1)
|
|
model.swin.embeddings.patch_embeddings.projection = new_conv
|
|
|
|
|
|
model.load_state_dict(torch.load("best_model.pth", map_location=device))
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
|
|
swin_transform = transforms.Compose([
|
|
transforms.Resize((384, 384)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
|
|
label_to_idx = {
|
|
'akiec': 0, 'bcc': 1, 'bkl': 2, 'df': 3,
|
|
'mel': 4, 'nv': 5, 'vasc': 6
|
|
}
|
|
idx_to_label = {v: k for k, v in label_to_idx.items()}
|
|
|
|
|
|
def predict(image):
|
|
|
|
if isinstance(image, np.ndarray):
|
|
image = Image.fromarray(image)
|
|
|
|
|
|
swin_image = swin_transform(image).to(device)
|
|
|
|
|
|
mask = torch.zeros(1, 384, 384).to(device)
|
|
|
|
|
|
combined = torch.cat([swin_image, mask], dim=0).unsqueeze(0)
|
|
|
|
|
|
with torch.no_grad():
|
|
outputs = model(combined).logits[:, :7]
|
|
_, pred = torch.max(outputs, 1)
|
|
pred_label = idx_to_label[pred.item()]
|
|
|
|
return pred_label
|
|
|
|
|
|
iface = gr.Interface(
|
|
fn=predict,
|
|
inputs=gr.Image(type="pil"),
|
|
outputs=gr.Text(),
|
|
title="Skin Cancer Classification",
|
|
description="Upload an image to classify the type of skin cancer. Supported classes: akiec, bcc, bkl, df, mel, nv, vasc."
|
|
)
|
|
|
|
|
|
iface.launch() |