File size: 7,670 Bytes
e57a32e
778c9c3
7bf67d3
 
778c9c3
 
 
e57a32e
778c9c3
 
7bf67d3
778c9c3
 
7bf67d3
778c9c3
7bf67d3
 
 
 
9e506cf
7bf67d3
 
 
 
e57a32e
7bf67d3
 
 
 
 
 
778c9c3
 
 
 
 
 
7bf67d3
 
 
 
 
778c9c3
9e506cf
7bf67d3
778c9c3
 
7bf67d3
 
 
 
 
9e506cf
 
 
7bf67d3
778c9c3
7bf67d3
778c9c3
 
7bf67d3
778c9c3
 
 
 
 
 
 
 
 
 
7bf67d3
778c9c3
 
7bf67d3
778c9c3
 
7bf67d3
778c9c3
 
 
 
7bf67d3
 
778c9c3
 
 
 
7bf67d3
778c9c3
7bf67d3
 
 
778c9c3
 
 
 
 
 
7bf67d3
 
 
 
778c9c3
 
 
 
 
 
 
 
 
 
 
9e506cf
 
778c9c3
 
 
 
7bf67d3
 
 
778c9c3
 
 
 
 
 
 
7bf67d3
 
778c9c3
7bf67d3
778c9c3
7bf67d3
778c9c3
 
 
 
7bf67d3
 
 
 
 
 
 
 
778c9c3
7bf67d3
778c9c3
7bf67d3
 
778c9c3
 
7bf67d3
 
 
778c9c3
7bf67d3
 
 
 
 
 
 
778c9c3
 
 
 
7bf67d3
 
 
 
778c9c3
7bf67d3
 
 
 
 
 
 
778c9c3
 
 
 
7bf67d3
 
 
 
778c9c3
7bf67d3
 
 
778c9c3
e57a32e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import gradio as gr
from transformers import BertTokenizerFast, BertForMaskedLM
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import random
import json

model_name = "nycu-ai113-dl-final-project/bert-turtle-soup-pet-zh"
dataset_name = "nycu-ai113-dl-final-project/TurtleBench-extended-zh"

model = BertForMaskedLM.from_pretrained(model_name)
tokenizer = BertTokenizerFast.from_pretrained(model_name)

answer_judge = SentenceTransformer('thenlper/gte-base-zh')

intro="""
### 玩法介紹

遊戲一開始,我會給你一個不完整的故事,這個故事通常有很多未知的細節,你需要透過提出問題來探索更多線索。你可以問我各種問題,不過請記住,我只能回答三種答案:「正確」、「錯誤」或「不知道」。你的目標是根據這些有限的答案,逐步推理出故事的完整脈絡,從而揭開事件的真相。

這個遊戲的名稱來自於其中一個最經典的題目,海龜湯的故事。由於這類型的遊戲強調水平思考,也就是用非傳統的方式解決問題,這些遊戲就被大家統稱為「海龜湯」,有點像是可樂成為所有碳酸飲料的代名詞。

在遊戲中,你的提問會讓你逐漸接近真相。準備好發揮你的推理能力,讓我們開始吧!
"""

class PuzzleGame:
    def __init__(self):
        """
        初始化遊戲類別。
        """
        self.template = ' 根據判定規則,此玩家的猜測為[MASK]'
        self.load_stories('stories.json')

    def load_stories(self, path):
        with open(path, mode='r', encoding='utf-8') as f:
            self.stories = json.load(f)

    def get_random_puzzle(self):
        """
        隨機選擇一個謎題並設定當前謎題的標題、故事和答案。
        """
        puzzle = random.choice(self.stories)
        # puzzle = self.stories[1]
        self.title = puzzle['title']
        self.surface = puzzle['surface']
        self.bottom = puzzle['bottom']

    def get_prompt(self):
        """
        返回填入謎題故事和答案的 prompt
        """
        few_shot = "1.玩家猜測:賣給他貨的人不是老闆;你的回答:是\n2.玩家猜測:他被嚇傻和零食本身有關;你的回答:不\n3.玩家猜測:零食是合法的;你的回答:不知道"

        prompt = f"你是遊戲的裁判,根據<湯麵>和<湯底>判斷玩家的猜測是否正確。你的回答只能是以下三種之一:1.是:玩家的猜測與故事相符。2.否:玩家的猜測與故事不符。3.不知道:無法從<湯麵>和<湯底>推理得出結論。注意:1. 玩家只能看到<湯麵>,你的判定也只能基於<湯麵>。2. 無法從故事中推理的問題,回答\"不知道\"。<湯麵>{self.surface}<湯底>{self.bottom}<範例>{few_shot}\n請判斷以下玩家猜測:"

        return prompt

# 初始化遊戲
game = PuzzleGame()

def predict_masked_token(text):
    """
    使用模型預測 [MASK] 位置的 token。
    :param text: 包含 [MASK] 的文本。
    :return: 預測的 token。
    """
    inputs = tokenizer(text, return_tensors="pt")
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)

    # 找到 [MASK] 的位置
    mask_token_index = (inputs.input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]

    if mask_token_index.numel() == 0:
        raise "請確保輸入包含 [MASK]。"

    # 獲取 [MASK] 位置的 logits 並預測
    mask_token_logits = outputs.logits[0, mask_token_index, :]
    predicted_id = torch.argmax(mask_token_logits, dim=-1)
    return tokenizer.decode(predicted_id)

def restart():
    """
    重新開始遊戲,初始化新的謎題。
    返回故事的開頭內容。
    """
    game.get_random_puzzle()
    story = [{"role": "assistant", "content": game.surface}]
    return story

def user(message, history):
    """
    處理用戶輸入的問題或答案,並將其添加到對話歷史中。
    :param message: 用戶的輸入消息。
    :param history: 當前的對話歷史。
    :return: 用戶的消息和更新後的歷史。
    """
    history.append({"role": "user", "content": message})
    return message, history

def check_question(question, history):
    """
    使用 Masked Language Model 檢查用戶提問的問題。
    :param question: 用戶的提問。
    :param history: 當前的對話歷史。
    :return: 空字符串和更新後的歷史。
    """
    # 將問題與遊戲提示和模板結合
    text = game.get_prompt() + question + game.template
    predicted = predict_masked_token(text)

    predicted_map = {
        '是':'正確',
        '否':'錯誤',
        '不':'不知道'
    }

    history.append({"role": "assistant", "content": predicted_map[predicted]})
    return "", history

def check_answer(answer, history):
    """
    使用語義相似度檢查用戶輸入的答案是否正確。
    :param answer: 用戶的答案。
    :param history: 當前的對話歷史。
    :return: 空字符串和更新後的歷史。
    """
    sentences = [answer, game.bottom]
    embeddings = answer_judge.encode(sentences)

    sim = cos_sim([embeddings[0]], [embeddings[1]])

    print("相似度: ", sim[0][0])

    # 根據相似度生成回應
    if sim[0][0] > 0.8:
        response = "正確!你猜對了! 完整故事:\n" + game.bottom
    elif sim[0][0] > 0.7:
        response = "接近了!再試一次!"
    else:
        response = "錯誤!再試一次!"
    
    history.append({"role": "assistant", "content": response})

    return "", history

# 使用 Gradio 創建界面
with gr.Blocks() as demo:
    # 頁面介紹
    gr.Markdown(intro)
    gr.Markdown("---")

    # 初始化故事
    story = restart()
    chatbot = gr.Chatbot(type='messages', value=story, height=600)
    
    # 問題提問功能
    with gr.Tab("提出問題"):
        question_input_box = gr.Textbox(
            show_label=False,
            placeholder="提問各種可能性的問題...",
            submit_btn=True,
        )

        # 用戶輸入的文本框
        # 1. 將用戶輸入的問題提交到 `user` 函數處理,將問題加入到歷史對話中。
        # 2. 將 `user` 處理的結果(問題和更新後的歷史)傳遞給 `check_question` 函數。
        # 3. `check_question` 會檢查問題並生成對應的回應,更新對話歷史。
        question_input_box.submit(user, [question_input_box, chatbot], [question_input_box, chatbot]).then(
            check_question, [question_input_box, chatbot], [question_input_box, chatbot]
        )

    # 答案輸入功能
    with gr.Tab("輸入答案"):
        answer_input_box = gr.Textbox(
            show_label=False,
            placeholder="請輸入你的答案...",
            submit_btn=True,
        )

        # 用戶輸入的答案框
        # 1. 將用戶輸入的答案提交到 `user` 函數處理,將答案加入到歷史對話中。
        # 2. 將 `user` 處理的結果(答案和更新後的歷史)傳遞給 `check_answer` 函數。
        # 3. `check_answer` 會檢查答案的正確性,生成對應的回應,並更新對話歷史。
        answer_input_box.submit(user, [answer_input_box, chatbot], [answer_input_box, chatbot]).then(
            check_answer, [answer_input_box, chatbot], [answer_input_box, chatbot]
        )

    # 重新開始按鈕
    restart_btn = gr.ClearButton(value='重新開始新遊戲', inputs=[question_input_box, chatbot])
    restart_btn.click(restart, outputs=[chatbot])

# 啟動應用
if __name__ == "__main__":
    demo.launch()