import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as T

# Load the trained model (YOLOv8n) with your weights
model = torch.hub.load('ultralytics/yolov8', 'yolov8n')
model.load_state_dict(torch.load("best_p6.pt"))
model.eval()

# Define the image transformation (if required, based on your dataset preprocessing)
transform = T.Compose([T.ToTensor()])

# Define the inference function
def process_image(image):
    # Convert the image to tensor and make inference
    image_tensor = transform(image).unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        outputs = model(image_tensor)
    
    # Get the output image with bounding boxes (you can adjust this part based on your model's output)
    result_image = outputs.render()[0]  # This will render bounding boxes on the image
    
    # Convert to PIL image for easy download
    result_pil_image = Image.fromarray(result_image)
    
    # Save the output image for download
    output_path = "/tmp/output_image.jpg"
    result_pil_image.save(output_path)
    
    return output_path

# Define Gradio interface
iface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="pil"),  # Image input from user
    outputs=gr.File(label="Download Processed Image"),  # Provide the file output for download
    title="Waste Detection",  # Interface title
    description="Upload an image of floating waste, and the model will detect and label the objects in it."
)

# Launch the interface
iface.launch()