|
import ctranslate2 |
|
from transformers import AutoTokenizer |
|
|
|
import threading |
|
import gradio as gr |
|
|
|
from typing import Optional |
|
from queue import Queue |
|
|
|
|
|
|
|
|
|
class TokenIteratorStreamer: |
|
def __init__(self, end_token_id: int, timeout: Optional[float] = None): |
|
self.end_token_id = end_token_id |
|
self.queue = Queue() |
|
self.timeout = timeout |
|
|
|
def put(self, token: int): |
|
self.queue.put(token, timeout=self.timeout) |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
token = self.queue.get(timeout=self.timeout) |
|
if token == self.end_token_id: |
|
raise StopIteration() |
|
else: |
|
return token |
|
|
|
|
|
|
|
def generate_prompt(history): |
|
prompt = "" |
|
for chain in history[:-1]: |
|
prompt += f"<human>: {chain[0]}\n<bot>: {chain[1]}\n" |
|
prompt += f"<human>: {history[-1][0]}\n<bot>:" |
|
tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt)) |
|
return tokens |
|
|
|
def generate(streamer, history): |
|
def stepResultCallback(result): |
|
streamer.put(result.token_id) |
|
if result.is_last and (result.token_id != end_token_id): |
|
streamer.put(end_token_id) |
|
print(f"step={result.step}, batch_id={result.batch_id}, token={result.token}") |
|
|
|
tokens = generate_prompt(history) |
|
|
|
results = translator.translate_batch( |
|
[tokens], |
|
beam_size=1, |
|
max_decoding_length = 256, |
|
repetition_penalty = 1.8, |
|
callback = stepResultCallback |
|
) |
|
return results |
|
|
|
|
|
|
|
translator = ctranslate2.Translator("model", intra_threads=2) |
|
tokenizer = AutoTokenizer.from_pretrained("DKYoon/mt5-xl-lm-adapt") |
|
end_token = "</s>" |
|
end_token_id = tokenizer.encode(end_token)[0] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
chatbot = gr.Chatbot() |
|
msg = gr.Textbox() |
|
clear = gr.Button("Clear") |
|
|
|
def user(user_message, history): |
|
return "", history + [[user_message, ""]] |
|
|
|
def bot(history): |
|
bot_message_tokens = [] |
|
streamer = TokenIteratorStreamer(end_token_id = end_token_id) |
|
generation_thread = threading.Thread(target=generate, args=(streamer, history)) |
|
generation_thread.start() |
|
|
|
for token in streamer: |
|
bot_message_tokens.append(token) |
|
history[-1][1] = tokenizer.decode(bot_message_tokens) |
|
yield history |
|
generation_thread.join() |
|
|
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
bot, chatbot, chatbot |
|
) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
demo.queue() |
|
if __name__ == "__main__": |
|
demo.launch() |