File size: 3,402 Bytes
3d72f81
 
 
 
6aedfa8
3d72f81
 
 
6aedfa8
3d72f81
 
 
 
 
 
6aedfa8
 
3d72f81
 
 
6aedfa8
3d72f81
6aedfa8
3d72f81
 
6aedfa8
3d72f81
 
 
 
 
 
 
 
 
 
 
6aedfa8
3d72f81
 
 
 
 
6aedfa8
3d72f81
 
 
 
 
 
6aedfa8
3d72f81
 
 
 
 
 
 
6aedfa8
3d72f81
 
 
 
6aedfa8
 
 
3d72f81
6aedfa8
 
3d72f81
 
 
 
 
 
 
 
6aedfa8
3d72f81
 
 
 
 
 
 
6aedfa8
 
3d72f81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6aedfa8
 
3d72f81
 
 
 
 
 
 
 
 
6aedfa8
 
 
3d72f81
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
from collections.abc import Iterator
from threading import Thread

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

DESCRIPTION = """\
# Llama 3.2 3B Instruct

Llama 3.2 3B is Meta's latest iteration of open LLMs.
This is a demo of [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), fine-tuned for instruction following.
For more details, please check [our post](https://huggingface.co/blog/llama32).
"""

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = 32000

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("evabyte/EvaByte-SFT", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("evabyte/EvaByte-SFT", torch_dtype=torch.bfloat16, trust_remote_code=True).eval().to("cuda")

@spaces.GPU(duration=120)
def generate(
    message: str,
    chat_history: list[dict],
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = [*chat_history, {"role": "user", "content": message}]

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


demo = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,
    examples=[
        ["Write me an English pangram."],
    ],
    cache_examples=False,
    type="messages",
    description=DESCRIPTION,
    css_paths="style.css",
    fill_height=True,
)

if __name__ == "__main__":
    demo.queue(max_size=20).launch()