Đinh Ngọc Ân commited on
Commit
f6a046f
·
1 Parent(s): 56cacaf

first commit

Browse files
09_pretrained_effnetb2_feature_extractor_pizza_steak_sushi_20_percent.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:232b9150391e812a5c6ecba4348eb35c649f8c8baa1a390ecea7f8c6f5def965
3
+ size 31307450
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+
5
+ from model import create_effnetb2_model
6
+ from timeit import default_timer as timer
7
+ from typing import Tuple, TypedDict
8
+
9
+ # Setup class names
10
+ class_names = ["pizza", "steak", "sushi"]
11
+
12
+ # Create EffNetB2 model instance and transform
13
+ effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=len(class_names))
14
+
15
+ # Load model weights
16
+ effnetb2.load_state_dict(
17
+ torch.load(
18
+ os.path.join("models", "09_pretrained_effnetb2_feature_extractor_pizza_steak_sushi_20_percent.pth"),
19
+ map_location=torch.device("cpu")
20
+ )
21
+ )
22
+
23
+ # Predict function
24
+ def predict(img) -> Tuple[Dict, float]:
25
+ # Start a timer
26
+ start_time = timer()
27
+
28
+ # Transform the input image for use with EffNetB2
29
+ img = effnetb2_transforms(img).unsqueeze(0)
30
+
31
+ # put model into eval mode, make prediction
32
+ effnetb2.eval()
33
+ with torch.inference_mode():
34
+ pred_probs = torch.softmax(effnetb2(img), dim=-1)
35
+
36
+ # Create a prediction label and predcition probability
37
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(class_names)}
38
+
39
+ # Calculate pred time and pred dict
40
+ pred_time = round(timer() - start_time, 5)
41
+
42
+ return pred_labels_and_probs, pred_time
43
+
44
+ # Gradio app
45
+ # Create title, description and article strings
46
+ title = "FoodVision Mini 🍕🥩🍣"
47
+ description = "An EfficientNetB2 feature extractor computer vision model to classify images of food as pizza, steak or sushi."
48
+ article = "Created at [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/)."
49
+
50
+ # Create an example list
51
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
52
+
53
+ # Create the Gradio demo
54
+ demo = gr.Interface(fn=predict,
55
+ inputs=gr.inputs.Image(type="pil"),
56
+ outputs=[gr.outputs.Label(num_top_classes=3, label="Predictions"),
57
+ gr.outputs.Number(label="Prediction time (s)")],
58
+ examples=example_list,
59
+ title=title,
60
+ description=description,
61
+ article=article)
62
+
63
+ # Launch the demo!
64
+ demo.launch(debug=False,
65
+ share=True)
examples/2582289.jpg ADDED
examples/3622237.jpg ADDED
examples/592799.jpg ADDED
model.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torchvision
2
+ from torch import nn
3
+
4
+ def create_effnetb2_model(num_classes: int = 3,
5
+ seed:int=42):
6
+ """Creates an EfficientNetB2 feature extractor model and transforms.
7
+
8
+ Args:
9
+ num_classes (int, optional): Number of output neurons in the output layer. Defaults to 3
10
+ seed (int, optional): Random seed value. Defaults to 42.
11
+
12
+ Returns:
13
+ torchvision.models.efficientnet_b2: EffNetB2 feature extractor model
14
+
15
+ """
16
+ # 1. Setup pretrained EffNMetB2 weights
17
+ effnetb2_weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
18
+ effnetb2_transform = effnetb2_weights.transforms()
19
+ # 2. Setup pretrained model
20
+ effnetb2 = torchvision.models.efficientnet_b2(weights=effnetb2_weights)
21
+ # 3. Freeze the base layers
22
+ for param in effnetb2.parameters():
23
+ param.requires_grad = False
24
+
25
+ # 4. Change the classsifier to 3 classes
26
+ torch.manual_seed(seed)
27
+ effnetb2.classifier = nn.Sequential(
28
+ nn.Dropout(p=0.3, inplace=True),
29
+ nn.Linear(in_features=1408, out_features=num_classes))
30
+
31
+ return effnetb2, effnetb2_transform
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==1.12.0
2
+ torchvision==0.13.0
3
+ gradio==3.1.4