File size: 7,670 Bytes
e57a32e 778c9c3 7bf67d3 778c9c3 e57a32e 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 9e506cf 7bf67d3 e57a32e 7bf67d3 778c9c3 7bf67d3 778c9c3 9e506cf 7bf67d3 778c9c3 7bf67d3 9e506cf 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 9e506cf 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 7bf67d3 778c9c3 e57a32e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import gradio as gr
from transformers import BertTokenizerFast, BertForMaskedLM
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import random
import json
model_name = "nycu-ai113-dl-final-project/bert-turtle-soup-pet-zh"
dataset_name = "nycu-ai113-dl-final-project/TurtleBench-extended-zh"
model = BertForMaskedLM.from_pretrained(model_name)
tokenizer = BertTokenizerFast.from_pretrained(model_name)
answer_judge = SentenceTransformer('thenlper/gte-base-zh')
intro="""
### 玩法介紹
遊戲一開始,我會給你一個不完整的故事,這個故事通常有很多未知的細節,你需要透過提出問題來探索更多線索。你可以問我各種問題,不過請記住,我只能回答三種答案:「正確」、「錯誤」或「不知道」。你的目標是根據這些有限的答案,逐步推理出故事的完整脈絡,從而揭開事件的真相。
這個遊戲的名稱來自於其中一個最經典的題目,海龜湯的故事。由於這類型的遊戲強調水平思考,也就是用非傳統的方式解決問題,這些遊戲就被大家統稱為「海龜湯」,有點像是可樂成為所有碳酸飲料的代名詞。
在遊戲中,你的提問會讓你逐漸接近真相。準備好發揮你的推理能力,讓我們開始吧!
"""
class PuzzleGame:
def __init__(self):
"""
初始化遊戲類別。
"""
self.template = ' 根據判定規則,此玩家的猜測為[MASK]'
self.load_stories('stories.json')
def load_stories(self, path):
with open(path, mode='r', encoding='utf-8') as f:
self.stories = json.load(f)
def get_random_puzzle(self):
"""
隨機選擇一個謎題並設定當前謎題的標題、故事和答案。
"""
puzzle = random.choice(self.stories)
# puzzle = self.stories[1]
self.title = puzzle['title']
self.surface = puzzle['surface']
self.bottom = puzzle['bottom']
def get_prompt(self):
"""
返回填入謎題故事和答案的 prompt
"""
few_shot = "1.玩家猜測:賣給他貨的人不是老闆;你的回答:是\n2.玩家猜測:他被嚇傻和零食本身有關;你的回答:不\n3.玩家猜測:零食是合法的;你的回答:不知道"
prompt = f"你是遊戲的裁判,根據<湯麵>和<湯底>判斷玩家的猜測是否正確。你的回答只能是以下三種之一:1.是:玩家的猜測與故事相符。2.否:玩家的猜測與故事不符。3.不知道:無法從<湯麵>和<湯底>推理得出結論。注意:1. 玩家只能看到<湯麵>,你的判定也只能基於<湯麵>。2. 無法從故事中推理的問題,回答\"不知道\"。<湯麵>{self.surface}<湯底>{self.bottom}<範例>{few_shot}\n請判斷以下玩家猜測:"
return prompt
# 初始化遊戲
game = PuzzleGame()
def predict_masked_token(text):
"""
使用模型預測 [MASK] 位置的 token。
:param text: 包含 [MASK] 的文本。
:return: 預測的 token。
"""
inputs = tokenizer(text, return_tensors="pt")
model.eval()
with torch.no_grad():
outputs = model(**inputs)
# 找到 [MASK] 的位置
mask_token_index = (inputs.input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
if mask_token_index.numel() == 0:
raise "請確保輸入包含 [MASK]。"
# 獲取 [MASK] 位置的 logits 並預測
mask_token_logits = outputs.logits[0, mask_token_index, :]
predicted_id = torch.argmax(mask_token_logits, dim=-1)
return tokenizer.decode(predicted_id)
def restart():
"""
重新開始遊戲,初始化新的謎題。
返回故事的開頭內容。
"""
game.get_random_puzzle()
story = [{"role": "assistant", "content": game.surface}]
return story
def user(message, history):
"""
處理用戶輸入的問題或答案,並將其添加到對話歷史中。
:param message: 用戶的輸入消息。
:param history: 當前的對話歷史。
:return: 用戶的消息和更新後的歷史。
"""
history.append({"role": "user", "content": message})
return message, history
def check_question(question, history):
"""
使用 Masked Language Model 檢查用戶提問的問題。
:param question: 用戶的提問。
:param history: 當前的對話歷史。
:return: 空字符串和更新後的歷史。
"""
# 將問題與遊戲提示和模板結合
text = game.get_prompt() + question + game.template
predicted = predict_masked_token(text)
predicted_map = {
'是':'正確',
'否':'錯誤',
'不':'不知道'
}
history.append({"role": "assistant", "content": predicted_map[predicted]})
return "", history
def check_answer(answer, history):
"""
使用語義相似度檢查用戶輸入的答案是否正確。
:param answer: 用戶的答案。
:param history: 當前的對話歷史。
:return: 空字符串和更新後的歷史。
"""
sentences = [answer, game.bottom]
embeddings = answer_judge.encode(sentences)
sim = cos_sim([embeddings[0]], [embeddings[1]])
print("相似度: ", sim[0][0])
# 根據相似度生成回應
if sim[0][0] > 0.8:
response = "正確!你猜對了! 完整故事:\n" + game.bottom
elif sim[0][0] > 0.7:
response = "接近了!再試一次!"
else:
response = "錯誤!再試一次!"
history.append({"role": "assistant", "content": response})
return "", history
# 使用 Gradio 創建界面
with gr.Blocks() as demo:
# 頁面介紹
gr.Markdown(intro)
gr.Markdown("---")
# 初始化故事
story = restart()
chatbot = gr.Chatbot(type='messages', value=story, height=600)
# 問題提問功能
with gr.Tab("提出問題"):
question_input_box = gr.Textbox(
show_label=False,
placeholder="提問各種可能性的問題...",
submit_btn=True,
)
# 用戶輸入的文本框
# 1. 將用戶輸入的問題提交到 `user` 函數處理,將問題加入到歷史對話中。
# 2. 將 `user` 處理的結果(問題和更新後的歷史)傳遞給 `check_question` 函數。
# 3. `check_question` 會檢查問題並生成對應的回應,更新對話歷史。
question_input_box.submit(user, [question_input_box, chatbot], [question_input_box, chatbot]).then(
check_question, [question_input_box, chatbot], [question_input_box, chatbot]
)
# 答案輸入功能
with gr.Tab("輸入答案"):
answer_input_box = gr.Textbox(
show_label=False,
placeholder="請輸入你的答案...",
submit_btn=True,
)
# 用戶輸入的答案框
# 1. 將用戶輸入的答案提交到 `user` 函數處理,將答案加入到歷史對話中。
# 2. 將 `user` 處理的結果(答案和更新後的歷史)傳遞給 `check_answer` 函數。
# 3. `check_answer` 會檢查答案的正確性,生成對應的回應,並更新對話歷史。
answer_input_box.submit(user, [answer_input_box, chatbot], [answer_input_box, chatbot]).then(
check_answer, [answer_input_box, chatbot], [answer_input_box, chatbot]
)
# 重新開始按鈕
restart_btn = gr.ClearButton(value='重新開始新遊戲', inputs=[question_input_box, chatbot])
restart_btn.click(restart, outputs=[chatbot])
# 啟動應用
if __name__ == "__main__":
demo.launch()
|