Miwa-Keita commited on
Commit
0594435
·
verified ·
1 Parent(s): 1d4dce2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -7
app.py CHANGED
@@ -14,16 +14,16 @@ snapshot_download(
14
  ignore_patterns=["optimizer.pt", "checkpoint*"], # いらないファイルを無視
15
  )
16
 
17
- # モデルのロード
18
- model_path = os.path.join(model_dir, "pytorch_model.bin") # 必要なモデルファイル
19
- model = torch.load(model_path, map_location="cpu")
20
 
21
  # 入力を調整する関数
22
  def preprocess_input(user_input):
23
  prefix = "\uEE00" # 前に付与する文字列
24
  suffix = "\uEE01" # 後ろに付与する文字列
25
  processed_input = prefix + user_input + suffix
26
- return model(processed_input)
27
 
28
  # 出力を調整する関数
29
  def postprocess_output(model_output):
@@ -33,12 +33,28 @@ def postprocess_output(model_output):
33
  return model_output.split(suffix)[1]
34
  return model_output
35
 
36
- # インターフェースを定義
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  iface = gr.Interface(
38
- fn=lambda x: postprocess_output(preprocess_input(x)),
39
  inputs=gr.Textbox(label="変換する文字列(カタカナ)"),
40
  outputs=gr.Textbox(label="変換結果"),
41
- title="ニューラルかな漢字変換モデルzenz-v1のデモ",
42
  description="変換したい文字列をカタカナを入力してください"
43
  )
44
 
 
14
  ignore_patterns=["optimizer.pt", "checkpoint*"], # いらないファイルを無視
15
  )
16
 
17
+ # モデルとトークナイザーのロード(GPT-2 アーキテクチャ)
18
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
19
+ model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float32)
20
 
21
  # 入力を調整する関数
22
  def preprocess_input(user_input):
23
  prefix = "\uEE00" # 前に付与する文字列
24
  suffix = "\uEE01" # 後ろに付与する文字列
25
  processed_input = prefix + user_input + suffix
26
+ return processed_input
27
 
28
  # 出力を調整する関数
29
  def postprocess_output(model_output):
 
33
  return model_output.split(suffix)[1]
34
  return model_output
35
 
36
+ # 変換関数
37
+ def generate_text(user_input):
38
+ processed_input = preprocess_input(user_input)
39
+
40
+ # テキストをトークン化
41
+ inputs = tokenizer(processed_input, return_tensors="pt")
42
+
43
+ # モデルで生成
44
+ outputs = model.generate(**inputs, max_length=100)
45
+
46
+ # 出力のデコード
47
+ decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
48
+
49
+ # 出力の整形
50
+ return postprocess_output(decoded_output)
51
+
52
+ # Gradio インターフェース
53
  iface = gr.Interface(
54
+ fn=generate_text,
55
  inputs=gr.Textbox(label="変換する文字列(カタカナ)"),
56
  outputs=gr.Textbox(label="変換結果"),
57
+ title="ニューラルかな漢字変換モデル zenZ-v1 のデモ",
58
  description="変換したい文字列をカタカナを入力してください"
59
  )
60