Shriharsh commited on
Commit
752ca4f
ยท
verified ยท
1 Parent(s): 77fdea5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -39
app.py CHANGED
@@ -6,82 +6,58 @@ from model import create_effnetb2_model
6
  from timeit import default_timer as timer
7
  from typing import Tuple, Dict
8
 
9
- # Setup class names
10
  try:
11
- with open("class_names.txt", "r") as f: # reading them in from class_names.txt
12
  class_names = [food_name.strip() for food_name in f.readlines()]
13
  except FileNotFoundError:
14
- raise FileNotFoundError("class_names.txt not found. Ensure it exists in the root directory.")
15
 
16
  ### 2. Model and transforms preparation ###
17
-
18
- # Create model
19
  try:
20
- effnetb2, effnetb2_transforms = create_effnetb2_model(
21
- num_classes=101, # could also use len(class_names)
22
- )
23
  except Exception as e:
24
  raise Exception(f"Error creating model: {str(e)}")
25
 
26
- # Load saved weights
27
  try:
28
  effnetb2.load_state_dict(
29
  torch.load(
30
- f="09_pretrained_effnetb2_feature_extractor_food101.pth",
31
- map_location=torch.device("cpu"), # load to CPU
32
  )
33
  )
34
  except FileNotFoundError:
35
- raise FileNotFoundError("Model weights file '09_pretrained_effnetb2_feature_extractor_food101.pth' not found.")
36
  except Exception as e:
37
- raise Exception(f"Error loading model weights: {str(e)}")
38
 
39
  ### 3. Predict function ###
40
-
41
  def predict(img) -> Tuple[Dict, float]:
42
- """Transforms and performs a prediction on img and returns prediction and time taken."""
43
  try:
44
- # Start the timer
45
  start_time = timer()
46
-
47
- # Transform the target image and add a batch dimension
48
  if img is None:
49
- raise ValueError("Input image is None. Please provide a valid image.")
50
  img = effnetb2_transforms(img).unsqueeze(0)
51
-
52
- # Put model into evaluation mode and turn on inference mode
53
  effnetb2.eval()
54
  with torch.inference_mode():
55
- # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
56
  pred_probs = torch.softmax(effnetb2(img), dim=1)
57
-
58
- # Create a prediction label and prediction probability dictionary for each prediction class
59
- pred_labels_and_probs = {
60
- class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))
61
- }
62
-
63
- # Calculate the prediction time
64
  pred_time = round(timer() - start_time, 5)
65
-
66
- # Return the prediction dictionary and prediction time
67
  return pred_labels_and_probs, pred_time
68
  except Exception as e:
69
  return {"error": f"Prediction failed: {str(e)}"}, 0.0
70
 
71
  ### 4. Gradio app ###
72
-
73
- # Create title, description
74
  title = "FoodVision 101 ๐Ÿ”๐Ÿ‘"
75
- description = "An EfficientNetB2 feature extractor computer vision model to classify images of food into 101 different classes."
76
 
77
- # Create examples list from "examples/" directory
78
  try:
79
  example_list = [["examples/" + example] for example in os.listdir("examples")]
80
  except FileNotFoundError:
81
  example_list = []
82
- print("Warning: 'examples/' directory not found. No example images will be loaded.")
83
 
84
- # Create Gradio interface
85
  demo = gr.Interface(
86
  fn=predict,
87
  inputs=gr.Image(type="pil"),
@@ -94,5 +70,5 @@ demo = gr.Interface(
94
  description=description,
95
  )
96
 
97
- # Launch the app with share=True for Hugging Face Spaces
98
- demo.launch(share=True)
 
6
  from timeit import default_timer as timer
7
  from typing import Tuple, Dict
8
 
9
+ # Load class names
10
  try:
11
+ with open("class_names.txt", "r") as f:
12
  class_names = [food_name.strip() for food_name in f.readlines()]
13
  except FileNotFoundError:
14
+ raise FileNotFoundError("class_names.txt not found.")
15
 
16
  ### 2. Model and transforms preparation ###
 
 
17
  try:
18
+ effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=101)
 
 
19
  except Exception as e:
20
  raise Exception(f"Error creating model: {str(e)}")
21
 
22
+ # Load weights
23
  try:
24
  effnetb2.load_state_dict(
25
  torch.load(
26
+ "09_pretrained_effnetb2_feature_extractor_food101.pth",
27
+ map_location=torch.device("cpu"),
28
  )
29
  )
30
  except FileNotFoundError:
31
+ raise FileNotFoundError("Model weights file not found.")
32
  except Exception as e:
33
+ raise Exception(f"Error loading weights: {str(e)}")
34
 
35
  ### 3. Predict function ###
 
36
  def predict(img) -> Tuple[Dict, float]:
 
37
  try:
 
38
  start_time = timer()
 
 
39
  if img is None:
40
+ raise ValueError("Input image is None.")
41
  img = effnetb2_transforms(img).unsqueeze(0)
 
 
42
  effnetb2.eval()
43
  with torch.inference_mode():
 
44
  pred_probs = torch.softmax(effnetb2(img), dim=1)
45
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
 
 
 
 
 
 
46
  pred_time = round(timer() - start_time, 5)
 
 
47
  return pred_labels_and_probs, pred_time
48
  except Exception as e:
49
  return {"error": f"Prediction failed: {str(e)}"}, 0.0
50
 
51
  ### 4. Gradio app ###
 
 
52
  title = "FoodVision 101 ๐Ÿ”๐Ÿ‘"
53
+ description = "An EfficientNetB2 feature extractor to classify 101 food classes."
54
 
 
55
  try:
56
  example_list = [["examples/" + example] for example in os.listdir("examples")]
57
  except FileNotFoundError:
58
  example_list = []
59
+ print("Warning: 'examples/' directory not found.")
60
 
 
61
  demo = gr.Interface(
62
  fn=predict,
63
  inputs=gr.Image(type="pil"),
 
70
  description=description,
71
  )
72
 
73
+ # Launch without share=True for Hugging Face Spaces
74
+ demo.launch()