p1atdev commited on
Commit
ff91c77
·
1 Parent(s): 8e2ba3c

chore: use gpu

Browse files
Files changed (2) hide show
  1. app.py +12 -9
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import time
3
-
4
 
5
  import torch
6
  from transformers import (
@@ -16,7 +16,7 @@ import gradio as gr
16
  MODEL_NAME = os.environ.get("MODEL_NAME", None)
17
  assert MODEL_NAME is not None
18
  MODEL_PATH = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors")
19
- DEVICE = torch.device("cpu")
20
 
21
 
22
  def fix_compiled_state_dict(state_dict: dict):
@@ -24,12 +24,11 @@ def fix_compiled_state_dict(state_dict: dict):
24
 
25
 
26
  def prepare_models():
27
- config = AutoConfig.from_pretrained(
28
- MODEL_NAME, use_cache=True, trust_remote_code=True
29
- )
30
  model = AutoModelForPreTraining.from_config(
31
  config, torch_dtype=torch.bfloat16, trust_remote_code=True
32
  )
 
33
  processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
34
 
35
  state_dict = load_file(MODEL_PATH)
@@ -38,7 +37,7 @@ def prepare_models():
38
 
39
  model.eval()
40
  model = model.to(DEVICE)
41
- model = torch.compile(model)
42
 
43
  return model, processor
44
 
@@ -46,6 +45,7 @@ def prepare_models():
46
  def demo():
47
  model, processor = prepare_models()
48
 
 
49
  @torch.inference_mode()
50
  def generate_tags(
51
  text: str,
@@ -109,8 +109,8 @@ def demo():
109
  label="Auto detect copyright tags.", value=False
110
  )
111
  copyright_tags = gr.Textbox(
112
- label="Custom tags",
113
- placeholder="Enter custom tags here. e.g.) hatsune miku",
114
  )
115
  translate_btn = gr.Button(value="Translate")
116
 
@@ -124,9 +124,12 @@ def demo():
124
  value=0.1,
125
  step=0.1,
126
  )
127
- top_k = gr.Number(
128
  label="Top k",
 
 
129
  value=10,
 
130
  )
131
  top_p = gr.Slider(
132
  label="Top p",
 
1
  import os
2
  import time
3
+ import spaces
4
 
5
  import torch
6
  from transformers import (
 
16
  MODEL_NAME = os.environ.get("MODEL_NAME", None)
17
  assert MODEL_NAME is not None
18
  MODEL_PATH = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors")
19
+ DEVICE = torch.device("cuda")
20
 
21
 
22
  def fix_compiled_state_dict(state_dict: dict):
 
24
 
25
 
26
  def prepare_models():
27
+ config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
 
 
28
  model = AutoModelForPreTraining.from_config(
29
  config, torch_dtype=torch.bfloat16, trust_remote_code=True
30
  )
31
+ model.decoder_model.use_cache = True
32
  processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
33
 
34
  state_dict = load_file(MODEL_PATH)
 
37
 
38
  model.eval()
39
  model = model.to(DEVICE)
40
+ # model = torch.compile(model)
41
 
42
  return model, processor
43
 
 
45
  def demo():
46
  model, processor = prepare_models()
47
 
48
+ @spaces.GPU(duration=5)
49
  @torch.inference_mode()
50
  def generate_tags(
51
  text: str,
 
109
  label="Auto detect copyright tags.", value=False
110
  )
111
  copyright_tags = gr.Textbox(
112
+ label="Copyright tags",
113
+ placeholder="Enter copyright tags here. e.g.) hatsune miku",
114
  )
115
  translate_btn = gr.Button(value="Translate")
116
 
 
124
  value=0.1,
125
  step=0.1,
126
  )
127
+ top_k = gr.Slider(
128
  label="Top k",
129
+ minimum=1,
130
+ maximum=100,
131
  value=10,
132
+ step=1,
133
  )
134
  top_p = gr.Slider(
135
  label="Top p",
requirements.txt CHANGED
@@ -3,3 +3,4 @@ transformers
3
  accelerate
4
  safetensors
5
  huggingface_hub
 
 
3
  accelerate
4
  safetensors
5
  huggingface_hub
6
+ spaces