tumor-prediction / predict.py
Skym616's picture
done
93f30c6
raw
history blame
2.24 kB
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
# Architecture du modèle
class DeepCNN(nn.Module):
def __init__(self, num_classes=4):
super(DeepCNN, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.layer2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.layer3 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.lqyer4 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.fc_layers = nn.Sequential(
nn.Linear(28800, 1024),
nn.ReLU(),
nn.Linear(1024, num_classes)
)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
out = out.view(out.size(0), -1)
out = self.fc_layers(out)
return out
def load_model(model_path, num_classes=4):
model = DeepCNN(num_classes=num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
return model
# Charger le modèle
model = load_model('cnn_model1.pth')
# Définir les transformations
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Fonction de prédiction
def predict(image):
image = Image.open(image)
image = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(image)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
confidence, predicted = torch.max(probabilities, 1)
print(predicted.item(), confidence.item() * 100)
return predicted.item(), round(confidence.item() * 100, 2)