|
import os |
|
from threading import Event, Thread |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
StoppingCriteria, |
|
StoppingCriteriaList, |
|
TextIteratorStreamer, |
|
) |
|
from huggingface_hub import login |
|
import gradio as gr |
|
import torch |
|
|
|
login(os.getenv("HF_TOKEN", None)) |
|
|
|
model_name = "richardr1126/spider-natsql-wizard-coder-8bit" |
|
tok = AutoTokenizer.from_pretrained(model_name) |
|
|
|
max_new_tokens = 1536 |
|
|
|
print(f"Starting to load the model {model_name}") |
|
|
|
m = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
device_map=0, |
|
load_in_8bit=True, |
|
) |
|
|
|
m.config.pad_token_id = m.config.eos_token_id |
|
m.generation_config.pad_token_id = m.config.eos_token_id |
|
|
|
stop_tokens = [";", "###", "Result"] |
|
stop_token_ids = tok.convert_tokens_to_ids(stop_tokens) |
|
|
|
print(f"Successfully loaded the model {model_name} into memory") |
|
|
|
class StopOnTokens(StoppingCriteria): |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
for stop_id in stop_token_ids: |
|
if input_ids[0][-1] == stop_id: |
|
return True |
|
return False |
|
|
|
def bot(input_message: str, temperature=0.1, top_p=0.9, top_k=0, repetition_penalty=1.08): |
|
stop = StopOnTokens() |
|
|
|
messages = input_message |
|
|
|
input_ids = tok(messages, return_tensors="pt").input_ids |
|
input_ids = input_ids.to(m.device) |
|
streamer = TextIteratorStreamer(tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
generate_kwargs = dict( |
|
input_ids=input_ids, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
do_sample=temperature > 0.0, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
streamer=streamer, |
|
stopping_criteria=StoppingCriteriaList([stop]), |
|
) |
|
|
|
stream_complete = Event() |
|
|
|
def generate_and_signal_complete(): |
|
m.generate(**generate_kwargs) |
|
stream_complete.set() |
|
|
|
t1 = Thread(target=generate_and_signal_complete) |
|
t1.start() |
|
|
|
partial_text = "" |
|
for new_text in streamer: |
|
partial_text += new_text |
|
|
|
return partial_text |
|
|
|
gradio_interface = gr.Interface( |
|
fn=bot, |
|
inputs=[ |
|
"text", |
|
gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.1, step=0.1), |
|
gr.Slider(label="Top-p (nucleus sampling)", minimum=0.0, maximum=1.0, value=0.9, step=0.01), |
|
gr.Slider(label="Top-k", minimum=0, maximum=200, value=0, step=1), |
|
gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.1) |
|
], |
|
outputs="text", |
|
title="REST API with Gradio and Huggingface Spaces", |
|
description="This is a demo of how to build an AI powered REST API with Gradio and Huggingface Spaces – for free! See the **Use via API** link at the bottom of this page.", |
|
) |
|
gradio_interface.launch() |
|
|