pythia-uk / app.py
theodotus's picture
Remove doubled BOS
dd6a7c6
raw
history blame
2.63 kB
import ctranslate2
from transformers import AutoTokenizer
import threading
import gradio as gr
from typing import Optional
from queue import Queue
class TokenIteratorStreamer:
def __init__(self, end_token_id: int, timeout: Optional[float] = None):
self.end_token_id = end_token_id
self.queue = Queue()
self.timeout = timeout
def put(self, token: int):
self.queue.put(token, timeout=self.timeout)
def __iter__(self):
return self
def __next__(self):
token = self.queue.get(timeout=self.timeout)
if token == self.end_token_id:
raise StopIteration()
else:
return token
def generate_prompt(history):
prompt = ""
for chain in history[:-1]:
prompt += f"<human>: {chain[0]}\n<bot>: {chain[1]}{end_token}\n"
prompt += f"<human>: {history[-1][0]}\n<bot>:"
tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt))
return tokens
def generate(streamer, history):
def stepResultCallback(result):
streamer.put(result.token_id)
if result.is_last and (result.token_id != end_token_id):
streamer.put(end_token_id)
print(f"step={result.step}, batch_id={result.batch_id}, token={result.token}")
tokens = generate_prompt(history)
results = translator.translate_batch(
[tokens],
beam_size=1,
max_decoding_length = 256,
repetition_penalty = 1.2,
callback = stepResultCallback
)
return results
translator = ctranslate2.Translator("model", intra_threads=2)
tokenizer = AutoTokenizer.from_pretrained("openlm-research/open_llama_7b", use_fast=False)
end_token = "</s>"
end_token_id = tokenizer.encode(end_token)[0]
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
def user(user_message, history):
return "", history + [[user_message, ""]]
def bot(history):
bot_message_tokens = []
streamer = TokenIteratorStreamer(end_token_id = end_token_id)
generation_thread = threading.Thread(target=generate, args=(streamer, history))
generation_thread.start()
for token in streamer:
bot_message_tokens.append(token)
history[-1][1] = tokenizer.decode(bot_message_tokens)
yield history
generation_thread.join()
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue()
if __name__ == "__main__":
demo.launch()