import os from threading import Event, Thread from transformers import ( AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, ) from huggingface_hub import login import gradio as gr import torch login(os.getenv("HF_TOKEN", None)) model_name = "richardr1126/spider-natsql-wizard-coder-8bit" tok = AutoTokenizer.from_pretrained(model_name) max_new_tokens = 1536 print(f"Starting to load the model {model_name}") m = AutoModelForCausalLM.from_pretrained( model_name, device_map=0, load_in_8bit=True, ) m.config.pad_token_id = m.config.eos_token_id m.generation_config.pad_token_id = m.config.eos_token_id stop_tokens = [";", "###", "Result"] stop_token_ids = tok.convert_tokens_to_ids(stop_tokens) print(f"Successfully loaded the model {model_name} into memory") class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: for stop_id in stop_token_ids: if input_ids[0][-1] == stop_id: return True return False def bot(input_message: str, temperature=0.1, top_p=0.9, top_k=0, repetition_penalty=1.08): stop = StopOnTokens() messages = input_message input_ids = tok(messages, return_tensors="pt").input_ids input_ids = input_ids.to(m.device) streamer = TextIteratorStreamer(tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=input_ids, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=temperature > 0.0, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, streamer=streamer, stopping_criteria=StoppingCriteriaList([stop]), ) stream_complete = Event() def generate_and_signal_complete(): m.generate(**generate_kwargs) stream_complete.set() t1 = Thread(target=generate_and_signal_complete) t1.start() partial_text = "" for new_text in streamer: partial_text += new_text return partial_text gradio_interface = gr.Interface( fn=bot, inputs=[ "text", gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.1, step=0.1), gr.Slider(label="Top-p (nucleus sampling)", minimum=0.0, maximum=1.0, value=0.9, step=0.01), gr.Slider(label="Top-k", minimum=0, maximum=200, value=0, step=1), gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.1) ], outputs="text", title="REST API with Gradio and Huggingface Spaces", description="This is a demo of how to build an AI powered REST API with Gradio and Huggingface Spaces – for free! See the **Use via API** link at the bottom of this page.", ) gradio_interface.launch()