tdnathmlenthusiast commited on
Commit
3d93357
·
verified ·
1 Parent(s): 2b660a0

added all required files for model

Browse files
Files changed (3) hide show
  1. app.py +82 -0
  2. best_model.pth +3 -0
  3. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ from transformers import AutoImageProcessor, SwinForImageClassification
6
+ from torchvision import transforms
7
+
8
+ # Define device
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # Load Swin Transformer model with original classifier (1000 classes)
12
+ swin_processor = AutoImageProcessor.from_pretrained("microsoft/swin-large-patch4-window12-384")
13
+ model = SwinForImageClassification.from_pretrained("microsoft/swin-large-patch4-window12-384")
14
+
15
+ # Modify input channels to 4 (RGB + mask)
16
+ original_conv = model.swin.embeddings.patch_embeddings.projection
17
+ new_conv = torch.nn.Conv2d(
18
+ in_channels=4,
19
+ out_channels=original_conv.out_channels,
20
+ kernel_size=original_conv.kernel_size,
21
+ stride=original_conv.stride,
22
+ padding=original_conv.padding,
23
+ bias=original_conv.bias is not None
24
+ )
25
+ with torch.no_grad():
26
+ new_conv.weight[:, :3] = original_conv.weight.clone()
27
+ new_conv.weight[:, 3] = original_conv.weight.mean(dim=1)
28
+ model.swin.embeddings.patch_embeddings.projection = new_conv
29
+
30
+ # Load the trained state dict from best_model.pth
31
+ model.load_state_dict(torch.load("best_model.pth", map_location=device))
32
+ model.to(device)
33
+ model.eval()
34
+
35
+ # Define transformations for Swin Transformer input
36
+ swin_transform = transforms.Compose([
37
+ transforms.Resize((384, 384)),
38
+ transforms.ToTensor(),
39
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
40
+ ])
41
+
42
+ # Define label mapping for the first 7 classes
43
+ label_to_idx = {
44
+ 'akiec': 0, 'bcc': 1, 'bkl': 2, 'df': 3,
45
+ 'mel': 4, 'nv': 5, 'vasc': 6
46
+ }
47
+ idx_to_label = {v: k for k, v in label_to_idx.items()}
48
+
49
+ # Prediction function
50
+ def predict(image):
51
+ # Convert numpy array to PIL Image if necessary
52
+ if isinstance(image, np.ndarray):
53
+ image = Image.fromarray(image)
54
+
55
+ # Process image for Swin Transformer
56
+ swin_image = swin_transform(image).to(device)
57
+
58
+ # Generate a dummy mask channel (all zeros)
59
+ mask = torch.zeros(1, 384, 384).to(device)
60
+
61
+ # Combine image and dummy mask
62
+ combined = torch.cat([swin_image, mask], dim=0).unsqueeze(0) # Add batch dimension
63
+
64
+ # Get prediction using only the first 7 logits
65
+ with torch.no_grad():
66
+ outputs = model(combined).logits[:, :7] # Take only the first 7 classes
67
+ _, pred = torch.max(outputs, 1)
68
+ pred_label = idx_to_label[pred.item()]
69
+
70
+ return pred_label
71
+
72
+ # Create Gradio interface
73
+ iface = gr.Interface(
74
+ fn=predict,
75
+ inputs=gr.Image(type="pil"),
76
+ outputs=gr.Text(),
77
+ title="Skin Cancer Classification",
78
+ description="Upload an image to classify the type of skin cancer. Supported classes: akiec, bcc, bkl, df, mel, nv, vasc."
79
+ )
80
+
81
+ # Launch the interface
82
+ iface.launch()
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fce599d2bca7e9e9d7e4eeb0020787f429df5584f4595cf3376407b4117e0490
3
+ size 791125887
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ transformers
5
+ segmentation-models-pytorch
6
+ pillow
7
+ numpy