File size: 2,597 Bytes
a70311a 1be71e1 a70311a 1be71e1 a70311a 1be71e1 dd6a7c6 1be71e1 afae51e e72e55e a70311a 1be71e1 a70311a 1be71e1 a70311a afae51e a70311a 1be71e1 a70311a afae51e d7c49f6 a70311a 1be71e1 a70311a 1be71e1 a70311a 1be71e1 a70311a 1be71e1 d7c49f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import ctranslate2
from transformers import AutoTokenizer
import threading
import gradio as gr
from typing import Optional
from queue import Queue
class TokenIteratorStreamer:
def __init__(self, end_token_id: int, timeout: Optional[float] = None):
self.end_token_id = end_token_id
self.queue = Queue()
self.timeout = timeout
def put(self, token: int):
self.queue.put(token, timeout=self.timeout)
def __iter__(self):
return self
def __next__(self):
token = self.queue.get(timeout=self.timeout)
if token == self.end_token_id:
raise StopIteration()
else:
return token
def generate_prompt(history):
prompt = ""
for chain in history[:-1]:
prompt += f"<human>: {chain[0]}\n<bot>: {chain[1]}\n"
prompt += f"<human>: {history[-1][0]}\n<bot>:"
tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt))
return tokens
def generate(streamer, history):
def stepResultCallback(result):
streamer.put(result.token_id)
if result.is_last and (result.token_id != end_token_id):
streamer.put(end_token_id)
print(f"step={result.step}, batch_id={result.batch_id}, token={result.token}")
tokens = generate_prompt(history)
results = translator.translate_batch(
[tokens],
beam_size=1,
max_decoding_length = 256,
repetition_penalty = 1.8,
callback = stepResultCallback
)
return results
translator = ctranslate2.Translator("model", intra_threads=2)
tokenizer = AutoTokenizer.from_pretrained("DKYoon/mt5-xl-lm-adapt")
end_token = "</s>"
end_token_id = tokenizer.encode(end_token)[0]
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
def user(user_message, history):
return "", history + [[user_message, ""]]
def bot(history):
bot_message_tokens = []
streamer = TokenIteratorStreamer(end_token_id = end_token_id)
generation_thread = threading.Thread(target=generate, args=(streamer, history))
generation_thread.start()
for token in streamer:
bot_message_tokens.append(token)
history[-1][1] = tokenizer.decode(bot_message_tokens)
yield history
generation_thread.join()
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue()
if __name__ == "__main__":
demo.launch() |