yeq6x commited on
Commit
e18b12e
·
1 Parent(s): 730ac54
Files changed (1) hide show
  1. scripts/generate_prompt.py +3 -0
scripts/generate_prompt.py CHANGED
@@ -9,6 +9,8 @@ import numpy as np
9
  from tensorflow.keras.layers import TFSMLayer
10
  from huggingface_hub import hf_hub_download
11
 
 
 
12
  # 画像サイズの設定
13
  IMAGE_SIZE = 448
14
 
@@ -38,6 +40,7 @@ def download_model_files(repo_id, model_dir, sub_dir, files, sub_files):
38
  for file in sub_files:
39
  hf_hub_download(repo_id, file, subfolder=sub_dir, cache_dir=os.path.join(model_dir, sub_dir), force_download=True, force_filename=file)
40
 
 
41
  def load_wd14_tagger_model():
42
  """WD14タグ付けモデルをロード"""
43
  model_dir = "wd14_tagger_model"
 
9
  from tensorflow.keras.layers import TFSMLayer
10
  from huggingface_hub import hf_hub_download
11
 
12
+ import spaces
13
+
14
  # 画像サイズの設定
15
  IMAGE_SIZE = 448
16
 
 
40
  for file in sub_files:
41
  hf_hub_download(repo_id, file, subfolder=sub_dir, cache_dir=os.path.join(model_dir, sub_dir), force_download=True, force_filename=file)
42
 
43
+ @spaces.GPU
44
  def load_wd14_tagger_model():
45
  """WD14タグ付けモデルをロード"""
46
  model_dir = "wd14_tagger_model"