File size: 5,666 Bytes
bfcf71e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36113cf
ea06d0c
a7b08fa
 
ea06d0c
36113cf
 
bfcf71e
a7b08fa
 
 
bfcf71e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# Gradio demo of streaming generation of multiple LLM response pairs.

import logging
import time
import html
import numpy as np
import gradio as gr
import util

# gr.DataFrame is currently bugged for updating values,
# so we must use raw HTML.
# https://github.com/gradio-app/gradio/issues/8160
def make_html_table(headers, data):
    rows = ['<tr>' + ''.join(f'<th style="width: 50%">{h}</th>' for h in headers) + '</tr>\n']
    for row in data:
        rows.append('<tr>' + ''.join(f'<td style="width: 50%; font-family: monospace; white-space: pre-wrap;">{v}</td>' for v in row) + '</tr>\n')
    return '<table style="width: 100%; table-layout: fixed">\n' + ''.join(rows) + '</table>\n'

def highlight_prefix(tokens, prefix_len):
    prefix_tokens = tokens[:prefix_len]

    s = tokenizer.decode(tokens, skip_special_tokens=True)
    prefix_s = tokenizer.decode(prefix_tokens, skip_special_tokens=True)

    s_lcp_len = util.longest_common_prefix(np.array(list(s)), np.array(list(prefix_s)))

    prefix_html = html.escape(s[:s_lcp_len])
    suffix_html = html.escape(s[s_lcp_len:])

    #highlight_style = 'background-color: #FFFFAE;'
    #highlight_style = 'text-decoration: underline;'
    highlight_style = 'background-color: #90FF90;'

    return f'<span style="{highlight_style}">{prefix_html}</span>{suffix_html}'

def format_response_pair(tokens_a, tokens_b):
    # This is slightly convoluted, so as to properly handle grapheme clusters that span token boundaries.
    token_lcp_len = util.longest_common_prefix(tokens_a, tokens_b)
    return highlight_prefix(tokens_a, token_lcp_len), highlight_prefix(tokens_b, token_lcp_len)

HEADERS = ['Response (Left)', 'Response (Right)']
repo_id = "Qwen/Qwen2-0.5B-Instruct"

DRY_RUN = False

if DRY_RUN:
    from load import load_tokenizer

    tokenizer = load_tokenizer(repo_id)

    def fn(max_tokens, num_responses, prompt_x, prompt_y):
        rows = [['']*2 for i in range(num_responses)]
        
        yield make_html_table(HEADERS, rows)

        for j in range(num_responses):
            response_raw_a = f'Sure!\n\n1 2 3 4 & 5.'
            response_raw_b = f'Sure!\n\n1 2 3 4 5 & 6.'

            response_tok_a = tokenizer.encode(response_raw_a, add_special_tokens=False, return_tensors='np')[0]
            response_tok_b = tokenizer.encode(response_raw_b, add_special_tokens=False, return_tensors='np')[0]

            steps = 1 + max(len(response_tok_a), len(response_tok_b))

            for i in range(steps):
                time.sleep(0.1)
                prefix_tok_a = response_tok_a[:i]
                prefix_tok_b = response_tok_b[:i]

                content_a, content_b = format_response_pair(prefix_tok_a, prefix_tok_b)

                rows[j][0] = content_a
                rows[j][1] = content_b

                yield make_html_table(HEADERS, rows)
else:
    from load import load_model
    import algorithms

    logging.basicConfig(format='%(levelname)s:%(name)s: %(message)s')
    algorithms.logger.setLevel(logging.INFO)

    model, tokenizer = load_model(repo_id)

    def make_chat(system_msg, prompt):
        chat = [
                {
                    'role': 'system',
                    'content': system_msg,
                },
                {
                    'role': 'user',
                    'content': prompt,
                },
        ]
        return chat

    def fn(max_tokens, num_responses, prompt_x, prompt_y):
        rows = [['']*2 for i in range(num_responses)]
        yield make_html_table(HEADERS, rows)

        for j in range(num_responses):
            system_msg = "You are a helpful assistant."

            chat_x = make_chat(system_msg, prompt_x)
            chat_y = make_chat(system_msg, prompt_y)

            gen = algorithms.apoc_streaming(
                'cpu',
                model,
                model,
                tokenizer,
                chat_x,
                chat_y,
                max_tokens=max_tokens,
            )
            response_a_L = []
            response_b_L = []
            for token_a, token_b in gen:
                dirty = False
                if token_a is not None:
                    response_a_L.append(token_a)
                    dirty = True
                if token_b is not None:
                    response_b_L.append(token_b)
                    dirty = True
                
                if dirty:
                    content_a, content_b = format_response_pair(np.array(response_a_L), np.array(response_b_L))

                    rows[j][0] = content_a
                    rows[j][1] = content_b
                
                yield make_html_table(HEADERS, rows)

demo = gr.Interface(
    fn=fn,
    inputs=[
        gr.Slider(1, 512, label='Max Tokens', value=48),
        gr.Slider(1, 16, step=1, label='Num Responses', value=8),
        gr.Textbox(label='Prompt (Left)'),
        gr.Textbox(label='Prompt (Right)'),
        ],
    outputs=[
        gr.HTML(),
        ],
    title='All-Prefix-Optimal Coupling',
    description='Try similar prompts to see the effect of the difference between them. '
        f'Model: `{repo_id}`.'
        ,
    examples=[
        [48, 8, 'Count from 1 to 5.', 'Count from 1 to 6.'],

        # This would be a good example, but Qwen2-0.5B occasionally goes off-color.
        #[48, 8, 'Tell me a joke.', 'Tell me a funny joke.'],

        [48, 8, 'Calculate 3 + 4', 'Calculate 3 + 5'],
        [48, 8, "What's the capital of Canada?", "What's the capital of France?"],
    ],
    # In HuggingFace Spaces, this defaults to true, which makes startup
    # take a very long time.
    cache_examples=False,
    )

demo.launch()