File size: 4,352 Bytes
7206ed3
 
 
 
 
6db6a87
63056c5
ea66b57
7206ed3
13575f0
 
6db6a87
0da9dca
 
 
 
56f79a3
 
 
 
 
 
a8badba
 
 
 
 
 
 
 
 
 
 
0da9dca
 
 
 
 
 
a84611a
0da9dca
 
7206ed3
3e17a60
 
 
 
 
 
 
 
 
 
 
 
 
7206ed3
3e17a60
 
 
 
 
 
 
 
 
 
 
 
 
b11b05c
7206ed3
b11b05c
 
 
 
 
 
 
 
 
 
7206ed3
a8badba
 
 
 
 
 
 
a4e7597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import yaml
import numpy as np
from matplotlib import cm
import gradio as gr
import deeplabcut
import dlclibrary
import transformers

from PIL import Image
import requests

from viz_utils import save_results_as_json, draw_keypoints_on_image, draw_bbox_w_text, save_results_only_dlc
from detection_utils import predict_md, crop_animal_detections
from ui_utils import gradio_inputs_for_MD_DLC, gradio_outputs_for_MD_DLC, gradio_description_and_examples

from deeplabcut.utils import auxiliaryfunctions
from dlclibrary.dlcmodelzoo.modelzoo_download import (
    download_huggingface_model,
    MODELOPTIONS,
)



# TESTING (passes) download the SuperAnimal models:
#model = 'superanimal_topviewmouse'
#train_dir = 'DLC_models/sa-tvm'
#download_huggingface_model(model, train_dir)

# grab demo data cooco cat:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

# megadetector and dlc model look up
MD_models_dict = {'md_v5a': "MD_models/md_v5a.0.0.pt", # 
                  'md_v5b': "MD_models/md_v5b.0.0.pt"}

# DLC models target  dirs
DLC_models_dict = {'superanimal_topviewmouse': "DLC_models/sa-tvm",
                   'superanimal_quadruped': "DLC_models/sa-q",
                    'full_human': "DLC_models/DLC_human_dancing/"}


#####################################################
def predict_pipeline(img_input,
                     mega_model_input,
                     dlc_model_input_str,
                     flag_dlc_only,
                     flag_show_str_labels,
                     bbox_likelihood_th,
                     kpts_likelihood_th,
                     font_style,
                     font_size,
                     keypt_color,
                     marker_size,
                     ):

    if not flag_dlc_only:
        ############################################################                                               
        # ### Run Megadetector
        md_results = predict_md(img_input, 
                                MD_models_dict[mega_model_input], #mega_model_input,
                                size=640) #Image.fromarray(results.imgs[0])

        ################################################################
        # Obtain animal crops for bboxes with confidence above th
        list_crops = crop_animal_detections(img_input,
                                            md_results,
                                            bbox_likelihood_th)

        ############################################################

    ## Get DLC model and label map  
    
    # If model is found: do not download (previous execution is likely within same day)
    # TODO: can we ask the user whether to reload dlc model if a directory is found?
    if os.path.isdir(DLC_models_dict[dlc_model_input_str]) and \
        len(os.listdir(DLC_models_dict[dlc_model_input_str])) > 0:
        path_to_DLCmodel = DLC_models_dict[dlc_model_input_str]
    else:
        path_to_DLCmodel = download_huggingface_model(dlc_model_input_str, 
                                         DLC_models_dict[dlc_model_input_str])

    # extract map label ids to strings
    pose_cfg_path = os.path.join(DLC_models_dict[dlc_model_input_str],
                                 'pose_cfg.yaml')
    with open(pose_cfg_path, "r") as stream:
        pose_cfg_dict = yaml.safe_load(stream) 
    map_label_id_to_str = dict([(k,v) for k,v in zip([el[0] for el in pose_cfg_dict['all_joints']],  # pose_cfg_dict['all_joints'] is a list of one-element lists,
                                                     pose_cfg_dict['all_joints_names'])])





#########################################################
# Define user interface and launch
inputs = gradio_inputs_for_MD_DLC(list(MD_models_dict.keys()),
                                  list(DLC_models_dict.keys()))
outputs = gradio_outputs_for_MD_DLC()                                    
[gr_title, 
 gr_description, 
 examples] = gradio_description_and_examples()

# launch
demo = gr.Interface(predict_pipeline, 
                    inputs=inputs,
                    outputs=outputs, 
                    title=gr_title, 
                    description=gr_description,
                    examples=examples,
                    theme="huggingface")

demo.launch(enable_queue=True, share=True)