masharpe's picture
Move style to CSS string
fad10f4
raw
history blame
6.8 kB
# Gradio demo of streaming generation of multiple LLM response pairs.
import spaces
import logging
import time
import html
import numpy as np
import gradio as gr
import util
import huggingface_hub
import torch
import transformers
import accelerate
# For setting `requirements.txt`.
print('Dependency versions:')
print(f'huggingface_hub=={huggingface_hub.__version__}')
print(f'numpy=={np.__version__}')
print(f'torch=={torch.__version__}')
print(f'transformers=={transformers.__version__}')
print(f'accelerate=={accelerate.__version__}')
print()
# Initialize logging.
logging.basicConfig(format='%(levelname)s:%(name)s: %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
# gr.DataFrame is currently bugged for updating values,
# so we must use raw HTML.
# https://github.com/gradio-app/gradio/issues/8160
css = '''
.response-table {
width: 100%;
table-layout: fixed;
}
.response-table th, .response-table td {
width: 50%;
}
.response-table td {
font-family: monospace;
white-space: pre-wrap;
text-align: left;
vertical-align: top;
}
.highlight {
background-color: #90FF90;
}
'''
def make_html_table(headers, data):
rows = ['<tr>' + ''.join(f'<th>{h}</th>' for h in headers) + '</tr>\n']
for row in data:
rows.append('<tr>' + ''.join(f'<td>{v}</td>' for v in row) + '</tr>\n')
return '<table class="response-table">\n' + ''.join(rows) + '</table>\n'
def highlight_prefix(tokens, prefix_len):
prefix_tokens = tokens[:prefix_len]
s = tokenizer.decode(tokens, skip_special_tokens=True)
prefix_s = tokenizer.decode(prefix_tokens, skip_special_tokens=True)
s_lcp_len = util.longest_common_prefix(np.array(list(s)), np.array(list(prefix_s)))
prefix_html = html.escape(s[:s_lcp_len])
suffix_html = html.escape(s[s_lcp_len:])
return f'<span class="highlight">{prefix_html}</span>{suffix_html}'
def format_response_pair(tokens_a, tokens_b):
# This is slightly convoluted, so as to properly handle grapheme clusters that span token boundaries.
token_lcp_len = util.longest_common_prefix(tokens_a, tokens_b)
return highlight_prefix(tokens_a, token_lcp_len), highlight_prefix(tokens_b, token_lcp_len)
HEADERS = ['Response (Left)', 'Response (Right)']
repo_id = "Qwen/Qwen2-0.5B-Instruct"
DRY_RUN = True
if DRY_RUN:
from load import load_tokenizer
tokenizer = load_tokenizer(repo_id)
def fn(max_tokens, num_responses, prompt_x, prompt_y):
logger.info('Starting generation...')
generation_start = time.perf_counter()
rows = [['']*2 for i in range(num_responses)]
yield make_html_table(HEADERS, rows)
for j in range(num_responses):
response_raw_a = f'Sure!\n\n1 2 3 4 & 5.'
response_raw_b = f'Sure!\n\n1 2 3 4 5 &\n\n\n\n6.'
response_tok_a = tokenizer.encode(response_raw_a, add_special_tokens=False, return_tensors='np')[0]
response_tok_b = tokenizer.encode(response_raw_b, add_special_tokens=False, return_tensors='np')[0]
steps = 1 + max(len(response_tok_a), len(response_tok_b))
for i in range(steps):
time.sleep(0.01)
prefix_tok_a = response_tok_a[:i]
prefix_tok_b = response_tok_b[:i]
content_a, content_b = format_response_pair(prefix_tok_a, prefix_tok_b)
rows[j][0] = content_a
rows[j][1] = content_b
yield make_html_table(HEADERS, rows)
generation_end = time.perf_counter()
logger.info(f'Generation took {(generation_end - generation_start):.3f} s')
else:
from load import load_model
import algorithms
#algorithms.logger.setLevel(logging.DEBUG)
model, tokenizer = load_model(repo_id)
def make_chat(system_msg, prompt):
chat = [
{
'role': 'system',
'content': system_msg,
},
{
'role': 'user',
'content': prompt,
},
]
return chat
@spaces.GPU
def fn(max_tokens, num_responses, prompt_x, prompt_y):
logger.info('Starting generation...')
generation_start = time.perf_counter()
# Is this necessary with ZeroGPU?
torch.use_deterministic_algorithms(True)
rows = [['']*2 for i in range(num_responses)]
yield make_html_table(HEADERS, rows)
for j in range(num_responses):
system_msg = "You are a helpful assistant."
chat_x = make_chat(system_msg, prompt_x)
chat_y = make_chat(system_msg, prompt_y)
gen = algorithms.apoc_streaming(
model,
model,
tokenizer,
chat_x,
chat_y,
max_tokens=max_tokens,
)
response_a_L = []
response_b_L = []
for token_a, token_b in gen:
dirty = False
if token_a is not None:
response_a_L.append(token_a)
dirty = True
if token_b is not None:
response_b_L.append(token_b)
dirty = True
if dirty:
content_a, content_b = format_response_pair(np.array(response_a_L), np.array(response_b_L))
rows[j][0] = content_a
rows[j][1] = content_b
yield make_html_table(HEADERS, rows)
generation_end = time.perf_counter()
logger.info(f'Generation took {(generation_end - generation_start):.3f} s')
demo = gr.Interface(
fn=fn,
inputs=[
gr.Slider(1, 512, label='Max Tokens', value=48),
gr.Slider(1, 16, step=1, label='Num Responses', value=8),
gr.Textbox(label='Prompt (Left)'),
gr.Textbox(label='Prompt (Right)'),
],
outputs=[
gr.HTML(),
],
css=css,
title='All-Prefix-Optimal Coupling',
description='Try similar prompts to see the effect of the difference between them. '
f'Model: `{repo_id}`.'
,
examples=[
[48, 8, 'Count from 1 to 5.', 'Count from 1 to 6.'],
# This would be a good example, but Qwen2-0.5B occasionally goes off-color.
#[48, 8, 'Tell me a joke.', 'Tell me a funny joke.'],
[48, 8, 'Calculate 3 + 4', 'Calculate 3 + 5'],
[48, 8, "What's the capital of Canada?", "What's the capital of France?"],
[48, 8, "1 3 5. What number is next?", "4 5 6. What number is next?"],
],
# In HuggingFace Spaces, this defaults to true, which makes startup
# take a very long time.
cache_examples=False,
)
demo.launch()