AbstractPhil commited on
Commit
785df91
Β·
verified Β·
1 Parent(s): 41c0c30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -28
app.py CHANGED
@@ -4,35 +4,12 @@ from transformers import AutoTokenizer, AutoModelForMaskedLM
4
  import torch
5
  import gradio as gr
6
  import re
7
- from dataclasses import dataclass
8
  from pathlib import Path
9
  import spaces
10
 
11
  @spaces.GPU
12
- @dataclass
13
- class SymbolicConfig:
14
- repo_id: str = "AbstractPhil/bert-beatrix-2048"
15
- revision: str = "main"
16
- symbolic_roles: list = (
17
- "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
18
- "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
19
- "<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
20
- "<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
21
- "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
22
- "<fabric>", "<jewelry>"
23
- )
24
-
25
- config = SymbolicConfig()
26
- tokenizer = AutoTokenizer.from_pretrained(config.repo_id, revision=config.revision)
27
- model = AutoModelForMaskedLM.from_pretrained(
28
- config.repo_id,
29
- revision=config.revision,
30
- trust_remote_code=True
31
- ).eval().cuda()
32
-
33
- MASK_TOKEN = tokenizer.mask_token or "[MASK]"
34
-
35
  def mask_and_predict(text: str, selected_roles: list[str]):
 
36
  results = []
37
  masked_text = text
38
  token_ids = tokenizer.encode(text, return_tensors="pt").cuda()
@@ -64,16 +41,32 @@ def mask_and_predict(text: str, selected_roles: list[str]):
64
  accuracy = sum(1 for r in results if r["Match"] == "βœ…") / max(len(results), 1)
65
  return results, f"Accuracy: {accuracy:.1%}"
66
 
67
- def build_interface():
68
- role_checkboxes = [gr.Checkbox(label=role, value=False) for role in config.symbolic_roles]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
 
70
  with gr.Blocks() as demo:
71
  gr.Markdown("## πŸ”Ž Symbolic BERT Inference Test")
72
  with gr.Row():
73
  with gr.Column():
74
  input_text = gr.Textbox(label="Symbolic Input Caption", lines=3)
75
  selected_roles = gr.CheckboxGroup(
76
- choices=config.symbolic_roles,
77
  label="Mask these symbolic roles"
78
  )
79
  run_btn = gr.Button("Run Mask Inference")
@@ -85,7 +78,6 @@ def build_interface():
85
 
86
  return demo
87
 
88
-
89
  if __name__ == "__main__":
90
  demo = build_interface()
91
  demo.launch()
 
4
  import torch
5
  import gradio as gr
6
  import re
 
7
  from pathlib import Path
8
  import spaces
9
 
10
  @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def mask_and_predict(text: str, selected_roles: list[str]):
12
+ MASK_TOKEN = tokenizer.mask_token or "[MASK]"
13
  results = []
14
  masked_text = text
15
  token_ids = tokenizer.encode(text, return_tensors="pt").cuda()
 
41
  accuracy = sum(1 for r in results if r["Match"] == "βœ…") / max(len(results), 1)
42
  return results, f"Accuracy: {accuracy:.1%}"
43
 
44
+ symbolic_roles = [
45
+ "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
46
+ "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
47
+ "<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
48
+ "<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
49
+ "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
50
+ "<fabric>", "<jewelry>"
51
+ ]
52
+
53
+ REPO_ID = "AbstractPhil/bert-beatrix-2048"
54
+ REVISION = "main"
55
+ tokenizer = AutoTokenizer.from_pretrained(REPO_ID, revision=REVISION)
56
+ model = AutoModelForMaskedLM.from_pretrained(
57
+ REPO_ID,
58
+ revision=REVISION,
59
+ trust_remote_code=True
60
+ ).eval().cuda()
61
 
62
+ def build_interface():
63
  with gr.Blocks() as demo:
64
  gr.Markdown("## πŸ”Ž Symbolic BERT Inference Test")
65
  with gr.Row():
66
  with gr.Column():
67
  input_text = gr.Textbox(label="Symbolic Input Caption", lines=3)
68
  selected_roles = gr.CheckboxGroup(
69
+ choices=symbolic_roles,
70
  label="Mask these symbolic roles"
71
  )
72
  run_btn = gr.Button("Run Mask Inference")
 
78
 
79
  return demo
80
 
 
81
  if __name__ == "__main__":
82
  demo = build_interface()
83
  demo.launch()