jays009 commited on
Commit
5a8efa9
·
verified ·
1 Parent(s): b3f86f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -30
app.py CHANGED
@@ -4,7 +4,6 @@ from torch import nn
4
  from torchvision import models, transforms
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
- import os
8
  import logging
9
  import requests
10
  from io import BytesIO
@@ -16,31 +15,43 @@ logging.basicConfig(level=logging.INFO)
16
  num_classes = 3
17
 
18
  # Download model from Hugging Face
 
19
  def download_model():
20
- model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
21
- return model_path
 
 
 
 
 
22
 
23
  # Load the model from Hugging Face
24
- def load_model(model_path):
25
- model = models.resnet50(pretrained=False)
26
- num_features = model.fc.in_features
27
- model.fc = nn.Sequential(
28
- nn.Dropout(0.5),
29
- nn.Linear(num_features, 3) # 3 classes
30
- )
31
- checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
32
- model.load_state_dict(checkpoint['model_state_dict'])
33
-
34
- model.eval()
35
- return model
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # Path to your model
39
- model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
40
- model = load_model(model_path)
41
-
42
-
43
- # Download the model and load it
44
  model_path = download_model()
45
  model = load_model(model_path)
46
 
@@ -59,7 +70,7 @@ def predict_from_image_url(image_url):
59
  # Download the image from the provided URL
60
  response = requests.get(image_url)
61
  response.raise_for_status()
62
- image = Image.open(BytesIO(response.content))
63
 
64
  # Apply transformations
65
  image_tensor = transform(image).unsqueeze(0)
@@ -67,27 +78,29 @@ def predict_from_image_url(image_url):
67
  # Perform prediction
68
  with torch.no_grad():
69
  outputs = model(image_tensor)
 
 
70
  predicted_class = torch.argmax(outputs, dim=1).item()
71
 
72
  # Interpret the result
73
- if predicted_class == 0:
74
- return {"result": "The photo is of Fall Army Worm with problem ID 126."}
75
- elif predicted_class == 1:
76
- return {"result": "The photo shows symptoms of Phosphorus Deficiency with Problem ID 142."}
77
- elif predicted_class == 2:
78
- return {"result": "The photo shows symptoms of Bacterial Leaf Blight with Problem ID 203."}
79
- else:
80
- return {"error": "Unexpected class prediction."}
81
 
82
  except Exception as e:
 
83
  return {"error": str(e)}
84
 
85
 
 
86
  demo = gr.Interface(
87
  fn=predict_from_image_url,
88
  inputs="text",
89
  outputs="json",
90
- title="Maize Disease Classification",
91
  description="Enter a URL to an image for classification (Fall Army Worm, Phosphorus Deficiency, or Bacterial Leaf Blight).",
92
  )
93
 
 
4
  from torchvision import models, transforms
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
 
7
  import logging
8
  import requests
9
  from io import BytesIO
 
15
  num_classes = 3
16
 
17
  # Download model from Hugging Face
18
+
19
  def download_model():
20
+ try:
21
+ model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
22
+ logging.info("Model downloaded successfully.")
23
+ return model_path
24
+ except Exception as e:
25
+ logging.error(f"Failed to download model: {e}")
26
+ raise
27
 
28
  # Load the model from Hugging Face
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ def load_model(model_path):
31
+ try:
32
+ model = models.resnet50(pretrained=False)
33
+ num_features = model.fc.in_features
34
+ model.fc = nn.Sequential(
35
+ nn.Dropout(0.5),
36
+ nn.Linear(num_features, num_classes)
37
+ )
38
+ checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
39
+
40
+ # Ensure compatibility by handling key mismatches
41
+ model_state_dict = checkpoint['model_state_dict']
42
+ for key in list(model_state_dict.keys()):
43
+ if key.startswith('fc.1'):
44
+ model_state_dict[key.replace('fc.1', 'fc')] = model_state_dict.pop(key)
45
+
46
+ model.load_state_dict(model_state_dict)
47
+ model.eval()
48
+ logging.info("Model loaded successfully.")
49
+ return model
50
+ except Exception as e:
51
+ logging.error(f"Failed to load model: {e}")
52
+ raise
53
 
54
  # Path to your model
 
 
 
 
 
55
  model_path = download_model()
56
  model = load_model(model_path)
57
 
 
70
  # Download the image from the provided URL
71
  response = requests.get(image_url)
72
  response.raise_for_status()
73
+ image = Image.open(BytesIO(response.content)).convert('RGB')
74
 
75
  # Apply transformations
76
  image_tensor = transform(image).unsqueeze(0)
 
78
  # Perform prediction
79
  with torch.no_grad():
80
  outputs = model(image_tensor)
81
+ if outputs.shape[1] != num_classes:
82
+ raise ValueError(f"Unexpected number of output classes: {outputs.shape[1]} (expected {num_classes})")
83
  predicted_class = torch.argmax(outputs, dim=1).item()
84
 
85
  # Interpret the result
86
+ class_map = {
87
+ 0: "The photo is of Fall Army Worm with problem ID 126.",
88
+ 1: "The photo shows symptoms of Phosphorus Deficiency with Problem ID 142.",
89
+ 2: "The photo shows symptoms of Bacterial Leaf Blight with Problem ID 203."
90
+ }
91
+ return {"result": class_map.get(predicted_class, "Unexpected class prediction.")}
 
 
92
 
93
  except Exception as e:
94
+ logging.error(f"Error during prediction: {e}")
95
  return {"error": str(e)}
96
 
97
 
98
+ # Initialize Gradio interface
99
  demo = gr.Interface(
100
  fn=predict_from_image_url,
101
  inputs="text",
102
  outputs="json",
103
+ title="Crop anomaly Classification",
104
  description="Enter a URL to an image for classification (Fall Army Worm, Phosphorus Deficiency, or Bacterial Leaf Blight).",
105
  )
106