p1atdev commited on
Commit
25fff87
·
1 Parent(s): d457afd

chore: run on cpu

Browse files
Files changed (1) hide show
  1. app.py +51 -43
app.py CHANGED
@@ -16,6 +16,7 @@ import gradio as gr
16
  MODEL_NAME = os.environ.get("MODEL_NAME", None)
17
  assert MODEL_NAME is not None
18
  MODEL_PATH = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors")
 
19
 
20
 
21
  def fix_compiled_state_dict(state_dict: dict):
@@ -36,6 +37,7 @@ def prepare_models():
36
  model.load_state_dict(state_dict)
37
 
38
  model.eval()
 
39
  model = torch.compile(model)
40
 
41
  return model, processor
@@ -48,7 +50,7 @@ def demo():
48
  def generate_tags(
49
  text: str,
50
  auto_detect: bool,
51
- copyright_tags: str,
52
  max_new_tokens: int = 128,
53
  do_sample: bool = False,
54
  temperature: float = 0.1,
@@ -70,10 +72,10 @@ def demo():
70
 
71
  start_time = time.time()
72
  outputs = model.generate(
73
- input_ids=inputs["input_ids"].to("cuda"),
74
- attention_mask=inputs["attention_mask"].to("cuda"),
75
- encoder_input_ids=inputs["encoder_input_ids"].to("cuda"),
76
- encoder_attention_mask=inputs["encoder_attention_mask"].to("cuda"),
77
  max_new_tokens=max_new_tokens,
78
  do_sample=do_sample,
79
  temperature=temperature,
@@ -93,44 +95,50 @@ def demo():
93
  )
94
  return [deocded, f"Time elapsed: {elapsed:.2f} seconds"]
95
 
 
 
 
 
 
96
  with gr.Blocks() as ui:
97
- with gr.Row():
98
- with gr.Column():
99
- text = gr.Text(label="Text", lines=4)
100
- auto_detect = gr.Checkbox(
101
- label="Auto detect copyright tags.", value=False
102
- )
103
- copyright_tags = gr.Textbox(
104
- label="Custom tags",
105
- placeholder="Enter custom tags here. e.g.) hatsune miku",
106
- )
107
- translate_btn = gr.Button(value="Translate")
108
-
109
- with gr.Accordion(label="Advanced", open=False):
110
- max_new_tokens = gr.Number(label="Max new tokens", value=128)
111
- do_sample = gr.Checkbox(label="Do sample", value=False)
112
- temperature = gr.Slider(
113
- label="Temperature",
114
- minimum=0.1,
115
- maximum=1.0,
116
- value=0.1,
117
- step=0.1,
118
- )
119
- top_k = gr.Number(
120
- label="Top k",
121
- value=10,
122
  )
123
- top_p = gr.Slider(
124
- label="Top p",
125
- minimum=0.1,
126
- maximum=1.0,
127
- value=0.1,
128
- step=0.1,
129
  )
130
-
131
- with gr.Column():
132
- output = gr.Textbox(label="Output", lines=4, interactive=False)
133
- time_elapsed = gr.Markdown(value="")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  gr.Examples(
136
  examples=[["Miku is looking at viewer.", True]],
@@ -139,9 +147,9 @@ def demo():
139
 
140
  gr.on(
141
  triggers=[
142
- text.change,
143
- auto_detect.change,
144
- copyright_tags.change,
145
  translate_btn.click,
146
  ],
147
  fn=generate_tags,
 
16
  MODEL_NAME = os.environ.get("MODEL_NAME", None)
17
  assert MODEL_NAME is not None
18
  MODEL_PATH = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors")
19
+ DEVICE = torch.device("cpu")
20
 
21
 
22
  def fix_compiled_state_dict(state_dict: dict):
 
37
  model.load_state_dict(state_dict)
38
 
39
  model.eval()
40
+ model = model.to(DEVICE)
41
  model = torch.compile(model)
42
 
43
  return model, processor
 
50
  def generate_tags(
51
  text: str,
52
  auto_detect: bool,
53
+ copyright_tags: str = "",
54
  max_new_tokens: int = 128,
55
  do_sample: bool = False,
56
  temperature: float = 0.1,
 
72
 
73
  start_time = time.time()
74
  outputs = model.generate(
75
+ input_ids=inputs["input_ids"].to(model.device),
76
+ attention_mask=inputs["attention_mask"].to(model.device),
77
+ encoder_input_ids=inputs["encoder_input_ids"].to(model.device),
78
+ encoder_attention_mask=inputs["encoder_attention_mask"].to(model.device),
79
  max_new_tokens=max_new_tokens,
80
  do_sample=do_sample,
81
  temperature=temperature,
 
95
  )
96
  return [deocded, f"Time elapsed: {elapsed:.2f} seconds"]
97
 
98
+ # warmup
99
+ print("warming up...")
100
+ print(generate_tags("Miku is looking at viewer.", True))
101
+ print("done.")
102
+
103
  with gr.Blocks() as ui:
104
+ with gr.Column():
105
+ with gr.Row():
106
+ with gr.Column():
107
+ text = gr.Text(label="Text", lines=4)
108
+ auto_detect = gr.Checkbox(
109
+ label="Auto detect copyright tags.", value=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  )
111
+ copyright_tags = gr.Textbox(
112
+ label="Custom tags",
113
+ placeholder="Enter custom tags here. e.g.) hatsune miku",
 
 
 
114
  )
115
+ translate_btn = gr.Button(value="Translate")
116
+
117
+ with gr.Accordion(label="Advanced", open=False):
118
+ max_new_tokens = gr.Number(label="Max new tokens", value=128)
119
+ do_sample = gr.Checkbox(label="Do sample", value=False)
120
+ temperature = gr.Slider(
121
+ label="Temperature",
122
+ minimum=0.1,
123
+ maximum=1.0,
124
+ value=0.1,
125
+ step=0.1,
126
+ )
127
+ top_k = gr.Number(
128
+ label="Top k",
129
+ value=10,
130
+ )
131
+ top_p = gr.Slider(
132
+ label="Top p",
133
+ minimum=0.1,
134
+ maximum=1.0,
135
+ value=0.1,
136
+ step=0.1,
137
+ )
138
+
139
+ with gr.Column():
140
+ output = gr.Textbox(label="Output", lines=4, interactive=False)
141
+ time_elapsed = gr.Markdown(value="")
142
 
143
  gr.Examples(
144
  examples=[["Miku is looking at viewer.", True]],
 
147
 
148
  gr.on(
149
  triggers=[
150
+ # text.change,
151
+ # auto_detect.change,
152
+ # copyright_tags.change,
153
  translate_btn.click,
154
  ],
155
  fn=generate_tags,