|
import gradio as gr |
|
from transformers import BertTokenizerFast, BertForMaskedLM |
|
import torch |
|
from sentence_transformers import SentenceTransformer |
|
from sentence_transformers.util import cos_sim |
|
import random |
|
import json |
|
|
|
model_name = "nycu-ai113-dl-final-project/bert-turtle-soup-pet-zh" |
|
dataset_name = "nycu-ai113-dl-final-project/TurtleBench-extended-zh" |
|
|
|
model = BertForMaskedLM.from_pretrained(model_name) |
|
tokenizer = BertTokenizerFast.from_pretrained(model_name) |
|
|
|
answer_judge = SentenceTransformer('thenlper/gte-base-zh') |
|
|
|
intro=""" |
|
### 玩法介紹 |
|
|
|
遊戲一開始,我會給你一個不完整的故事,這個故事通常有很多未知的細節,你需要透過提出問題來探索更多線索。你可以問我各種問題,不過請記住,我只能回答三種答案:「正確」、「錯誤」或「不知道」。你的目標是根據這些有限的答案,逐步推理出故事的完整脈絡,從而揭開事件的真相。 |
|
|
|
這個遊戲的名稱來自於其中一個最經典的題目,海龜湯的故事。由於這類型的遊戲強調水平思考,也就是用非傳統的方式解決問題,這些遊戲就被大家統稱為「海龜湯」,有點像是可樂成為所有碳酸飲料的代名詞。 |
|
|
|
在遊戲中,你的提問會讓你逐漸接近真相。準備好發揮你的推理能力,讓我們開始吧! |
|
""" |
|
|
|
class PuzzleGame: |
|
def __init__(self): |
|
""" |
|
初始化遊戲類別。 |
|
""" |
|
self.template = ' 根據判定規則,此玩家的猜測為[MASK]' |
|
self.load_stories('stories.json') |
|
|
|
def load_stories(self, path): |
|
with open(path, mode='r', encoding='utf-8') as f: |
|
self.stories = json.load(f) |
|
|
|
def get_random_puzzle(self): |
|
""" |
|
隨機選擇一個謎題並設定當前謎題的標題、故事和答案。 |
|
""" |
|
puzzle = random.choice(self.stories) |
|
|
|
self.title = puzzle['title'] |
|
self.surface = puzzle['surface'] |
|
self.bottom = puzzle['bottom'] |
|
|
|
def get_prompt(self): |
|
""" |
|
返回填入謎題故事和答案的 prompt |
|
""" |
|
few_shot = "1.玩家猜測:賣給他貨的人不是老闆;你的回答:是\n2.玩家猜測:他被嚇傻和零食本身有關;你的回答:不\n3.玩家猜測:零食是合法的;你的回答:不知道" |
|
|
|
prompt = f"你是遊戲的裁判,根據<湯麵>和<湯底>判斷玩家的猜測是否正確。你的回答只能是以下三種之一:1.是:玩家的猜測與故事相符。2.否:玩家的猜測與故事不符。3.不知道:無法從<湯麵>和<湯底>推理得出結論。注意:1. 玩家只能看到<湯麵>,你的判定也只能基於<湯麵>。2. 無法從故事中推理的問題,回答\"不知道\"。<湯麵>{self.surface}<湯底>{self.bottom}<範例>{few_shot}\n請判斷以下玩家猜測:" |
|
|
|
return prompt |
|
|
|
|
|
game = PuzzleGame() |
|
|
|
def predict_masked_token(text): |
|
""" |
|
使用模型預測 [MASK] 位置的 token。 |
|
:param text: 包含 [MASK] 的文本。 |
|
:return: 預測的 token。 |
|
""" |
|
inputs = tokenizer(text, return_tensors="pt") |
|
model.eval() |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
mask_token_index = (inputs.input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)[1] |
|
|
|
if mask_token_index.numel() == 0: |
|
raise "請確保輸入包含 [MASK]。" |
|
|
|
|
|
mask_token_logits = outputs.logits[0, mask_token_index, :] |
|
predicted_id = torch.argmax(mask_token_logits, dim=-1) |
|
return tokenizer.decode(predicted_id) |
|
|
|
def restart(): |
|
""" |
|
重新開始遊戲,初始化新的謎題。 |
|
返回故事的開頭內容。 |
|
""" |
|
game.get_random_puzzle() |
|
story = [{"role": "assistant", "content": game.surface}] |
|
return story |
|
|
|
def user(message, history): |
|
""" |
|
處理用戶輸入的問題或答案,並將其添加到對話歷史中。 |
|
:param message: 用戶的輸入消息。 |
|
:param history: 當前的對話歷史。 |
|
:return: 用戶的消息和更新後的歷史。 |
|
""" |
|
history.append({"role": "user", "content": message}) |
|
return message, history |
|
|
|
def check_question(question, history): |
|
""" |
|
使用 Masked Language Model 檢查用戶提問的問題。 |
|
:param question: 用戶的提問。 |
|
:param history: 當前的對話歷史。 |
|
:return: 空字符串和更新後的歷史。 |
|
""" |
|
|
|
text = game.get_prompt() + question + game.template |
|
predicted = predict_masked_token(text) |
|
|
|
predicted_map = { |
|
'是':'正確', |
|
'否':'錯誤', |
|
'不':'不知道' |
|
} |
|
|
|
history.append({"role": "assistant", "content": predicted_map[predicted]}) |
|
return "", history |
|
|
|
def check_answer(answer, history): |
|
""" |
|
使用語義相似度檢查用戶輸入的答案是否正確。 |
|
:param answer: 用戶的答案。 |
|
:param history: 當前的對話歷史。 |
|
:return: 空字符串和更新後的歷史。 |
|
""" |
|
sentences = [answer, game.bottom] |
|
embeddings = answer_judge.encode(sentences) |
|
|
|
sim = cos_sim([embeddings[0]], [embeddings[1]]) |
|
|
|
print("相似度: ", sim[0][0]) |
|
|
|
|
|
if sim[0][0] > 0.8: |
|
response = "正確!你猜對了! 完整故事:\n" + game.bottom |
|
elif sim[0][0] > 0.7: |
|
response = "接近了!再試一次!" |
|
else: |
|
response = "錯誤!再試一次!" |
|
|
|
history.append({"role": "assistant", "content": response}) |
|
|
|
return "", history |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown(intro) |
|
gr.Markdown("---") |
|
|
|
|
|
story = restart() |
|
chatbot = gr.Chatbot(type='messages', value=story, height=600) |
|
|
|
|
|
with gr.Tab("提出問題"): |
|
question_input_box = gr.Textbox( |
|
show_label=False, |
|
placeholder="提問各種可能性的問題...", |
|
submit_btn=True, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
question_input_box.submit(user, [question_input_box, chatbot], [question_input_box, chatbot]).then( |
|
check_question, [question_input_box, chatbot], [question_input_box, chatbot] |
|
) |
|
|
|
|
|
with gr.Tab("輸入答案"): |
|
answer_input_box = gr.Textbox( |
|
show_label=False, |
|
placeholder="請輸入你的答案...", |
|
submit_btn=True, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
answer_input_box.submit(user, [answer_input_box, chatbot], [answer_input_box, chatbot]).then( |
|
check_answer, [answer_input_box, chatbot], [answer_input_box, chatbot] |
|
) |
|
|
|
|
|
restart_btn = gr.ClearButton(value='重新開始新遊戲', inputs=[question_input_box, chatbot]) |
|
restart_btn.click(restart, outputs=[chatbot]) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|