Wei-Hsu-AI's picture
fix: random choice data and stories.json
9e506cf
import gradio as gr
from transformers import BertTokenizerFast, BertForMaskedLM
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import random
import json
model_name = "nycu-ai113-dl-final-project/bert-turtle-soup-pet-zh"
dataset_name = "nycu-ai113-dl-final-project/TurtleBench-extended-zh"
model = BertForMaskedLM.from_pretrained(model_name)
tokenizer = BertTokenizerFast.from_pretrained(model_name)
answer_judge = SentenceTransformer('thenlper/gte-base-zh')
intro="""
### 玩法介紹
遊戲一開始,我會給你一個不完整的故事,這個故事通常有很多未知的細節,你需要透過提出問題來探索更多線索。你可以問我各種問題,不過請記住,我只能回答三種答案:「正確」、「錯誤」或「不知道」。你的目標是根據這些有限的答案,逐步推理出故事的完整脈絡,從而揭開事件的真相。
這個遊戲的名稱來自於其中一個最經典的題目,海龜湯的故事。由於這類型的遊戲強調水平思考,也就是用非傳統的方式解決問題,這些遊戲就被大家統稱為「海龜湯」,有點像是可樂成為所有碳酸飲料的代名詞。
在遊戲中,你的提問會讓你逐漸接近真相。準備好發揮你的推理能力,讓我們開始吧!
"""
class PuzzleGame:
def __init__(self):
"""
初始化遊戲類別。
"""
self.template = ' 根據判定規則,此玩家的猜測為[MASK]'
self.load_stories('stories.json')
def load_stories(self, path):
with open(path, mode='r', encoding='utf-8') as f:
self.stories = json.load(f)
def get_random_puzzle(self):
"""
隨機選擇一個謎題並設定當前謎題的標題、故事和答案。
"""
puzzle = random.choice(self.stories)
# puzzle = self.stories[1]
self.title = puzzle['title']
self.surface = puzzle['surface']
self.bottom = puzzle['bottom']
def get_prompt(self):
"""
返回填入謎題故事和答案的 prompt
"""
few_shot = "1.玩家猜測:賣給他貨的人不是老闆;你的回答:是\n2.玩家猜測:他被嚇傻和零食本身有關;你的回答:不\n3.玩家猜測:零食是合法的;你的回答:不知道"
prompt = f"你是遊戲的裁判,根據<湯麵>和<湯底>判斷玩家的猜測是否正確。你的回答只能是以下三種之一:1.是:玩家的猜測與故事相符。2.否:玩家的猜測與故事不符。3.不知道:無法從<湯麵>和<湯底>推理得出結論。注意:1. 玩家只能看到<湯麵>,你的判定也只能基於<湯麵>。2. 無法從故事中推理的問題,回答\"不知道\"。<湯麵>{self.surface}<湯底>{self.bottom}<範例>{few_shot}\n請判斷以下玩家猜測:"
return prompt
# 初始化遊戲
game = PuzzleGame()
def predict_masked_token(text):
"""
使用模型預測 [MASK] 位置的 token。
:param text: 包含 [MASK] 的文本。
:return: 預測的 token。
"""
inputs = tokenizer(text, return_tensors="pt")
model.eval()
with torch.no_grad():
outputs = model(**inputs)
# 找到 [MASK] 的位置
mask_token_index = (inputs.input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
if mask_token_index.numel() == 0:
raise "請確保輸入包含 [MASK]。"
# 獲取 [MASK] 位置的 logits 並預測
mask_token_logits = outputs.logits[0, mask_token_index, :]
predicted_id = torch.argmax(mask_token_logits, dim=-1)
return tokenizer.decode(predicted_id)
def restart():
"""
重新開始遊戲,初始化新的謎題。
返回故事的開頭內容。
"""
game.get_random_puzzle()
story = [{"role": "assistant", "content": game.surface}]
return story
def user(message, history):
"""
處理用戶輸入的問題或答案,並將其添加到對話歷史中。
:param message: 用戶的輸入消息。
:param history: 當前的對話歷史。
:return: 用戶的消息和更新後的歷史。
"""
history.append({"role": "user", "content": message})
return message, history
def check_question(question, history):
"""
使用 Masked Language Model 檢查用戶提問的問題。
:param question: 用戶的提問。
:param history: 當前的對話歷史。
:return: 空字符串和更新後的歷史。
"""
# 將問題與遊戲提示和模板結合
text = game.get_prompt() + question + game.template
predicted = predict_masked_token(text)
predicted_map = {
'是':'正確',
'否':'錯誤',
'不':'不知道'
}
history.append({"role": "assistant", "content": predicted_map[predicted]})
return "", history
def check_answer(answer, history):
"""
使用語義相似度檢查用戶輸入的答案是否正確。
:param answer: 用戶的答案。
:param history: 當前的對話歷史。
:return: 空字符串和更新後的歷史。
"""
sentences = [answer, game.bottom]
embeddings = answer_judge.encode(sentences)
sim = cos_sim([embeddings[0]], [embeddings[1]])
print("相似度: ", sim[0][0])
# 根據相似度生成回應
if sim[0][0] > 0.8:
response = "正確!你猜對了! 完整故事:\n" + game.bottom
elif sim[0][0] > 0.7:
response = "接近了!再試一次!"
else:
response = "錯誤!再試一次!"
history.append({"role": "assistant", "content": response})
return "", history
# 使用 Gradio 創建界面
with gr.Blocks() as demo:
# 頁面介紹
gr.Markdown(intro)
gr.Markdown("---")
# 初始化故事
story = restart()
chatbot = gr.Chatbot(type='messages', value=story, height=600)
# 問題提問功能
with gr.Tab("提出問題"):
question_input_box = gr.Textbox(
show_label=False,
placeholder="提問各種可能性的問題...",
submit_btn=True,
)
# 用戶輸入的文本框
# 1. 將用戶輸入的問題提交到 `user` 函數處理,將問題加入到歷史對話中。
# 2. 將 `user` 處理的結果(問題和更新後的歷史)傳遞給 `check_question` 函數。
# 3. `check_question` 會檢查問題並生成對應的回應,更新對話歷史。
question_input_box.submit(user, [question_input_box, chatbot], [question_input_box, chatbot]).then(
check_question, [question_input_box, chatbot], [question_input_box, chatbot]
)
# 答案輸入功能
with gr.Tab("輸入答案"):
answer_input_box = gr.Textbox(
show_label=False,
placeholder="請輸入你的答案...",
submit_btn=True,
)
# 用戶輸入的答案框
# 1. 將用戶輸入的答案提交到 `user` 函數處理,將答案加入到歷史對話中。
# 2. 將 `user` 處理的結果(答案和更新後的歷史)傳遞給 `check_answer` 函數。
# 3. `check_answer` 會檢查答案的正確性,生成對應的回應,並更新對話歷史。
answer_input_box.submit(user, [answer_input_box, chatbot], [answer_input_box, chatbot]).then(
check_answer, [answer_input_box, chatbot], [answer_input_box, chatbot]
)
# 重新開始按鈕
restart_btn = gr.ClearButton(value='重新開始新遊戲', inputs=[question_input_box, chatbot])
restart_btn.click(restart, outputs=[chatbot])
# 啟動應用
if __name__ == "__main__":
demo.launch()