curiouscurrent commited on
Commit
bc99e01
·
verified ·
1 Parent(s): 3c1f8c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -37
app.py CHANGED
@@ -1,43 +1,43 @@
1
  import gradio as gr
2
  import torch
3
  from PIL import Image
4
-
5
- # Define the function to load the YOLOv8 model and perform processing
6
- def process_image(image_path, model_path="waste-detection-yolov8/best_p6.pt"):
7
- """
8
- Processes an image using a YOLOv8 model and returns the processed image.
9
-
10
- Args:
11
- image_path (str): Path to the input image.
12
- model_path (str, optional): Path to the YOLOv8 model weights file. Defaults to "waste-detection-yolov8/best_p6.pt".
13
-
14
- Returns:
15
- PIL.Image: The processed image.
16
- """
17
- # Load the YOLOv8 model from the specified path
18
- model = torch.hub.load('ultralytics/yolov8n', 'custom', path=model_path)
19
-
20
- # Read the input image
21
- image = Image.open(image_path)
22
-
23
- # Convert the image to a tensor
24
- image = model(image)
25
-
26
- # Get the processed image from the results
27
- processed_image = image.imgs[0]
28
-
29
- return processed_image
30
-
31
- # Define the Gradio interface
32
- interface = gr.Interface(
33
- fn=process_image,
34
- inputs=gr.Image(label="Input Image", type="filepath"),
35
- outputs="image",
36
- title="Image Processing with YOLOv8n",
37
- description="Upload an image to process it with the YOLOv8n model.",
38
- thumbnail=None,
39
- article="<p>This Gradio app allows you to upload an image and process it using a YOLOv8n model.</p>",
40
  )
41
 
42
  # Launch the interface
43
- interface.launch(server_port=11111, server_name="localhost", enable_queue=True, allow_screenshot=False, allow_user_code=False)
 
1
  import gradio as gr
2
  import torch
3
  from PIL import Image
4
+ import torchvision.transforms as T
5
+
6
+ # Load the trained model (YOLOv8n) with your weights
7
+ model = torch.hub.load('ultralytics/yolov8', 'yolov8n')
8
+ model.load_state_dict(torch.load("best_p6.pt"))
9
+ model.eval()
10
+
11
+ # Define the image transformation (if required, based on your dataset preprocessing)
12
+ transform = T.Compose([T.ToTensor()])
13
+
14
+ # Define the inference function
15
+ def process_image(image):
16
+ # Convert the image to tensor and make inference
17
+ image_tensor = transform(image).unsqueeze(0) # Add batch dimension
18
+ with torch.no_grad():
19
+ outputs = model(image_tensor)
20
+
21
+ # Get the output image with bounding boxes (you can adjust this part based on your model's output)
22
+ result_image = outputs.render()[0] # This will render bounding boxes on the image
23
+
24
+ # Convert to PIL image for easy download
25
+ result_pil_image = Image.fromarray(result_image)
26
+
27
+ # Save the output image for download
28
+ output_path = "/tmp/output_image.jpg"
29
+ result_pil_image.save(output_path)
30
+
31
+ return output_path
32
+
33
+ # Define Gradio interface
34
+ iface = gr.Interface(
35
+ fn=process_image,
36
+ inputs=gr.Image(type="pil"), # Image input from user
37
+ outputs=gr.File(label="Download Processed Image"), # Provide the file output for download
38
+ title="Waste Detection", # Interface title
39
+ description="Upload an image of floating waste, and the model will detect and label the objects in it."
40
  )
41
 
42
  # Launch the interface
43
+ iface.launch()