Tanusree88 commited on
Commit
edd91e9
·
verified ·
1 Parent(s): 807c3f8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+ import numpy as np
4
+ import torch
5
+ from transformers import SegformerForImageSegmentation, ResNetForImageClassification, AdamW
6
+ from PIL import Image
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import streamlit as st
9
+ import gradio as gr
10
+
11
+ # Load the Segformer model using Gradio (Optional)
12
+ gr.load("models/nvidia/segformer-b0-finetuned-ade-512-512").launch()
13
+
14
+ # Function to extract zip files
15
+ def extract_zip(zip_file, extract_to):
16
+ with zipfile.ZipFile(zip_file, 'r') as zip_ref:
17
+ zip_ref.extractall(extract_to)
18
+
19
+ # Preprocess images
20
+ def preprocess_image(image_path):
21
+ ext = os.path.splitext(image_path)[-1].lower()
22
+
23
+ if ext == '.npy':
24
+ image_data = np.load(image_path)
25
+ image_tensor = torch.tensor(image_data).float()
26
+ if len(image_tensor.shape) == 3:
27
+ image_tensor = image_tensor.unsqueeze(0)
28
+
29
+ elif ext in ['.jpg', '.jpeg']:
30
+ img = Image.open(image_path).convert('RGB').resize((224, 224))
31
+ img_np = np.array(img)
32
+ image_tensor = torch.tensor(img_np).permute(2, 0, 1).float()
33
+
34
+ else:
35
+ raise ValueError(f"Unsupported format: {ext}")
36
+
37
+ image_tensor /= 255.0 # Normalize to [0, 1]
38
+ return image_tensor
39
+
40
+ # Prepare dataset
41
+ def prepare_dataset(extracted_folder):
42
+ neuronii_path = os.path.join(extracted_folder, "neuroniiimages")
43
+
44
+ if not os.path.exists(neuronii_path):
45
+ raise FileNotFoundError(f"The folder neuroniiimages does not exist in the extracted folder: {neuronii_path}")
46
+
47
+ image_paths = []
48
+ labels = []
49
+
50
+ for disease_folder in ['alzheimers_dataset', 'parkinsons_dataset', 'MSjpg']:
51
+ folder_path = os.path.join(neuronii_path, disease_folder)
52
+
53
+ if not os.path.exists(folder_path):
54
+ print(f"Folder not found: {folder_path}")
55
+ continue
56
+ label = {'alzheimers_dataset': 0, 'parkinsons_dataset': 1, 'MSjpg': 2}[disease_folder]
57
+
58
+ for img_file in os.listdir(folder_path):
59
+ if img_file.endswith(('.npy', '.jpg', '.jpeg')):
60
+ image_paths.append(os.path.join(folder_path, img_file))
61
+ labels.append(label)
62
+ else:
63
+ print(f"Unsupported file: {img_file}")
64
+ print(f"Total images loaded: {len(image_paths)}")
65
+ return image_paths, labels
66
+
67
+ # Custom Dataset class
68
+ class CustomImageDataset(Dataset):
69
+ def __init__(self, image_paths, labels):
70
+ self.image_paths = image_paths
71
+ self.labels = labels
72
+
73
+ def __len__(self):
74
+ return len(self.image_paths)
75
+
76
+ def __getitem__(self, idx):
77
+ image = preprocess_image(self.image_paths[idx])
78
+ label = self.labels[idx]
79
+ return image, label
80
+
81
+ # Training function for classification
82
+ def fine_tune_classification_model(train_loader):
83
+ model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50', num_labels=3)
84
+ model.train()
85
+ optimizer = AdamW(model.parameters(), lr=1e-4)
86
+ criterion = torch.nn.CrossEntropyLoss()
87
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
88
+ model.to(device)
89
+
90
+ for epoch in range(10):
91
+ running_loss = 0.0
92
+ for images, labels in train_loader:
93
+ images, labels = images.to(device), labels.to(device)
94
+ optimizer.zero_grad()
95
+ outputs = model(pixel_values=images).logits
96
+ loss = criterion(outputs, labels)
97
+ loss.backward()
98
+ optimizer.step()
99
+ running_loss += loss.item()
100
+ return running_loss / len(train_loader)
101
+
102
+ # Streamlit UI for Fine-tuning
103
+ st.title("Fine-tune ResNet for MRI/CT Scans Classification")
104
+
105
+ zip_file_url = "https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/neuroniiimages.zip"
106
+
107
+ if st.button("Start Training"):
108
+ extraction_dir = "extracted_files"
109
+ os.makedirs(extraction_dir, exist_ok=True)
110
+
111
+ # Download the zip file (placeholder)
112
+ zip_file = "neuroniiimages.zip" # Assuming you downloaded it with this name
113
+
114
+ # Extract zip file
115
+ extract_zip(zip_file, extraction_dir)
116
+
117
+ # Prepare dataset
118
+ image_paths, labels = prepare_dataset(extraction_dir)
119
+ dataset = CustomImageDataset(image_paths, labels)
120
+ train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
121
+
122
+ # Fine-tune the classification model
123
+ final_loss = fine_tune_classification_model(train_loader)
124
+ st.write(f"Training Complete with Final Loss: {final_loss}")
125
+
126
+ # Segmentation function (using SegFormer)
127
+ def fine_tune_segmentation_model(train_loader):
128
+ model = SegformerForImageSegmentation.from_pretrained('nvidia/segformer-b0', num_labels=3)
129
+ model.train()
130
+ optimizer = AdamW(model.parameters(), lr=1e-4)
131
+ criterion = torch.nn.CrossEntropyLoss()
132
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
133
+ model.to(device)
134
+
135
+ for epoch in range(10):
136
+ running_loss = 0.0
137
+ for images, labels in train_loader:
138
+ images, labels = images.to(device), labels.to(device)
139
+ optimizer.zero_grad()
140
+ outputs = model(pixel_values=images).logits
141
+ loss = criterion(outputs, labels)
142
+ loss.backward()
143
+ optimizer.step()
144
+ running_loss += loss.item()
145
+ return running_loss / len(train_loader)
146
+
147
+ # Add a button for segmentation training
148
+ if st.button("Start Segmentation Training"):
149
+ # Assuming the dataset for segmentation is prepared similarly
150
+ seg_train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
151
+
152
+ # Fine-tune the segmentation model
153
+ final_loss_seg = fine_tune_segmentation_model(seg_train_loader)
154
+ st.write(f"Segmentation Training Complete with Final Loss: {final_loss_seg}")