atalaydenknalbant commited on
Commit
e4caec0
·
verified ·
1 Parent(s): 5e9ca47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -8
app.py CHANGED
@@ -2,16 +2,23 @@ import spaces
2
  import supervision as sv
3
  import PIL.Image as Image
4
  from ultralytics import YOLO
5
- from huggingface_hub import hf_hub_download
6
  import gradio as gr
7
 
8
  global repo_id
9
 
10
- def download_models(model_id):
11
- hf_hub_download(repo_id, filename=f"{model_id}", local_dir=f"./")
12
- return f"./{model_id}"
13
-
14
  repo_id = "atalaydenknalbant/asl-yolo-models"
 
 
 
 
 
 
 
 
 
 
 
15
  box_annotator = sv.BoxAnnotator()
16
  category_dict = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F', 6: 'G', 7: 'H', 8: 'I',
17
  9: 'J', 10: 'K', 11: 'L', 12: 'M', 13: 'N', 14: 'O', 15: 'P', 16: 'Q',
@@ -19,7 +26,9 @@ category_dict = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F', 6: 'G', 7: 'H',
19
 
20
  @spaces.GPU
21
  def yolo_inference(image, model_id, conf_threshold, iou_threshold, max_detection):
22
- model_path = download_models(model_id)
 
 
23
  model = YOLO(model_path)
24
  results = model(source=image, imgsz=640, iou=iou_threshold, conf=conf_threshold, verbose=False, max_det=max_detection)[0]
25
  detections = sv.Detections.from_ultralytics(results)
@@ -33,13 +42,19 @@ def yolo_inference(image, model_id, conf_threshold, iou_threshold, max_detection
33
  return annotated_image
34
 
35
  def app():
 
 
 
36
  with gr.Blocks():
37
  with gr.Row():
38
  with gr.Column():
39
  image = gr.Image(type="pil", label="Image", interactive=True)
40
 
41
- model_id = gr.Textbox(label="Model ID", placeholder="Enter model filename (.pt)")
42
-
 
 
 
43
  conf_threshold = gr.Slider(
44
  label="Confidence Threshold",
45
  minimum=0.1,
 
2
  import supervision as sv
3
  import PIL.Image as Image
4
  from ultralytics import YOLO
5
+ from huggingface_hub import hf_hub_download, list_repo_files
6
  import gradio as gr
7
 
8
  global repo_id
9
 
 
 
 
 
10
  repo_id = "atalaydenknalbant/asl-yolo-models"
11
+
12
+ def download_models(repo_id, model_id, file_extension=".pt"):
13
+ # Get list of files in the repository without using HfApi
14
+ files = list_repo_files(repo_id)
15
+ # Filter for model filenames with the given extension
16
+ model_filenames = [file for file in files if file.endswith(file_extension)]
17
+
18
+ # Download the selected model
19
+ hf_hub_download(repo_id, filename=model_id, local_dir=f"./")
20
+ return model_filenames, f"./{model_id}"
21
+
22
  box_annotator = sv.BoxAnnotator()
23
  category_dict = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F', 6: 'G', 7: 'H', 8: 'I',
24
  9: 'J', 10: 'K', 11: 'L', 12: 'M', 13: 'N', 14: 'O', 15: 'P', 16: 'Q',
 
26
 
27
  @spaces.GPU
28
  def yolo_inference(image, model_id, conf_threshold, iou_threshold, max_detection):
29
+ # Download models and get filenames within the same function
30
+ model_filenames, model_path = download_models(repo_id, model_id)
31
+
32
  model = YOLO(model_path)
33
  results = model(source=image, imgsz=640, iou=iou_threshold, conf=conf_threshold, verbose=False, max_det=max_detection)[0]
34
  detections = sv.Detections.from_ultralytics(results)
 
42
  return annotated_image
43
 
44
  def app():
45
+ # Fetch the model filenames directly in the app
46
+ model_filenames, _ = download_models(repo_id, "")
47
+
48
  with gr.Blocks():
49
  with gr.Row():
50
  with gr.Column():
51
  image = gr.Image(type="pil", label="Image", interactive=True)
52
 
53
+ model_id = gr.Dropdown(
54
+ label="Model",
55
+ choices=model_filenames,
56
+ value=model_filenames[0] if model_filenames else "",
57
+ )
58
  conf_threshold = gr.Slider(
59
  label="Confidence Threshold",
60
  minimum=0.1,