jays009 commited on
Commit
39dba98
·
verified ·
1 Parent(s): 97f39b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -25
app.py CHANGED
@@ -4,6 +4,7 @@ from torch import nn
4
  from torchvision import models, transforms
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
 
7
  import logging
8
  import requests
9
  from io import BytesIO
@@ -15,18 +16,11 @@ logging.basicConfig(level=logging.INFO)
15
  num_classes = 3
16
 
17
  # Download model from Hugging Face
18
-
19
  def download_model():
20
- try:
21
- model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
22
- logging.info("Model downloaded successfully.")
23
- return model_path
24
- except Exception as e:
25
- logging.error(f"Failed to download model: {e}")
26
- raise
27
 
28
  # Load the model from Hugging Face
29
-
30
  def load_model(model_path):
31
  model = models.resnet50(pretrained=False)
32
  num_features = model.fc.in_features
@@ -34,7 +28,7 @@ def load_model(model_path):
34
  nn.Dropout(0.5),
35
  nn.Linear(num_features, 3) # 3 classes
36
  )
37
-
38
  # Load the checkpoint
39
  checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
40
 
@@ -50,9 +44,6 @@ def load_model(model_path):
50
  model.load_state_dict(new_state_dict, strict=False)
51
  model.eval()
52
  return model
53
- except Exception as e:
54
- logging.error(f"Failed to load model: {e}")
55
- raise
56
 
57
  # Path to your model
58
  model_path = download_model()
@@ -73,7 +64,7 @@ def predict_from_image_url(image_url):
73
  # Download the image from the provided URL
74
  response = requests.get(image_url)
75
  response.raise_for_status()
76
- image = Image.open(BytesIO(response.content)).convert('RGB')
77
 
78
  # Apply transformations
79
  image_tensor = transform(image).unsqueeze(0)
@@ -81,24 +72,22 @@ def predict_from_image_url(image_url):
81
  # Perform prediction
82
  with torch.no_grad():
83
  outputs = model(image_tensor)
84
- if outputs.shape[1] != num_classes:
85
- raise ValueError(f"Unexpected number of output classes: {outputs.shape[1]} (expected {num_classes})")
86
  predicted_class = torch.argmax(outputs, dim=1).item()
87
 
88
  # Interpret the result
89
- class_map = {
90
- 0: "The photo is of Fall Army Worm with problem ID 126.",
91
- 1: "The photo shows symptoms of Phosphorus Deficiency with Problem ID 142.",
92
- 2: "The photo shows symptoms of Bacterial Leaf Blight with Problem ID 203."
93
- }
94
- return {"result": class_map.get(predicted_class, "Unexpected class prediction.")}
 
 
95
 
96
  except Exception as e:
97
- logging.error(f"Error during prediction: {e}")
98
  return {"error": str(e)}
99
 
100
 
101
- # Initialize Gradio interface
102
  demo = gr.Interface(
103
  fn=predict_from_image_url,
104
  inputs="text",
@@ -108,4 +97,4 @@ demo = gr.Interface(
108
  )
109
 
110
  if __name__ == "__main__":
111
- demo.launch()
 
4
  from torchvision import models, transforms
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
+ import os
8
  import logging
9
  import requests
10
  from io import BytesIO
 
16
  num_classes = 3
17
 
18
  # Download model from Hugging Face
 
19
  def download_model():
20
+ model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
21
+ return model_path
 
 
 
 
 
22
 
23
  # Load the model from Hugging Face
 
24
  def load_model(model_path):
25
  model = models.resnet50(pretrained=False)
26
  num_features = model.fc.in_features
 
28
  nn.Dropout(0.5),
29
  nn.Linear(num_features, 3) # 3 classes
30
  )
31
+
32
  # Load the checkpoint
33
  checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
34
 
 
44
  model.load_state_dict(new_state_dict, strict=False)
45
  model.eval()
46
  return model
 
 
 
47
 
48
  # Path to your model
49
  model_path = download_model()
 
64
  # Download the image from the provided URL
65
  response = requests.get(image_url)
66
  response.raise_for_status()
67
+ image = Image.open(BytesIO(response.content))
68
 
69
  # Apply transformations
70
  image_tensor = transform(image).unsqueeze(0)
 
72
  # Perform prediction
73
  with torch.no_grad():
74
  outputs = model(image_tensor)
 
 
75
  predicted_class = torch.argmax(outputs, dim=1).item()
76
 
77
  # Interpret the result
78
+ if predicted_class == 0:
79
+ return {"result": "The photo is of Fall Army Worm with problem ID 126."}
80
+ elif predicted_class == 1:
81
+ return {"result": "The photo shows symptoms of Phosphorus Deficiency with Problem ID 142."}
82
+ elif predicted_class == 2:
83
+ return {"result": "The photo shows symptoms of Bacterial Leaf Blight with Problem ID 203."}
84
+ else:
85
+ return {"error": "Unexpected class prediction."}
86
 
87
  except Exception as e:
 
88
  return {"error": str(e)}
89
 
90
 
 
91
  demo = gr.Interface(
92
  fn=predict_from_image_url,
93
  inputs="text",
 
97
  )
98
 
99
  if __name__ == "__main__":
100
+ demo.launch()