pavi156 commited on
Commit
c5e43e4
·
1 Parent(s): f4545ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -42
app.py CHANGED
@@ -5,47 +5,32 @@ from PIL import Image
5
 
6
  # Load the trained model
7
  model_path = "cifar_net.pth"
8
- model = torch.load(model_path, map_location=torch.device('cpu'))
9
- model = YourModelClass() # Replace YourModelClass with the appropriate model class
10
- model.load_state_dict(state_dict)
11
  model.eval()
12
 
13
- # Define class labels for CIFAR-10
14
- classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
15
-
16
- def classify_image(image):
17
- transform = transforms.Compose([
18
- transforms.ToTensor(),
19
- transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
20
- ])
21
-
22
- # Preprocess the input image
23
- image = transform(image).unsqueeze(0)
24
-
25
- # Perform inference with the model
26
- outputs = model(image)
27
- _, predicted = torch.max(outputs, 1)
28
- predicted_class = classes[predicted.item()]
29
-
30
- return predicted_class
31
-
32
- def classify_images(images):
33
- return [classify_image(image) for image in images]
34
-
35
- inputs_image = gr.inputs.Image(label="Input Image", type="pil")
36
- outputs_image = gr.outputs.Label(label="Predicted Class")
37
- interface_image = gr.Interface(
38
- fn=classify_images,
39
- inputs=inputs_image,
40
- outputs=outputs_image,
41
- title="CIFAR-10 Image Classifier",
42
- description="Classify images into one of the CIFAR-10 classes.",
43
- examples=[
44
- ['image_0.jpg'],
45
- ['image_1.jpg']
46
- ],
47
- allow_flagging=False
48
- )
49
-
50
- if __name__ == "__main__":
51
- interface_image.launch()
 
5
 
6
  # Load the trained model
7
  model_path = "cifar_net.pth"
8
+
9
+ model = torch.load(model_path)
 
10
  model.eval()
11
 
12
+ # Prepare the image for prediction
13
+ image_path = 'download.jpg'
14
+ image = Image.open(image_path)
15
+
16
+ # Transform the image to match CIFAR-10 format
17
+ transform = transforms.Compose([
18
+ transforms.Resize((32, 32)),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize with CIFAR-10 mean and std
21
+ ])
22
+ input_image = transform(image).unsqueeze(0)
23
+
24
+ # Make predictions
25
+ with torch.no_grad():
26
+ outputs = model(input_image)
27
+
28
+ # Retrieve the predicted class label
29
+ _, predicted = torch.max(outputs, 1)
30
+ class_index = predicted.item()
31
+
32
+ # Load the CIFAR-10 class labels
33
+ classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
34
+
35
+ # Print the predicted class label
36
+ print('Predicted class label:', classes[class_index])