radub23
Add retry logic and robust tensor handling for intermittent failures
f438e63
raw
history blame
5.68 kB
import gradio as gr
from fastai.vision.all import *
from fastai.learner import load_learner
from pathlib import Path
import pandas as pd
import os
import time
"""
Warning Lamp Detector using FastAI
This application allows users to upload images of warning lamps and get classification results.
"""
def get_labels(fname):
"""
Function required by the model to process labels
Args:
fname: Path to the image file
Returns:
list: List of active labels
"""
# Since we're only doing inference, we can return an empty list
# This function is only needed because the model was saved with it
return []
# Load the FastAI model
try:
model_path = Path("WarningLampClassifier.pkl")
learn_inf = load_learner(model_path)
print("Model loaded successfully")
except Exception as e:
print(f"Error loading model: {e}")
raise
def detect_warning_lamp(image, history: list[tuple[str, str]], system_message):
"""
Process the uploaded image and return detection results using FastAI model
Args:
image: PIL Image from Gradio
history: Chat history
system_message: System prompt
Returns:
Updated chat history with prediction results
"""
if image is None:
history.append((None, "Please upload an image first."))
return history
# Maximum number of retries
max_retries = 3
retry_count = 0
while retry_count < max_retries:
try:
# Convert PIL image to FastAI compatible format
img = PILImage(image)
# Get model prediction
pred_class, pred_idx, probs = learn_inf.predict(img)
# Try different approaches to handle tensor conversion
try:
# First approach - direct conversion
confidence = float(probs[pred_idx])
except Exception as e1:
print(f"First conversion approach failed: {e1}")
try:
# Second approach - convert index first
idx = int(pred_idx)
confidence = float(probs[idx])
except Exception as e2:
print(f"Second conversion approach failed: {e2}")
# Third approach - use item() method if available
if hasattr(probs[pred_idx], 'item'):
confidence = probs[pred_idx].item()
else:
# Last resort - use the max probability
confidence = float(max(probs))
# Format the prediction results
response = f"Detected Warning Lamp: {pred_class}\nConfidence: {confidence:.2%}"
# Add probabilities for all classes
response += "\n\nProbabilities for all classes:"
for i, (cls, prob) in enumerate(zip(learn_inf.dls.vocab, probs)):
try:
prob_value = float(prob)
response += f"\n- {cls}: {prob_value:.2%}"
except Exception as prob_error:
print(f"Error converting probability for {cls}: {prob_error}")
response += f"\n- {cls}: N/A"
# Update chat history
history.append((None, response))
return history
except Exception as e:
retry_count += 1
print(f"Attempt {retry_count} failed with error: {e}")
if retry_count < max_retries:
print(f"Retrying in 1 second...")
time.sleep(1) # Wait a bit before retrying
else:
error_msg = f"Error processing image after {max_retries} attempts: {str(e)}"
print(f"All retries failed: {error_msg}")
history.append((None, error_msg))
return history
# Create a custom interface with image upload
with gr.Blocks(title="Warning Lamp Detector", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🚨 Warning Lamp Detector
Upload an image of a warning lamp to get its classification.
### Instructions:
1. Upload a clear image of the warning lamp
2. Wait for the analysis
3. View the detailed classification results
### Supported Warning Lamps:
""")
# Display supported classes if available
if 'learn_inf' in locals():
gr.Markdown("\n".join([f"- {cls}" for cls in learn_inf.dls.vocab]))
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
label="Upload Warning Lamp Image",
type="pil",
sources="upload"
)
system_message = gr.Textbox(
value="You are an expert in warning lamp classification. Analyze the image and provide detailed information about the type, color, and status of the warning lamp.",
label="System Message",
lines=3,
visible=False # Hide this since we're using direct model inference
)
with gr.Column(scale=1):
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
bubble_full_width=False,
avatar_images=(None, "🚨"),
height=400
)
# Add a submit button
submit_btn = gr.Button("Analyze Warning Lamp", variant="primary")
submit_btn.click(
detect_warning_lamp,
inputs=[image_input, chatbot, system_message],
outputs=chatbot
)
if __name__ == "__main__":
demo.launch()