import asyncio import gradio as gr import transformers from transformers import ( TextIteratorStreamer, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, ) import threading import ctypes tokenizer = AutoTokenizer.from_pretrained("pfnet/plamo-2-1b", trust_remote_code=True) pipeline = transformers.pipeline( "text-generation", model="pfnet/plamo-2-1b", trust_remote_code=True, ) class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops=[], encounters=1): super().__init__() self.stops = stops def __call__(self, input_ids, scores): last_token = input_ids[0][-2:] for stop in self.stops: if stop in tokenizer.decode(last_token): return True return False stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=["\n\n"])]) class CancelableThread(threading.Thread): def __init__(self, group=None, target=None, name=None, args=(), kwargs={}): threading.Thread.__init__(self, group=group, target=target, name=name) self.args = args self.kwargs = kwargs return def run(self): self.id = threading.get_native_id() self._target(*self.args, **self.kwargs) def get_id(self): return self.id def raise_exception(self): thread_id = self.get_id() resu = ctypes.pythonapi.PyThreadState_SetAsyncExc( ctypes.c_long(thread_id), ctypes.py_object(SystemExit) ) if resu > 1: ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(thread_id), 0) print("Failure in raising exception") class ThreadManager: def __init__(self, thread: CancelableThread, **kwargs): self.thread = thread def __enter__(self): # スレッドを開始 self.thread.start() return self.thread def __exit__(self, exc_type, exc_value, traceback): # スレッドの終了を待機 if self.thread.is_alive(): print("trying to terminate thread") self.thread.raise_exception() self.thread.join() print("Thread has been successfully joined.") def respond(prompt, max_tokens): # print(prompt) streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) thread = CancelableThread( target=pipeline, kwargs=dict( text_inputs=prompt, max_new_tokens=max_tokens, return_full_text=False, streamer=streamer, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, stopping_criteria=stopping_criteria, ), ) response = "" with ThreadManager(thread=thread): for output in streamer: if not output: continue # print(output) response += output yield response, gr.update(interactive=False), gr.update(interactive=False), yield ( response, gr.update(interactive=True), gr.update(interactive=True), ) def reset_textbox(): return gr.update(value=""), gr.update(value="") def no_interactive(): return gr.update(interactive=False), gr.update(interactive=False) with gr.Blocks() as demo: gr.HTML("""

plamo-2-1b CPU demo

""") gr.Markdown( "2 vCPU, 16 GB RAMでのデモです。10年前くらいのノートパソコンくらい。(GPUなしのHugging Faceの無料インスタンスで動いています。)vllmとかllama.cppが対応すればもっと高速に動くはず。" ) with gr.Column(elem_id="col_container") as main_block: with gr.Row(): with gr.Column(): input_text = gr.Textbox( lines=15, label="input_text", placeholder="これからの人工知能技術は" ) with gr.Row(): with gr.Column(scale=3): clear_button = gr.Button("Clear") with gr.Column(scale=5): submit_button = gr.Button("Submit") outputs = gr.Textbox(lines=20, label="Output") # inputs, top_p, temperature, top_k, repetition_penalty with gr.Accordion("Parameters", open=False): max_tokens = gr.Slider( minimum=1, maximum=4096, value=32, step=1, label="Max new tokens" ) submit_button.click(no_interactive, [], [submit_button, clear_button]) submit_button.click( respond, [input_text, max_tokens], [outputs, submit_button, clear_button], ) clear_button.click(reset_textbox, [], [input_text, outputs], queue=False) demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860) if __name__ == "__main__": demo.launch()