minoD commited on
Commit
dc0a3cb
·
1 Parent(s): 80f184d

Add application file

Browse files
Files changed (1) hide show
  1. app.py +81 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from peft import PeftModel
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+
6
+ model_name = "rinna/japanese-gpt-neox-3.6b"
7
+ peft_name = "minoD/GOMESS"
8
+
9
+ model = AutoModelForCausalLM.from_pretrained(
10
+ model_name,
11
+ device_map="cpu",
12
+ )
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
15
+
16
+ model = PeftModel.from_pretrained(
17
+ model,
18
+ peft_name,
19
+ device_map="cpu",
20
+ )
21
+
22
+ # プロンプトテンプレートの準備にカテゴリを追加
23
+ def generate_prompt(data_point, category=None):
24
+ category_part = f"### カテゴリ:\n{category}\n\n" if category else ""
25
+ result = f"{category_part}### 指示:\n{data_point['instruction']}\n\n### 入力:\n{data_point['input']}\n\n### 回答:\n" if data_point["input"] else f"{category_part}### 指示:\n{data_point['instruction']}\n\n### 回答:\n"
26
+ result = result.replace('\n', '<NL>')
27
+ return result
28
+
29
+ def generate(instruction, input=None, category=None, maxTokens=256):
30
+ # 推論
31
+ prompt = generate_prompt({'instruction':instruction, 'input':input}, category)
32
+ input_ids = tokenizer(prompt,
33
+ return_tensors="pt",
34
+ truncation=True,
35
+ add_special_tokens=False).input_ids
36
+ outputs = model.generate(
37
+ input_ids=input_ids,
38
+ max_new_tokens=maxTokens,
39
+ do_sample=True,
40
+ temperature=0.7,
41
+ top_p=0.75,
42
+ top_k=40,
43
+ no_repeat_ngram_size=2,
44
+ )
45
+ outputs = outputs[0].tolist()
46
+
47
+ # EOSトークンにヒットしたらデコード完了
48
+ if tokenizer.eos_token_id in outputs:
49
+ eos_index = outputs.index(tokenizer.eos_token_id)
50
+ decoded = tokenizer.decode(outputs[:eos_index])
51
+
52
+ # レスポンス内容のみ抽出
53
+ sentinel = "### 回答:"
54
+ sentinelLoc = decoded.find(sentinel)
55
+ if sentinelLoc >= 0:
56
+ result = decoded[sentinelLoc+len(sentinel):]
57
+ return result.replace("<NL>", "\n") # <NL>→改行
58
+ else:
59
+ return 'Warning: Expected prompt template to be emitted. Ignoring output.'
60
+ else:
61
+ return 'Warning: no <eos> detected ignoring output'
62
+
63
+ # 既存のgenerate関数を使用しますが、print文を削除し、結果を返すように変更します。
64
+ import gradio as gr
65
+
66
+ # generate関数をGradio用に調整します。入力とカテゴリは固定されます。
67
+ def generate_for_gradio(instruction):
68
+ return generate(instruction, category="ES2Q", maxTokens=200)
69
+
70
+ # Gradioインターフェースを定義します。
71
+ iface = gr.Interface(
72
+ fn=generate_for_gradio,
73
+ inputs=[
74
+ gr.Textbox(lines=10, placeholder="ESの回答を入力してください")
75
+ ],
76
+ outputs="text",
77
+ title="ESから質問を生成テスト",
78
+ description="エントリーシートから面接官が言いそうな質問を生成します。(精度:悪)"
79
+ )
80
+
81
+ iface.launch()