mrs83 commited on
Commit
8889d9a
·
1 Parent(s): eda5549

add app and requirements

Browse files
Files changed (2) hide show
  1. app.py +153 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from datasets import load_dataset
5
+ import random
6
+
7
+ from skincancer_vit.model import SkinCancerViTModel
8
+
9
+ HF_MODEL_REPO = "ethicalabs/SkinCancerViT"
10
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+
13
+ print(f"Loading SkinCancerViT model from {HF_MODEL_REPO} to {DEVICE}...")
14
+
15
+ model = SkinCancerViTModel.from_pretrained(HF_MODEL_REPO)
16
+ model.to(DEVICE)
17
+ model.eval() # Set to evaluation mode
18
+ print("Model loaded successfully.")
19
+
20
+ print("Loading 'marmal88/skin_cancer' dataset for random samples...")
21
+ dataset = load_dataset("marmal88/skin_cancer", split="test")
22
+ print("Dataset loaded successfully.")
23
+
24
+
25
+ def predict_uploaded_image(image: Image.Image, age: int, localization: str) -> str:
26
+ """
27
+ Handles prediction for an uploaded image with user-provided tabular data.
28
+ """
29
+ if model is None:
30
+ return "Error: Model not loaded. Please check the console for details."
31
+ if image is None:
32
+ return "Please upload an image."
33
+ if age is None:
34
+ return "Please enter an age."
35
+ if not localization:
36
+ return "Please select a localization."
37
+
38
+ try:
39
+ # Call the model's full_predict method
40
+ predicted_dx, confidence = model.full_predict(
41
+ raw_image=image, raw_age=age, raw_localization=localization, device=DEVICE
42
+ )
43
+ return f"Predicted Diagnosis: **{predicted_dx}** (Confidence: {confidence:.4f})"
44
+ except Exception as e:
45
+ return f"Prediction Error: {e}"
46
+
47
+
48
+ # --- Prediction Function for Random Sample ---
49
+ def predict_random_sample() -> str:
50
+ """
51
+ Fetches a random sample from the dataset and performs prediction.
52
+ """
53
+ if model is None:
54
+ return "Error: Model not loaded. Please check the console for details."
55
+ if dataset is None:
56
+ return "Error: Dataset not loaded. Cannot select random sample."
57
+
58
+ try:
59
+ # Select a random sample from the dataset
60
+ random_idx = random.randint(0, len(dataset) - 1)
61
+ sample = dataset[random_idx]
62
+
63
+ sample_image = sample["image"]
64
+ sample_age = sample["age"]
65
+ sample_localization = sample["localization"]
66
+ sample_true_dx = sample["dx"]
67
+
68
+ # Call the model's full_predict method
69
+ predicted_dx, confidence = model.full_predict(
70
+ raw_image=sample_image,
71
+ raw_age=sample_age,
72
+ raw_localization=sample_localization,
73
+ device=DEVICE,
74
+ )
75
+
76
+ # Return a formatted string with all information
77
+ result_str = (
78
+ f"**Random Sample Details:**\n"
79
+ f"- Age: {sample_age}\n"
80
+ f"- Localization: {sample_localization}\n"
81
+ f"- True Diagnosis: **{sample_true_dx}**\n\n"
82
+ f"**Model Prediction:**\n"
83
+ f"- Predicted Diagnosis: **{predicted_dx}**\n"
84
+ f"- Confidence: {confidence:.4f}\n"
85
+ f"- Correct Prediction: {'✅ Yes' if predicted_dx == sample_true_dx else '❌ No'}"
86
+ )
87
+ return sample_image, result_str
88
+ except Exception as e:
89
+ return None, f"Prediction Error on Random Sample: {e}"
90
+
91
+
92
+ # --- Gradio Interface ---
93
+ with gr.Blocks(title="Skin Cancer ViT Predictor") as demo:
94
+ gr.Markdown(
95
+ """
96
+ # Skin Cancer ViT Predictor
97
+ This application demonstrates the `SkinCancerViT` multimodal model for skin cancer diagnosis.
98
+ It can take an uploaded image with patient metadata or predict on a random sample from the dataset.
99
+ **Disclaimer:** This tool is for demonstration and research purposes only and should not be used for medical diagnosis.
100
+ """
101
+ )
102
+
103
+ with gr.Tab("Predict on Random Sample"):
104
+ gr.Markdown("## Get a Prediction from a Random Sample in the Test Set")
105
+ random_sample_button = gr.Button("Get Random Sample Prediction")
106
+
107
+ # Modified output components for random sample tab
108
+ with gr.Row():
109
+ output_random_image = gr.Image(
110
+ type="pil", label="Random Sample Image", height=250, width=250
111
+ )
112
+ output_random_details = gr.Markdown(
113
+ "Random sample details and prediction will appear here."
114
+ )
115
+
116
+ random_sample_button.click(
117
+ fn=predict_random_sample,
118
+ inputs=[],
119
+ outputs=[
120
+ output_random_image,
121
+ output_random_details,
122
+ ], # Map to both image and markdown outputs
123
+ )
124
+
125
+ with gr.Tab("Upload Image & Predict"):
126
+ gr.Markdown("## Upload Your Image and Provide Patient Data")
127
+ with gr.Row():
128
+ image_input = gr.Image(
129
+ type="pil", label="Upload Skin Lesion Image (224x224 preferred)"
130
+ )
131
+ with gr.Column():
132
+ age_input = gr.Number(
133
+ label="Patient Age", minimum=0, maximum=120, step=1
134
+ )
135
+ # Ensure these localizations match your training data categories
136
+ localization_input = gr.Dropdown(
137
+ model.config.localization_to_id.keys(),
138
+ label="Lesion Localization",
139
+ value="unknown", # Default value
140
+ )
141
+ predict_button = gr.Button("Get Prediction")
142
+
143
+ output_upload = gr.Markdown("Prediction will appear here.")
144
+
145
+ predict_button.click(
146
+ fn=predict_uploaded_image,
147
+ inputs=[image_input, age_input, localization_input],
148
+ outputs=output_upload,
149
+ )
150
+
151
+ if __name__ == "__main__":
152
+ demo.launch(share=False)
153
+
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ skincancervit @ git+https://github.com/ethicalabs-ai/SkinCancerViT.git