File size: 3,549 Bytes
a1a543e
 
 
 
a97bf6b
a1a543e
588b2d4
 
a1a543e
 
 
 
 
 
a97bf6b
a1a543e
a97bf6b
a1a543e
a97bf6b
a1a543e
 
c3cbdc6
 
 
 
 
 
a1a543e
 
 
 
 
 
 
 
c3cbdc6
a1a543e
 
c3cbdc6
a1a543e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588b2d4
c3cbdc6
588b2d4
a1a543e
 
 
 
 
c3cbdc6
a1a543e
 
 
 
 
 
c3cbdc6
a1a543e
c3cbdc6
a1a543e
 
c3cbdc6
a1a543e
 
 
 
 
c3cbdc6
a1a543e
 
c3cbdc6
 
 
a1a543e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from threading import Thread

import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer

model_id = "EleutherAI/pythia-6.9b-deduped"
assistant_id = "EleutherAI/pythia-70m-deduped"
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", torch_device)
print("CPU threads:", torch.get_num_threads())


if torch_device == "cuda":
    model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
else:
    model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_id).to(torch_device)


def run_generation(user_text, use_assistant, top_p, temperature, top_k, max_new_tokens):
    if temperature == 0.0:
        do_sample = False
    else:
        do_sample = True

    # Get the model and tokenizer, and tokenize the user text.
    model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)

    # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
    # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
    streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        assistant_model=assistant_model if use_assistant else None,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        top_p=top_p,
        temperature=float(temperature),
        top_k=top_k
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # Pull the generated text from the streamer, and update the model output.
    model_output = ""
    for new_text in streamer:
        model_output += new_text
        yield model_output
    return model_output


def reset_textbox():
    return gr.update(value='')


with gr.Blocks() as demo:
    gr.Markdown(
        "# 🤗 Assisted Generation Demo\n"
        f"Model: {model_id} (using INT8)\n\n"
        f"Assistant Model: {assistant_id}"
    )

    with gr.Row():
        with gr.Column(scale=4):
            user_text = gr.Textbox(
                placeholder="Question: What is the meaning of life? Answer:",
                label="User input"
            )
            model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
            button_submit = gr.Button(value="Submit")

        with gr.Column(scale=1):
            use_assistant = gr.Checkbox(label="Use Assistant", default=True)
            max_new_tokens = gr.Slider(
                minimum=1, maximum=500, value=250, step=1, interactive=True, label="Max New Tokens",
            )
            top_p = gr.Slider(
                minimum=0.05, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p",
            )
            top_k = gr.Slider(
                minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k",
            )
            temperature = gr.Slider(
                minimum=0.0, maximum=2.0, value=0.0, step=0.1, interactive=True, label="Temperature (0.0 = Greedy)",
            )

    generate_inputs = [user_text, use_assistant, top_p, temperature, top_k, max_new_tokens]
    user_text.submit(run_generation, generate_inputs, model_output)
    button_submit.click(run_generation, generate_inputs, model_output)

    demo.queue(max_size=32).launch(enable_queue=True)