import gradio as gr import torch from PIL import Image import numpy as np from transformers import AutoImageProcessor, SwinForImageClassification from torchvision import transforms # Define device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load Swin Transformer model with original classifier (1000 classes) swin_processor = AutoImageProcessor.from_pretrained("microsoft/swin-large-patch4-window12-384") model = SwinForImageClassification.from_pretrained("microsoft/swin-large-patch4-window12-384") # Modify input channels to 4 (RGB + mask) original_conv = model.swin.embeddings.patch_embeddings.projection new_conv = torch.nn.Conv2d( in_channels=4, out_channels=original_conv.out_channels, kernel_size=original_conv.kernel_size, stride=original_conv.stride, padding=original_conv.padding, bias=original_conv.bias is not None ) with torch.no_grad(): new_conv.weight[:, :3] = original_conv.weight.clone() new_conv.weight[:, 3] = original_conv.weight.mean(dim=1) model.swin.embeddings.patch_embeddings.projection = new_conv # Load the trained state dict from best_model.pth model.load_state_dict(torch.load("best_model.pth", map_location=device)) model.to(device) model.eval() # Define transformations for Swin Transformer input swin_transform = transforms.Compose([ transforms.Resize((384, 384)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Define label mapping for the first 7 classes label_to_idx = { 'akiec': 0, 'bcc': 1, 'bkl': 2, 'df': 3, 'mel': 4, 'nv': 5, 'vasc': 6 } idx_to_label = {v: k for k, v in label_to_idx.items()} # Prediction function def predict(image): # Convert numpy array to PIL Image if necessary if isinstance(image, np.ndarray): image = Image.fromarray(image) # Process image for Swin Transformer swin_image = swin_transform(image).to(device) # Generate a dummy mask channel (all zeros) mask = torch.zeros(1, 384, 384).to(device) # Combine image and dummy mask combined = torch.cat([swin_image, mask], dim=0).unsqueeze(0) # Add batch dimension # Get prediction using only the first 7 logits with torch.no_grad(): outputs = model(combined).logits[:, :7] # Take only the first 7 classes _, pred = torch.max(outputs, 1) pred_label = idx_to_label[pred.item()] return pred_label # Create Gradio interface iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Text(), title="Skin Cancer Classification", description="Upload an image to classify the type of skin cancer. Supported classes: akiec, bcc, bkl, df, mel, nv, vasc." ) # Launch the interface iface.launch()