Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,7 +8,6 @@ from PIL import Image
|
|
| 8 |
import requests
|
| 9 |
from io import BytesIO
|
| 10 |
from fastapi import FastAPI
|
| 11 |
-
from gradio.routes import App
|
| 12 |
|
| 13 |
# Define the number of classes
|
| 14 |
num_classes = 2
|
|
@@ -18,11 +17,13 @@ results_cache = {}
|
|
| 18 |
|
| 19 |
# Download model from Hugging Face
|
| 20 |
def download_model():
|
|
|
|
| 21 |
model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
|
| 22 |
return model_path
|
| 23 |
|
| 24 |
# Load the model from Hugging Face
|
| 25 |
def load_model(model_path):
|
|
|
|
| 26 |
model = models.resnet50(pretrained=False)
|
| 27 |
model.fc = nn.Linear(model.fc.in_features, num_classes)
|
| 28 |
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
|
@@ -44,22 +45,16 @@ transform = transforms.Compose([
|
|
| 44 |
# Function to predict from image content
|
| 45 |
def predict_from_image(image):
|
| 46 |
try:
|
| 47 |
-
# Log the image processing
|
| 48 |
print(f"Processing image: {image}")
|
| 49 |
-
|
| 50 |
-
# Ensure the image is a PIL Image
|
| 51 |
if not isinstance(image, Image.Image):
|
| 52 |
raise ValueError("Invalid image format received. Please provide a valid image.")
|
| 53 |
|
| 54 |
# Apply transformations
|
| 55 |
image_tensor = transform(image).unsqueeze(0)
|
| 56 |
-
|
| 57 |
-
# Predict
|
| 58 |
with torch.no_grad():
|
| 59 |
outputs = model(image_tensor)
|
| 60 |
predicted_class = torch.argmax(outputs, dim=1).item()
|
| 61 |
|
| 62 |
-
# Interpret the result
|
| 63 |
if predicted_class == 0:
|
| 64 |
return {"result": "The photo is of fall army worm with problem ID 126."}
|
| 65 |
elif predicted_class == 1:
|
|
@@ -74,11 +69,10 @@ def predict_from_image(image):
|
|
| 74 |
# Function to predict from URL
|
| 75 |
def predict_from_url(url):
|
| 76 |
try:
|
| 77 |
-
|
| 78 |
response = requests.get(url)
|
| 79 |
-
response.raise_for_status()
|
| 80 |
image = Image.open(BytesIO(response.content))
|
| 81 |
-
print(f"Fetched image from URL: {url}")
|
| 82 |
return predict_from_image(image)
|
| 83 |
except Exception as e:
|
| 84 |
print(f"Error during URL processing: {e}")
|
|
@@ -87,21 +81,19 @@ def predict_from_url(url):
|
|
| 87 |
# Main prediction function with caching
|
| 88 |
def predict(image, url):
|
| 89 |
try:
|
|
|
|
| 90 |
if image:
|
| 91 |
result = predict_from_image(image)
|
| 92 |
elif url:
|
| 93 |
result = predict_from_url(url)
|
| 94 |
else:
|
| 95 |
result = {"error": "No input provided. Please upload an image or provide a URL."}
|
| 96 |
-
|
| 97 |
-
# Generate and store the event ID
|
| 98 |
event_id = id(result) # Use Python's id() function to generate a unique identifier
|
| 99 |
results_cache[event_id] = result
|
| 100 |
-
|
| 101 |
-
# Log the result
|
| 102 |
print(f"Event ID: {event_id}, Result: {result}")
|
| 103 |
return {"event_id": event_id, "result": result}
|
| 104 |
-
|
| 105 |
except Exception as e:
|
| 106 |
print(f"Error in prediction function: {e}")
|
| 107 |
return {"error": str(e)}
|
|
@@ -109,7 +101,7 @@ def predict(image, url):
|
|
| 109 |
# Function to retrieve result by event_id
|
| 110 |
def get_result(event_id):
|
| 111 |
try:
|
| 112 |
-
|
| 113 |
event_id = int(event_id)
|
| 114 |
result = results_cache.get(event_id)
|
| 115 |
if result:
|
|
|
|
| 8 |
import requests
|
| 9 |
from io import BytesIO
|
| 10 |
from fastapi import FastAPI
|
|
|
|
| 11 |
|
| 12 |
# Define the number of classes
|
| 13 |
num_classes = 2
|
|
|
|
| 17 |
|
| 18 |
# Download model from Hugging Face
|
| 19 |
def download_model():
|
| 20 |
+
print("Downloading model...")
|
| 21 |
model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
|
| 22 |
return model_path
|
| 23 |
|
| 24 |
# Load the model from Hugging Face
|
| 25 |
def load_model(model_path):
|
| 26 |
+
print("Loading model...")
|
| 27 |
model = models.resnet50(pretrained=False)
|
| 28 |
model.fc = nn.Linear(model.fc.in_features, num_classes)
|
| 29 |
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
|
|
|
| 45 |
# Function to predict from image content
|
| 46 |
def predict_from_image(image):
|
| 47 |
try:
|
|
|
|
| 48 |
print(f"Processing image: {image}")
|
|
|
|
|
|
|
| 49 |
if not isinstance(image, Image.Image):
|
| 50 |
raise ValueError("Invalid image format received. Please provide a valid image.")
|
| 51 |
|
| 52 |
# Apply transformations
|
| 53 |
image_tensor = transform(image).unsqueeze(0)
|
|
|
|
|
|
|
| 54 |
with torch.no_grad():
|
| 55 |
outputs = model(image_tensor)
|
| 56 |
predicted_class = torch.argmax(outputs, dim=1).item()
|
| 57 |
|
|
|
|
| 58 |
if predicted_class == 0:
|
| 59 |
return {"result": "The photo is of fall army worm with problem ID 126."}
|
| 60 |
elif predicted_class == 1:
|
|
|
|
| 69 |
# Function to predict from URL
|
| 70 |
def predict_from_url(url):
|
| 71 |
try:
|
| 72 |
+
print(f"Fetching image from URL: {url}")
|
| 73 |
response = requests.get(url)
|
| 74 |
+
response.raise_for_status()
|
| 75 |
image = Image.open(BytesIO(response.content))
|
|
|
|
| 76 |
return predict_from_image(image)
|
| 77 |
except Exception as e:
|
| 78 |
print(f"Error during URL processing: {e}")
|
|
|
|
| 81 |
# Main prediction function with caching
|
| 82 |
def predict(image, url):
|
| 83 |
try:
|
| 84 |
+
print("Starting prediction...")
|
| 85 |
if image:
|
| 86 |
result = predict_from_image(image)
|
| 87 |
elif url:
|
| 88 |
result = predict_from_url(url)
|
| 89 |
else:
|
| 90 |
result = {"error": "No input provided. Please upload an image or provide a URL."}
|
| 91 |
+
|
|
|
|
| 92 |
event_id = id(result) # Use Python's id() function to generate a unique identifier
|
| 93 |
results_cache[event_id] = result
|
| 94 |
+
|
|
|
|
| 95 |
print(f"Event ID: {event_id}, Result: {result}")
|
| 96 |
return {"event_id": event_id, "result": result}
|
|
|
|
| 97 |
except Exception as e:
|
| 98 |
print(f"Error in prediction function: {e}")
|
| 99 |
return {"error": str(e)}
|
|
|
|
| 101 |
# Function to retrieve result by event_id
|
| 102 |
def get_result(event_id):
|
| 103 |
try:
|
| 104 |
+
print(f"Retrieving result for event ID: {event_id}")
|
| 105 |
event_id = int(event_id)
|
| 106 |
result = results_cache.get(event_id)
|
| 107 |
if result:
|