Gabolozano commited on
Commit
945a330
·
verified ·
1 Parent(s): 6564a3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -28
app.py CHANGED
@@ -1,33 +1,26 @@
1
- import os
2
- import gradio as gr
3
- from transformers import pipeline
4
- from transformers import DetrForSegmentation, DetrConfig
5
 
6
- # Initialize the configuration for DetrForObjectDetection
7
- config = DetrConfig.from_pretrained("facebook/detr-resnet-50")
8
 
9
- # Create the model for object detection using the specified configuration
10
- model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50", config=config)
 
11
 
12
- # Updated function call
13
- results = processed_image(model, image, size={'longest_edge': 800})
14
 
15
- def get_pipeline_prediction(pil_image):
16
- # first get the pipeline output given the pil image
17
- pipeline_output = od_pipe(pil_image)
18
- # Then Process the image using the pipeline output
19
- processed_image = render_results_in_image(pil_image,
20
- pipeline_output)
21
-
22
- return processed_image
23
 
24
-
25
- demo = gr.Interface(
26
- fn=get_pipeline_prediction,
27
- inputs=gr.Image(label="Input image",
28
- type="pil"),
29
- outputs=gr.Image(label="Output image with predicted instances",
30
- type="pil")
31
- )
32
-
33
- demo.launch
 
1
+ from transformers import DetrImageProcessor, DetrForObjectDetection
2
+ import torch
3
+ from PIL import Image
4
+ import requests
5
 
6
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
7
+ image = Image.open(requests.get(url, stream=True).raw)
8
 
9
+ # you can specify the revision tag if you don't want the timm dependency
10
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
11
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
12
 
13
+ inputs = processor(images=image, return_tensors="pt")
14
+ outputs = model(**inputs)
15
 
16
+ # convert outputs (bounding boxes and class logits) to COCO API
17
+ # let's only keep detections with score > 0.9
18
+ target_sizes = torch.tensor([image.size[::-1]])
19
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
 
 
 
 
20
 
21
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
22
+ box = [round(i, 2) for i in box.tolist()]
23
+ print(
24
+ f"Detected {model.config.id2label[label.item()]} with confidence "
25
+ f"{round(score.item(), 3)} at location {box}"
26
+ )