Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import io | |
from torchvision import transforms | |
import torchvision.transforms.functional as TF | |
import urllib.request | |
import os | |
import random | |
import kagglehub | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = None | |
# Download dataset | |
dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation') | |
image_path = os.path.join(dataset_path, 'images') | |
mask_path = os.path.join(dataset_path, 'masks') | |
test_imgs = sorted([f for f in os.listdir(image_path) if f.endswith('.jpg') or f.endswith('.png')]) | |
test_masks = sorted([f for f in os.listdir(mask_path) if f.endswith('.jpg') or f.endswith('.png')]) | |
# Define your Attention U-Net architecture (from your training code) | |
class DoubleConv(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(DoubleConv, self).__init__() | |
self.conv = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
) | |
def forward(self, x): | |
return self.conv(x) | |
class AttentionBlock(nn.Module): | |
def __init__(self, F_g, F_l, F_int): | |
super(AttentionBlock, self).__init__() | |
self.W_g = nn.Sequential( | |
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
nn.BatchNorm2d(F_int) | |
) | |
self.W_x = nn.Sequential( | |
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
nn.BatchNorm2d(F_int) | |
) | |
self.psi = nn.Sequential( | |
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), | |
nn.BatchNorm2d(1), | |
nn.Sigmoid() | |
) | |
self.relu = nn.ReLU(inplace=True) | |
def forward(self, g, x): | |
g1 = self.W_g(g) | |
x1 = self.W_x(x) | |
psi = self.relu(g1 + x1) | |
psi = self.psi(psi) | |
return x * psi | |
class AttentionUNET(nn.Module): | |
def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]): | |
super(AttentionUNET, self).__init__() | |
self.out_channels = out_channels | |
self.ups = nn.ModuleList() | |
self.downs = nn.ModuleList() | |
self.attentions = nn.ModuleList() | |
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
# Down part of UNET | |
for feature in features: | |
self.downs.append(DoubleConv(in_channels, feature)) | |
in_channels = feature | |
# Bottleneck | |
self.bottleneck = DoubleConv(features[-1], features[-1]*2) | |
# Up part of UNET | |
for feature in reversed(features): | |
self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)) | |
self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2)) | |
self.ups.append(DoubleConv(feature*2, feature)) | |
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) | |
def forward(self, x): | |
skip_connections = [] | |
for down in self.downs: | |
x = down(x) | |
skip_connections.append(x) | |
x = self.pool(x) | |
x = self.bottleneck(x) | |
skip_connections = skip_connections[::-1] #reverse list | |
for idx in range(0, len(self.ups), 2): #do up and double_conv | |
x = self.ups[idx](x) | |
skip_connection = skip_connections[idx//2] | |
if x.shape != skip_connection.shape: | |
x = TF.resize(x, size=skip_connection.shape[2:]) | |
skip_connection = self.attentions[idx // 2](skip_connection, x) | |
concat_skip = torch.cat((skip_connection, x), dim=1) | |
x = self.ups[idx+1](concat_skip) | |
return self.final_conv(x) | |
def download_model(): | |
"""Download trained model from HuggingFace""" | |
model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar" | |
model_path = "best_attention_model.pth.tar" | |
if not os.path.exists(model_path): | |
print("Downloading trained model...") | |
try: | |
urllib.request.urlretrieve(model_url, model_path) | |
print("Model downloaded successfully!") | |
except Exception as e: | |
print(f"Failed to download model: {e}") | |
return None | |
else: | |
print("Model already exists!") | |
return model_path | |
def load_attention_model(): | |
"""Load trained Attention U-Net model""" | |
global model | |
if model is None: | |
try: | |
print("Loading trained Attention U-Net model...") | |
# Download model if needed | |
model_path = download_model() | |
if model_path is None: | |
return None | |
# Initialize model architecture | |
model = AttentionUNET(in_channels=1, out_channels=1).to(device) | |
# Load trained weights | |
checkpoint = torch.load(model_path, map_location=device, weights_only=True) | |
model.load_state_dict(checkpoint["state_dict"]) | |
model.eval() | |
print("Attention U-Net model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
model = None | |
return model | |
def preprocess_image(image): | |
"""Preprocessing for model input""" | |
# Convert to grayscale | |
if image.mode != 'L': | |
image = image.convert('L') | |
# Apply transforms | |
val_test_transform = transforms.Compose([ | |
transforms.Resize((256,256)), | |
transforms.ToTensor() | |
]) | |
return val_test_transform(image).unsqueeze(0) # Add batch dimension | |
def predict_tumor(image, mask=None): | |
current_model = load_attention_model() | |
if current_model is None: | |
return None, "Failed to load trained model." | |
if image is None: | |
return None, "Please upload an image first." | |
try: | |
print("Processing with PerceptNet Attention U-Net...") | |
# Preprocess image | |
input_tensor = preprocess_image(image).to(device) | |
# Model prediction | |
with torch.no_grad(): | |
pred_mask = torch.sigmoid(current_model(input_tensor)) | |
pred_mask_binary = (pred_mask > 0.5).float() | |
# Convert to numpy | |
pred_mask_np = pred_mask_binary.cpu().squeeze().numpy() | |
prob_mask_np = pred_mask.cpu().squeeze().numpy() # Probability for heatmap | |
original_np = np.array(image.convert('L').resize((256, 256))) | |
# Create inverted mask for visualization | |
inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255) | |
# Create tumor-only image | |
tumor_only = np.where(pred_mask_np == 1, original_np, 255) | |
# Handle ground truth if provided | |
mask_np = None | |
dice_score = None | |
iou_score = None | |
if mask is not None: | |
mask_transform = transforms.Compose([ | |
transforms.Resize((256,256)), | |
transforms.ToTensor() | |
]) | |
mask_tensor = mask_transform(mask).squeeze().numpy() | |
mask_np = (mask_tensor > 0.5).astype(float) | |
intersection = np.logical_and(pred_mask_np, mask_np).sum() | |
union = np.logical_or(pred_mask_np, mask_np).sum() | |
iou_score = intersection / (union + 1e-7) | |
dice_score = (2 * intersection) / (pred_mask_np.sum() + mask_np.sum() + 1e-7) | |
# Create visualization (5-panel layout) | |
fig, axes = plt.subplots(1, 5, figsize=(25, 5)) | |
fig.suptitle('PerceptNet Analysis Results', fontsize=16, fontweight='bold') | |
titles = ["Original Image", "Ground Truth", "Predicted Mask", "Tumor Only", "Heatmap"] | |
images = [original_np, mask_np if mask_np is not None else np.zeros_like(original_np), inv_pred_mask_np, tumor_only, prob_mask_np] | |
cmaps = ['gray', 'gray', 'gray', 'gray', 'hot'] | |
for i, ax in enumerate(axes): | |
ax.imshow(images[i], cmap=cmaps[i]) | |
ax.set_title(titles[i], fontsize=12, fontweight='bold') | |
ax.axis('off') | |
plt.tight_layout() | |
# Save result | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white') | |
buf.seek(0) | |
plt.close() | |
result_image = Image.open(buf) | |
# Calculate statistics | |
tumor_pixels = np.sum(pred_mask_np) | |
total_pixels = pred_mask_np.size | |
tumor_percentage = (tumor_pixels / total_pixels) * 100 | |
# Calculate confidence metrics | |
max_confidence = torch.max(pred_mask).item() | |
mean_confidence = torch.mean(pred_mask).item() | |
analysis_text = f""" | |
## PerceptNet Analysis Results | |
### Detection Summary: | |
- **Status**: {'TUMOR DETECTED' if tumor_pixels > 50 else 'NO SIGNIFICANT TUMOR'} | |
- **Tumor Area**: {tumor_percentage:.2f}% of brain region | |
- **Tumor Pixels**: {tumor_pixels:,} pixels | |
- **Max Confidence**: {max_confidence:.4f} | |
- **Mean Confidence**: {mean_confidence:.4f} | |
""" | |
if dice_score is not None and iou_score is not None: | |
analysis_text += f""" | |
- **Dice Score**: {dice_score:.4f} | |
- **IoU Score**: {iou_score:.4f} | |
""" | |
analysis_text += f""" | |
### Model Information: | |
- **Architecture**: PerceptNet Attention U-Net | |
- **Training Performance**: Dice: 0.8420, IoU: 0.7297 | |
- **Input**: Grayscale (single channel) | |
- **Output**: Binary segmentation mask | |
- **Device**: {device.type.upper()} | |
### Processing Details: | |
- **Preprocessing**: Resize(256×256) + ToTensor | |
- **Threshold**: 0.5 (sigmoid > 0.5) | |
- **Architecture**: Attention gates + Skip connections | |
- **Features**: [32, 64, 128, 256] channels | |
### Medical Disclaimer: | |
This AI model is for **research and educational purposes only**. | |
Results should be validated by medical professionals. Not for clinical diagnosis. | |
""" | |
print(f"Model analysis completed! Tumor area: {tumor_percentage:.2f}%") | |
return result_image, analysis_text | |
except Exception as e: | |
error_msg = f"Error with model: {str(e)}" | |
print(error_msg) | |
return None, error_msg | |
def load_random_sample(): | |
if not test_imgs: | |
return None, None, "Dataset not available." | |
rand_idx = random.randint(0, len(test_imgs) - 1) | |
img_path = os.path.join(image_path, test_imgs[rand_idx]) | |
msk_path = os.path.join(mask_path, test_masks[rand_idx]) | |
image = Image.open(img_path).convert('L') | |
mask = Image.open(msk_path).convert('L') | |
return image, mask, "Loaded random sample from dataset." | |
def clear_all(): | |
return None, None, "Upload a brain MRI image to test PerceptNet model", None | |
# Professional CSS styling | |
css = """ | |
.gradio-container { | |
max-width: 1600px !important; | |
margin: auto !important; | |
background-color: #ffffff !important; | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important; | |
} | |
.gr-markdown p, .gr-markdown div, .gr-markdown span, .gr-markdown li { | |
color: #1e293b !important; | |
} | |
.gr-markdown h1, .gr-markdown h2, .gr-markdown h3, .gr-markdown h4, .gr-markdown h5, .gr-markdown h6 { | |
color: #1e293b !important; | |
} | |
.gr-markdown strong { | |
color: #374151 !important; | |
} | |
#analysis-results * { | |
color: #1e293b !important; | |
} | |
#title-header { | |
background: linear-gradient(135deg, #2563eb 0%, #1d4ed8 100%); | |
color: white; | |
padding: 40px 30px; | |
border-radius: 12px; | |
margin-bottom: 30px; | |
box-shadow: 0 4px 20px rgba(37, 99, 235, 0.15); | |
text-align: center; | |
} | |
.main-container { | |
background-color: #ffffff; | |
border-radius: 12px; | |
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); | |
padding: 30px; | |
margin-bottom: 20px; | |
} | |
.input-section { | |
background-color: #f8fafc; | |
border: 1px solid #e2e8f0; | |
border-radius: 8px; | |
padding: 25px; | |
} | |
.info-panel { | |
background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%); | |
border: 1px solid #0ea5e9; | |
border-radius: 8px; | |
padding: 20px; | |
margin-top: 20px; | |
} | |
.footer-section { | |
background-color: #f8fafc; | |
border: 1px solid #e2e8f0; | |
border-radius: 12px; | |
padding: 30px; | |
margin-top: 30px; | |
} | |
.stat-grid { | |
display: grid; | |
grid-template-columns: 1fr 1fr; | |
gap: 30px; | |
margin: 20px 0; | |
} | |
.disclaimer-text { | |
color: #dc2626; | |
font-weight: 600; | |
line-height: 1.5; | |
background-color: #fef2f2; | |
padding: 15px; | |
border-radius: 6px; | |
border: 1px solid #fecaca; | |
} | |
h1, h2, h3, h4 { | |
color: #1e293b !important; | |
} | |
.gr-button-primary { | |
background: linear-gradient(135deg, #2563eb 0%, #1d4ed8 100%) !important; | |
border: none !important; | |
color: white !important; | |
font-weight: 600 !important; | |
padding: 12px 24px !important; | |
border-radius: 8px !important; | |
transition: all 0.2s ease !important; | |
} | |
.gr-button-primary:hover { | |
transform: translateY(-1px) !important; | |
box-shadow: 0 4px 12px rgba(37, 99, 235, 0.3) !important; | |
} | |
.gr-button-secondary { | |
background: #6b7280 !important; | |
border: none !important; | |
color: white !important; | |
font-weight: 600 !important; | |
padding: 12px 24px !important; | |
border-radius: 8px !important; | |
} | |
""" | |
# Create Gradio interface | |
with gr.Blocks(css=css, title="PerceptNet - Brain Tumor Segmentation", theme=gr.themes.Default()) as app: | |
gr.HTML(""" | |
<div id="title-header"> | |
<h1 style="margin: 0; font-size: 2.5rem; font-weight: 700;">PerceptNet</h1> | |
<p style="font-size: 1.2rem; margin: 15px 0 5px 0; opacity: 0.95;"> | |
Advanced Brain Tumor Segmentation System | |
</p> | |
<p style="font-size: 1rem; margin: 5px 0 0 0; opacity: 0.8;"> | |
Attention U-Net Architecture • Dice: 0.8420 • IoU: 0.7297 | |
</p> | |
</div> | |
""") | |
mask_state = gr.State(None) | |
with gr.Row(elem_classes="main-container"): | |
with gr.Column(scale=1, elem_classes="input-section"): | |
gr.Markdown("### Upload Brain MRI Scan", elem_classes="section-title") | |
image_input = gr.Image( | |
label="Brain MRI Image", | |
type="pil", | |
sources=["upload", "webcam"], | |
height=380 | |
) | |
with gr.Row(): | |
analyze_btn = gr.Button( | |
"Analyze Image", | |
variant="primary", | |
scale=2, | |
size="lg" | |
) | |
random_btn = gr.Button( | |
"Load Sample", | |
variant="secondary", | |
scale=1, | |
size="lg" | |
) | |
clear_btn = gr.Button( | |
"Clear", | |
variant="secondary", | |
scale=1 | |
) | |
gr.HTML(""" | |
<div class="info-panel"> | |
<h4 style="color: #0ea5e9; margin-bottom: 15px; font-size: 1.1rem;">Model Specifications</h4> | |
<div style="line-height: 1.8; font-size: 0.95rem;"> | |
<div><strong>Architecture:</strong> Attention U-Net with Skip Connections</div> | |
<div><strong>Performance:</strong> 84.2% Dice Score, 72.97% IoU</div> | |
<div><strong>Input Format:</strong> Grayscale MRI Scans (256×256)</div> | |
<div><strong>Output:</strong> Binary Segmentation + Confidence Heatmap</div> | |
<div><strong>Features:</strong> Attention Mechanisms, Multi-scale Analysis</div> | |
</div> | |
</div> | |
""") | |
with gr.Column(scale=2): | |
gr.Markdown("### Analysis Results", elem_classes="section-title") | |
output_image = gr.Image( | |
label="PerceptNet Analysis Output", | |
type="pil", | |
height=520 | |
) | |
analysis_output = gr.Markdown( | |
value="Upload a brain MRI image to begin analysis with PerceptNet.", | |
elem_id="analysis-results" | |
) | |
# Footer section | |
gr.HTML(""" | |
<div class="footer-section"> | |
<div class="stat-grid"> | |
<div> | |
<h4 style="color: #2563eb; margin-bottom: 15px;">Technical Specifications</h4> | |
<div style="line-height: 1.6;"> | |
<p><strong>Model Architecture:</strong> Attention U-Net with Gating Mechanisms</p> | |
<p><strong>Training Dataset:</strong> Brain Tumor Segmentation Dataset</p> | |
<p><strong>Image Processing:</strong> 256×256 Grayscale Normalization</p> | |
<p><strong>Inference Speed:</strong> Real-time Processing on GPU/CPU</p> | |
<p><strong>Output Formats:</strong> Binary Masks, Probability Maps, Heatmaps</p> | |
</div> | |
</div> | |
<div> | |
<h4 style="color: #dc2626; margin-bottom: 15px;">Important Disclaimer</h4> | |
<div class="disclaimer-text"> | |
PerceptNet is an AI research tool designed for <strong>educational and research purposes only</strong>. | |
This system is not intended for clinical diagnosis or medical decision-making. | |
All results must be validated by qualified medical professionals before any medical application. | |
</div> | |
</div> | |
</div> | |
<hr style="margin: 25px 0; border: none; border-top: 1px solid #e2e8f0;"> | |
<p style="text-align: center; color: #64748b; margin: 15px 0; font-weight: 500;"> | |
PerceptNet v1.0 • Advanced Medical Image Analysis • Research Grade Performance | |
</p> | |
</div> | |
""") | |
# Event handlers | |
analyze_btn.click( | |
fn=predict_tumor, | |
inputs=[image_input, mask_state], | |
outputs=[output_image, analysis_output], | |
show_progress=True | |
) | |
random_btn.click( | |
fn=load_random_sample, | |
inputs=[], | |
outputs=[image_input, mask_state, analysis_output] | |
) | |
clear_btn.click( | |
fn=clear_all, | |
inputs=[], | |
outputs=[image_input, output_image, analysis_output, mask_state] | |
) | |
if __name__ == "__main__": | |
print("Starting PerceptNet Brain Tumor Segmentation System...") | |
print("Loading Attention U-Net architecture...") | |
print("Auto-downloading model weights...") | |
print("Expected performance: Dice 0.8420, IoU 0.7297") | |
app.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True, | |
share=False | |
) | |