hysts HF staff commited on
Commit
7479a3a
·
1 Parent(s): ed7463d
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +29 -31
  3. style.css +3 -0
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🏃
4
  colorFrom: red
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: red
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 3.34.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -18,13 +18,9 @@ sys.path.insert(0, 'bizarre-pose-estimator')
18
 
19
  from _util.twodee_v0 import I as ImageWrapper
20
 
21
- TITLE = 'ShuhongChen/bizarre-pose-estimator (tagger)'
22
- DESCRIPTION = 'This is an unofficial demo for https://github.com/ShuhongChen/bizarre-pose-estimator.'
23
 
24
- HF_TOKEN = os.getenv('HF_TOKEN')
25
- MODEL_REPO = 'hysts/bizarre-pose-estimator-models'
26
- MODEL_FILENAME = 'tagger.pth'
27
- LABEL_FILENAME = 'tags.txt'
28
 
29
 
30
  def load_sample_image_paths() -> list[pathlib.Path]:
@@ -33,17 +29,14 @@ def load_sample_image_paths() -> list[pathlib.Path]:
33
  dataset_repo = 'hysts/sample-images-TADNE'
34
  path = huggingface_hub.hf_hub_download(dataset_repo,
35
  'images.tar.gz',
36
- repo_type='dataset',
37
- use_auth_token=HF_TOKEN)
38
  with tarfile.open(path) as f:
39
  f.extractall()
40
  return sorted(image_dir.glob('*'))
41
 
42
 
43
  def load_model(device: torch.device) -> torch.nn.Module:
44
- path = huggingface_hub.hf_hub_download(MODEL_REPO,
45
- MODEL_FILENAME,
46
- use_auth_token=HF_TOKEN)
47
  state_dict = torch.load(path)
48
  model = torchvision.models.resnet50(num_classes=1062)
49
  model.load_state_dict(state_dict)
@@ -53,9 +46,7 @@ def load_model(device: torch.device) -> torch.nn.Module:
53
 
54
 
55
  def load_labels() -> list[str]:
56
- label_path = huggingface_hub.hf_hub_download(MODEL_REPO,
57
- LABEL_FILENAME,
58
- use_auth_token=HF_TOKEN)
59
  with open(label_path) as f:
60
  labels = [line.strip() for line in f.readlines()]
61
  return labels
@@ -88,20 +79,27 @@ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
88
  model = load_model(device)
89
  labels = load_labels()
90
 
91
- func = functools.partial(predict, device=device, model=model, labels=labels)
92
-
93
- gr.Interface(
94
- fn=func,
95
- inputs=[
96
- gr.Image(label='Input', type='pil'),
97
- gr.Slider(label='Score Threshold',
98
- minimum=0,
99
- maximum=1,
100
- step=0.05,
101
- value=0.5),
102
- ],
103
- outputs=gr.Label(label='Output'),
104
- examples=examples,
105
- title=TITLE,
106
- description=DESCRIPTION,
107
- ).queue().launch(show_api=False)
 
 
 
 
 
 
 
 
18
 
19
  from _util.twodee_v0 import I as ImageWrapper
20
 
21
+ DESCRIPTION = '# [ShuhongChen/bizarre-pose-estimator (tagger)](https://github.com/ShuhongChen/bizarre-pose-estimator)'
 
22
 
23
+ MODEL_REPO = 'public-data/bizarre-pose-estimator-models'
 
 
 
24
 
25
 
26
  def load_sample_image_paths() -> list[pathlib.Path]:
 
29
  dataset_repo = 'hysts/sample-images-TADNE'
30
  path = huggingface_hub.hf_hub_download(dataset_repo,
31
  'images.tar.gz',
32
+ repo_type='dataset')
 
33
  with tarfile.open(path) as f:
34
  f.extractall()
35
  return sorted(image_dir.glob('*'))
36
 
37
 
38
  def load_model(device: torch.device) -> torch.nn.Module:
39
+ path = huggingface_hub.hf_hub_download(MODEL_REPO, 'tagger.pth')
 
 
40
  state_dict = torch.load(path)
41
  model = torchvision.models.resnet50(num_classes=1062)
42
  model.load_state_dict(state_dict)
 
46
 
47
 
48
  def load_labels() -> list[str]:
49
+ label_path = huggingface_hub.hf_hub_download(MODEL_REPO, 'tags.txt')
 
 
50
  with open(label_path) as f:
51
  labels = [line.strip() for line in f.readlines()]
52
  return labels
 
79
  model = load_model(device)
80
  labels = load_labels()
81
 
82
+ fn = functools.partial(predict, device=device, model=model, labels=labels)
83
+
84
+ with gr.Blocks(css='style.css') as demo:
85
+ gr.Markdown(DESCRIPTION)
86
+ with gr.Row():
87
+ with gr.Column():
88
+ image = gr.Image(label='Input', type='pil')
89
+ threshold = gr.Slider(label='Score Threshold',
90
+ minimum=0,
91
+ maximum=1,
92
+ step=0.05,
93
+ value=0.5)
94
+ run_button = gr.Button('Run')
95
+ with gr.Column():
96
+ result = gr.Label(label='Output')
97
+
98
+ inputs = [image, threshold]
99
+ gr.Examples(examples=examples,
100
+ inputs=inputs,
101
+ outputs=result,
102
+ fn=fn,
103
+ cache_examples=os.getenv('CACHE_EXAMPLES') == '1')
104
+ run_button.click(fn=fn, inputs=inputs, outputs=result, api_name='predict')
105
+ demo.queue(max_size=15).launch()
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }