Update app.py
Browse files
app.py
CHANGED
@@ -11,13 +11,17 @@ model_dir = "models/Miwa-Keita/zenz-v2.5-medium"
|
|
11 |
snapshot_download(
|
12 |
repo_id="Miwa-Keita/zenz-v2.5-medium",
|
13 |
local_dir=model_dir,
|
14 |
-
allow_patterns=["*.
|
15 |
ignore_patterns=["optimizer.pt", "checkpoint*"], # いらないファイルを無視
|
16 |
)
|
17 |
|
18 |
# モデルとトークナイザーのロード(GPT-2 アーキテクチャ)
|
19 |
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
20 |
-
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# 入力を調整する関数
|
23 |
def preprocess_input(user_input):
|
|
|
11 |
snapshot_download(
|
12 |
repo_id="Miwa-Keita/zenz-v2.5-medium",
|
13 |
local_dir=model_dir,
|
14 |
+
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.model"], # SafeTensorsに変更
|
15 |
ignore_patterns=["optimizer.pt", "checkpoint*"], # いらないファイルを無視
|
16 |
)
|
17 |
|
18 |
# モデルとトークナイザーのロード(GPT-2 アーキテクチャ)
|
19 |
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
20 |
+
model = AutoModelForCausalLM.from_pretrained(
|
21 |
+
model_dir,
|
22 |
+
torch_dtype=torch.float32,
|
23 |
+
safe_serialization=True # SafeTensorsを有効化
|
24 |
+
)
|
25 |
|
26 |
# 入力を調整する関数
|
27 |
def preprocess_input(user_input):
|