jays009 commited on
Commit
4f7c2c3
·
verified ·
1 Parent(s): b6755e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py CHANGED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torch import nn
4
+ from torchvision import models, transforms
5
+ from huggingface_hub import hf_hub_download
6
+ from PIL import Image
7
+ import os
8
+ import logging
9
+ import requests
10
+ from io import BytesIO
11
+
12
+ # Setup logging
13
+ logging.basicConfig(level=logging.INFO)
14
+
15
+ # Define the number of classes
16
+ num_classes = 3
17
+
18
+ # Download model from Hugging Face
19
+ def download_model():
20
+ model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
21
+ return model_path
22
+
23
+ # Load the model from Hugging Face
24
+ def load_model(model_path):
25
+ model = models.resnet50(pretrained=False)
26
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
27
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
28
+ model.eval()
29
+ logging.info("Model loaded successfully. Ready for inference.")
30
+ return model
31
+
32
+ # Download the model and load it
33
+ model_path = download_model()
34
+ model = load_model(model_path)
35
+
36
+ # Define the transformation for the input image
37
+ transform = transforms.Compose([
38
+ transforms.Resize(256),
39
+ transforms.CenterCrop(224),
40
+ transforms.ToTensor(),
41
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
42
+ ])
43
+
44
+ # Prediction function for an uploaded image
45
+
46
+ def predict_from_image_url(image_url):
47
+ try:
48
+ # Download the image from the provided URL
49
+ response = requests.get(image_url)
50
+ response.raise_for_status()
51
+ image = Image.open(BytesIO(response.content))
52
+
53
+ # Apply transformations
54
+ image_tensor = transform(image).unsqueeze(0)
55
+
56
+ # Perform prediction
57
+ with torch.no_grad():
58
+ outputs = model(image_tensor)
59
+ predicted_class = torch.argmax(outputs, dim=1).item()
60
+
61
+ # Interpret the result
62
+ if predicted_class == 0:
63
+ return {"result": "The photo is of Fall Army Worm with problem ID 126."}
64
+ elif predicted_class == 1:
65
+ return {"result": "The photo shows symptoms of Phosphorus Deficiency with Problem ID 142."}
66
+ elif predicted_class == 2:
67
+ return {"result": "The photo shows symptoms of Bacterial Leaf Blight with Problem ID 203."}
68
+ else:
69
+ return {"error": "Unexpected class prediction."}
70
+
71
+ except Exception as e:
72
+ return {"error": str(e)}
73
+
74
+
75
+ demo = gr.Interface(
76
+ fn=predict_from_image_url,
77
+ inputs="text",
78
+ outputs="json",
79
+ title="Maize Disease Classification",
80
+ description="Enter a URL to an image for classification (Fall Army Worm, Phosphorus Deficiency, or Bacterial Leaf Blight).",
81
+ )
82
+
83
+ if __name__ == "__main__":
84
+ demo.launch()