add chat
Browse files
app.py
CHANGED
@@ -17,6 +17,23 @@ def load():
|
|
17 |
)
|
18 |
return model, tokenizer
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
def generate():
|
21 |
pass
|
22 |
|
@@ -25,6 +42,7 @@ st.header(":dna: 遺伝カウンセリング対話AI")
|
|
25 |
|
26 |
|
27 |
# 初期化
|
|
|
28 |
if "messages" not in st.session_state:
|
29 |
st.session_state["messages"] = []
|
30 |
if "options" not in st.session_state:
|
@@ -64,9 +82,9 @@ if user_prompt := st.chat_input("質問を送信してください"):
|
|
64 |
with st.chat_message("user"):
|
65 |
st.text(user_prompt)
|
66 |
st.session_state["messages"].append({"role": "user", "content": user_prompt})
|
67 |
-
|
|
|
68 |
with st.chat_message("assistant"):
|
69 |
st.text(response)
|
70 |
st.session_state["messages"].append({"role": "assistant", "content": user_prompt})
|
71 |
|
72 |
-
model, tokenizer = load()
|
|
|
17 |
)
|
18 |
return model, tokenizer
|
19 |
|
20 |
+
def get_prompt(user_query, system_prompt, messages="", sep="\n\n### "):
|
21 |
+
prompt = system_prompt + "\n以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。"
|
22 |
+
roles = ["指示", "応答"]
|
23 |
+
msgs = [": \n" + user_query, ": "]
|
24 |
+
if messages:
|
25 |
+
roles.insert(1, "入力")
|
26 |
+
msgs.insert(1, ": \n" + "\n".join(message for _, message in messages))
|
27 |
+
|
28 |
+
for role, msg in zip(roles, msgs):
|
29 |
+
prompt += sep + role + msg
|
30 |
+
return prompt
|
31 |
+
|
32 |
+
def get_input_token_length(user_query, system_prompt, messages=""):
|
33 |
+
prompt = get_prompt(user_query, system_prompt, messages)
|
34 |
+
input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
|
35 |
+
return input_ids.shape[-1]
|
36 |
+
|
37 |
def generate():
|
38 |
pass
|
39 |
|
|
|
42 |
|
43 |
|
44 |
# 初期化
|
45 |
+
model, tokenizer = load()
|
46 |
if "messages" not in st.session_state:
|
47 |
st.session_state["messages"] = []
|
48 |
if "options" not in st.session_state:
|
|
|
82 |
with st.chat_message("user"):
|
83 |
st.text(user_prompt)
|
84 |
st.session_state["messages"].append({"role": "user", "content": user_prompt})
|
85 |
+
token_kength = get_input_token_length(user_query=user_prompt, system_prompt=st.session_state["options"]["system_prompt"], messages=st.session_state["messages"])
|
86 |
+
response = f"{token_kength}: " + get_prompt(user_query=user_prompt, system_prompt=st.session_state["options"]["system_prompt"], messages=st.session_state["messages"])
|
87 |
with st.chat_message("assistant"):
|
88 |
st.text(response)
|
89 |
st.session_state["messages"].append({"role": "assistant", "content": user_prompt})
|
90 |
|
|