Fta98 commited on
Commit
20734fc
·
1 Parent(s): 18b3695
Files changed (1) hide show
  1. app.py +20 -2
app.py CHANGED
@@ -17,6 +17,23 @@ def load():
17
  )
18
  return model, tokenizer
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def generate():
21
  pass
22
 
@@ -25,6 +42,7 @@ st.header(":dna: 遺伝カウンセリング対話AI")
25
 
26
 
27
  # 初期化
 
28
  if "messages" not in st.session_state:
29
  st.session_state["messages"] = []
30
  if "options" not in st.session_state:
@@ -64,9 +82,9 @@ if user_prompt := st.chat_input("質問を送信してください"):
64
  with st.chat_message("user"):
65
  st.text(user_prompt)
66
  st.session_state["messages"].append({"role": "user", "content": user_prompt})
67
- response = None
 
68
  with st.chat_message("assistant"):
69
  st.text(response)
70
  st.session_state["messages"].append({"role": "assistant", "content": user_prompt})
71
 
72
- model, tokenizer = load()
 
17
  )
18
  return model, tokenizer
19
 
20
+ def get_prompt(user_query, system_prompt, messages="", sep="\n\n### "):
21
+ prompt = system_prompt + "\n以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。"
22
+ roles = ["指示", "応答"]
23
+ msgs = [": \n" + user_query, ": "]
24
+ if messages:
25
+ roles.insert(1, "入力")
26
+ msgs.insert(1, ": \n" + "\n".join(message for _, message in messages))
27
+
28
+ for role, msg in zip(roles, msgs):
29
+ prompt += sep + role + msg
30
+ return prompt
31
+
32
+ def get_input_token_length(user_query, system_prompt, messages=""):
33
+ prompt = get_prompt(user_query, system_prompt, messages)
34
+ input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
35
+ return input_ids.shape[-1]
36
+
37
  def generate():
38
  pass
39
 
 
42
 
43
 
44
  # 初期化
45
+ model, tokenizer = load()
46
  if "messages" not in st.session_state:
47
  st.session_state["messages"] = []
48
  if "options" not in st.session_state:
 
82
  with st.chat_message("user"):
83
  st.text(user_prompt)
84
  st.session_state["messages"].append({"role": "user", "content": user_prompt})
85
+ token_kength = get_input_token_length(user_query=user_prompt, system_prompt=st.session_state["options"]["system_prompt"], messages=st.session_state["messages"])
86
+ response = f"{token_kength}: " + get_prompt(user_query=user_prompt, system_prompt=st.session_state["options"]["system_prompt"], messages=st.session_state["messages"])
87
  with st.chat_message("assistant"):
88
  st.text(response)
89
  st.session_state["messages"].append({"role": "assistant", "content": user_prompt})
90