ArchCoder's picture
Update app.py
5bec37a verified
raw
history blame
19 kB
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import io
from torchvision import transforms
import torchvision.transforms.functional as TF
import urllib.request
import os
import random
import kagglehub
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = None
# Download dataset
dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation')
image_path = os.path.join(dataset_path, 'images')
mask_path = os.path.join(dataset_path, 'masks')
test_imgs = sorted([f for f in os.listdir(image_path) if f.endswith('.jpg') or f.endswith('.png')])
test_masks = sorted([f for f in os.listdir(mask_path) if f.endswith('.jpg') or f.endswith('.png')])
# Define your Attention U-Net architecture (from your training code)
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.conv(x)
class AttentionBlock(nn.Module):
def __init__(self, F_g, F_l, F_int):
super(AttentionBlock, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
class AttentionUNET(nn.Module):
def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
super(AttentionUNET, self).__init__()
self.out_channels = out_channels
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
self.attentions = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# Down part of UNET
for feature in features:
self.downs.append(DoubleConv(in_channels, feature))
in_channels = feature
# Bottleneck
self.bottleneck = DoubleConv(features[-1], features[-1]*2)
# Up part of UNET
for feature in reversed(features):
self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2))
self.ups.append(DoubleConv(feature*2, feature))
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
def forward(self, x):
skip_connections = []
for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)
x = self.bottleneck(x)
skip_connections = skip_connections[::-1] #reverse list
for idx in range(0, len(self.ups), 2): #do up and double_conv
x = self.ups[idx](x)
skip_connection = skip_connections[idx//2]
if x.shape != skip_connection.shape:
x = TF.resize(x, size=skip_connection.shape[2:])
skip_connection = self.attentions[idx // 2](skip_connection, x)
concat_skip = torch.cat((skip_connection, x), dim=1)
x = self.ups[idx+1](concat_skip)
return self.final_conv(x)
def download_model():
"""Download trained model from HuggingFace"""
model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar"
model_path = "best_attention_model.pth.tar"
if not os.path.exists(model_path):
print("Downloading trained model...")
try:
urllib.request.urlretrieve(model_url, model_path)
print("Model downloaded successfully!")
except Exception as e:
print(f"Failed to download model: {e}")
return None
else:
print("Model already exists!")
return model_path
def load_attention_model():
"""Load trained Attention U-Net model"""
global model
if model is None:
try:
print("Loading trained Attention U-Net model...")
# Download model if needed
model_path = download_model()
if model_path is None:
return None
# Initialize model architecture
model = AttentionUNET(in_channels=1, out_channels=1).to(device)
# Load trained weights
checkpoint = torch.load(model_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
print("Attention U-Net model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
model = None
return model
def preprocess_image(image):
"""Preprocessing for model input"""
# Convert to grayscale
if image.mode != 'L':
image = image.convert('L')
# Apply transforms
val_test_transform = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor()
])
return val_test_transform(image).unsqueeze(0) # Add batch dimension
def predict_tumor(image, mask=None):
current_model = load_attention_model()
if current_model is None:
return None, "Failed to load trained model."
if image is None:
return None, "Please upload an image first."
try:
print("Processing with PerceptNet Attention U-Net...")
# Preprocess image
input_tensor = preprocess_image(image).to(device)
# Model prediction
with torch.no_grad():
pred_mask = torch.sigmoid(current_model(input_tensor))
pred_mask_binary = (pred_mask > 0.5).float()
# Convert to numpy
pred_mask_np = pred_mask_binary.cpu().squeeze().numpy()
prob_mask_np = pred_mask.cpu().squeeze().numpy() # Probability for heatmap
original_np = np.array(image.convert('L').resize((256, 256)))
# Create inverted mask for visualization
inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255)
# Create tumor-only image
tumor_only = np.where(pred_mask_np == 1, original_np, 255)
# Handle ground truth if provided
mask_np = None
dice_score = None
iou_score = None
if mask is not None:
mask_transform = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor()
])
mask_tensor = mask_transform(mask).squeeze().numpy()
mask_np = (mask_tensor > 0.5).astype(float)
intersection = np.logical_and(pred_mask_np, mask_np).sum()
union = np.logical_or(pred_mask_np, mask_np).sum()
iou_score = intersection / (union + 1e-7)
dice_score = (2 * intersection) / (pred_mask_np.sum() + mask_np.sum() + 1e-7)
# Create visualization (5-panel layout)
fig, axes = plt.subplots(1, 5, figsize=(25, 5))
fig.suptitle('PerceptNet Analysis Results', fontsize=16, fontweight='bold')
titles = ["Original Image", "Ground Truth", "Predicted Mask", "Tumor Only", "Heatmap"]
images = [original_np, mask_np if mask_np is not None else np.zeros_like(original_np), inv_pred_mask_np, tumor_only, prob_mask_np]
cmaps = ['gray', 'gray', 'gray', 'gray', 'hot']
for i, ax in enumerate(axes):
ax.imshow(images[i], cmap=cmaps[i])
ax.set_title(titles[i], fontsize=12, fontweight='bold')
ax.axis('off')
plt.tight_layout()
# Save result
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
buf.seek(0)
plt.close()
result_image = Image.open(buf)
# Calculate statistics
tumor_pixels = np.sum(pred_mask_np)
total_pixels = pred_mask_np.size
tumor_percentage = (tumor_pixels / total_pixels) * 100
# Calculate confidence metrics
max_confidence = torch.max(pred_mask).item()
mean_confidence = torch.mean(pred_mask).item()
analysis_text = f"""
## PerceptNet Analysis Results
### Detection Summary:
- **Status**: {'TUMOR DETECTED' if tumor_pixels > 50 else 'NO SIGNIFICANT TUMOR'}
- **Tumor Area**: {tumor_percentage:.2f}% of brain region
- **Tumor Pixels**: {tumor_pixels:,} pixels
- **Max Confidence**: {max_confidence:.4f}
- **Mean Confidence**: {mean_confidence:.4f}
"""
if dice_score is not None and iou_score is not None:
analysis_text += f"""
- **Dice Score**: {dice_score:.4f}
- **IoU Score**: {iou_score:.4f}
"""
analysis_text += f"""
### Model Information:
- **Architecture**: PerceptNet Attention U-Net
- **Training Performance**: Dice: 0.8420, IoU: 0.7297
- **Input**: Grayscale (single channel)
- **Output**: Binary segmentation mask
- **Device**: {device.type.upper()}
### Processing Details:
- **Preprocessing**: Resize(256×256) + ToTensor
- **Threshold**: 0.5 (sigmoid > 0.5)
- **Architecture**: Attention gates + Skip connections
- **Features**: [32, 64, 128, 256] channels
### Medical Disclaimer:
This AI model is for **research and educational purposes only**.
Results should be validated by medical professionals. Not for clinical diagnosis.
"""
print(f"Model analysis completed! Tumor area: {tumor_percentage:.2f}%")
return result_image, analysis_text
except Exception as e:
error_msg = f"Error with model: {str(e)}"
print(error_msg)
return None, error_msg
def load_random_sample():
if not test_imgs:
return None, None, "Dataset not available."
rand_idx = random.randint(0, len(test_imgs) - 1)
img_path = os.path.join(image_path, test_imgs[rand_idx])
msk_path = os.path.join(mask_path, test_masks[rand_idx])
image = Image.open(img_path).convert('L')
mask = Image.open(msk_path).convert('L')
return image, mask, "Loaded random sample from dataset."
def clear_all():
return None, None, "Upload a brain MRI image to test PerceptNet model", None
# Professional CSS styling
css = """
.gradio-container {
max-width: 1600px !important;
margin: auto !important;
background-color: #ffffff !important;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important;
}
.gr-markdown p, .gr-markdown div, .gr-markdown span, .gr-markdown li {
color: #1e293b !important;
}
.gr-markdown h1, .gr-markdown h2, .gr-markdown h3, .gr-markdown h4, .gr-markdown h5, .gr-markdown h6 {
color: #1e293b !important;
}
.gr-markdown strong {
color: #374151 !important;
}
#analysis-results * {
color: #1e293b !important;
}
#title-header {
background: linear-gradient(135deg, #2563eb 0%, #1d4ed8 100%);
color: white;
padding: 40px 30px;
border-radius: 12px;
margin-bottom: 30px;
box-shadow: 0 4px 20px rgba(37, 99, 235, 0.15);
text-align: center;
}
.main-container {
background-color: #ffffff;
border-radius: 12px;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
padding: 30px;
margin-bottom: 20px;
}
.input-section {
background-color: #f8fafc;
border: 1px solid #e2e8f0;
border-radius: 8px;
padding: 25px;
}
.info-panel {
background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%);
border: 1px solid #0ea5e9;
border-radius: 8px;
padding: 20px;
margin-top: 20px;
}
.footer-section {
background-color: #f8fafc;
border: 1px solid #e2e8f0;
border-radius: 12px;
padding: 30px;
margin-top: 30px;
}
.stat-grid {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 30px;
margin: 20px 0;
}
.disclaimer-text {
color: #dc2626;
font-weight: 600;
line-height: 1.5;
background-color: #fef2f2;
padding: 15px;
border-radius: 6px;
border: 1px solid #fecaca;
}
h1, h2, h3, h4 {
color: #1e293b !important;
}
.gr-button-primary {
background: linear-gradient(135deg, #2563eb 0%, #1d4ed8 100%) !important;
border: none !important;
color: white !important;
font-weight: 600 !important;
padding: 12px 24px !important;
border-radius: 8px !important;
transition: all 0.2s ease !important;
}
.gr-button-primary:hover {
transform: translateY(-1px) !important;
box-shadow: 0 4px 12px rgba(37, 99, 235, 0.3) !important;
}
.gr-button-secondary {
background: #6b7280 !important;
border: none !important;
color: white !important;
font-weight: 600 !important;
padding: 12px 24px !important;
border-radius: 8px !important;
}
"""
# Create Gradio interface
with gr.Blocks(css=css, title="PerceptNet - Brain Tumor Segmentation", theme=gr.themes.Default()) as app:
gr.HTML("""
<div id="title-header">
<h1 style="margin: 0; font-size: 2.5rem; font-weight: 700;">PerceptNet</h1>
<p style="font-size: 1.2rem; margin: 15px 0 5px 0; opacity: 0.95;">
Advanced Brain Tumor Segmentation System
</p>
<p style="font-size: 1rem; margin: 5px 0 0 0; opacity: 0.8;">
Attention U-Net Architecture • Dice: 0.8420 • IoU: 0.7297
</p>
</div>
""")
mask_state = gr.State(None)
with gr.Row(elem_classes="main-container"):
with gr.Column(scale=1, elem_classes="input-section"):
gr.Markdown("### Upload Brain MRI Scan", elem_classes="section-title")
image_input = gr.Image(
label="Brain MRI Image",
type="pil",
sources=["upload", "webcam"],
height=380
)
with gr.Row():
analyze_btn = gr.Button(
"Analyze Image",
variant="primary",
scale=2,
size="lg"
)
random_btn = gr.Button(
"Load Sample",
variant="secondary",
scale=1,
size="lg"
)
clear_btn = gr.Button(
"Clear",
variant="secondary",
scale=1
)
gr.HTML("""
<div class="info-panel">
<h4 style="color: #0ea5e9; margin-bottom: 15px; font-size: 1.1rem;">Model Specifications</h4>
<div style="line-height: 1.8; font-size: 0.95rem;">
<div><strong>Architecture:</strong> Attention U-Net with Skip Connections</div>
<div><strong>Performance:</strong> 84.2% Dice Score, 72.97% IoU</div>
<div><strong>Input Format:</strong> Grayscale MRI Scans (256×256)</div>
<div><strong>Output:</strong> Binary Segmentation + Confidence Heatmap</div>
<div><strong>Features:</strong> Attention Mechanisms, Multi-scale Analysis</div>
</div>
</div>
""")
with gr.Column(scale=2):
gr.Markdown("### Analysis Results", elem_classes="section-title")
output_image = gr.Image(
label="PerceptNet Analysis Output",
type="pil",
height=520
)
analysis_output = gr.Markdown(
value="Upload a brain MRI image to begin analysis with PerceptNet.",
elem_id="analysis-results"
)
# Footer section
gr.HTML("""
<div class="footer-section">
<div class="stat-grid">
<div>
<h4 style="color: #2563eb; margin-bottom: 15px;">Technical Specifications</h4>
<div style="line-height: 1.6;">
<p><strong>Model Architecture:</strong> Attention U-Net with Gating Mechanisms</p>
<p><strong>Training Dataset:</strong> Brain Tumor Segmentation Dataset</p>
<p><strong>Image Processing:</strong> 256×256 Grayscale Normalization</p>
<p><strong>Inference Speed:</strong> Real-time Processing on GPU/CPU</p>
<p><strong>Output Formats:</strong> Binary Masks, Probability Maps, Heatmaps</p>
</div>
</div>
<div>
<h4 style="color: #dc2626; margin-bottom: 15px;">Important Disclaimer</h4>
<div class="disclaimer-text">
PerceptNet is an AI research tool designed for <strong>educational and research purposes only</strong>.
This system is not intended for clinical diagnosis or medical decision-making.
All results must be validated by qualified medical professionals before any medical application.
</div>
</div>
</div>
<hr style="margin: 25px 0; border: none; border-top: 1px solid #e2e8f0;">
<p style="text-align: center; color: #64748b; margin: 15px 0; font-weight: 500;">
PerceptNet v1.0 • Advanced Medical Image Analysis • Research Grade Performance
</p>
</div>
""")
# Event handlers
analyze_btn.click(
fn=predict_tumor,
inputs=[image_input, mask_state],
outputs=[output_image, analysis_output],
show_progress=True
)
random_btn.click(
fn=load_random_sample,
inputs=[],
outputs=[image_input, mask_state, analysis_output]
)
clear_btn.click(
fn=clear_all,
inputs=[],
outputs=[image_input, output_image, analysis_output, mask_state]
)
if __name__ == "__main__":
print("Starting PerceptNet Brain Tumor Segmentation System...")
print("Loading Attention U-Net architecture...")
print("Auto-downloading model weights...")
print("Expected performance: Dice 0.8420, IoU 0.7297")
app.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
share=False
)