File size: 3,008 Bytes
3d72f81
6aedfa8
3d72f81
 
6c5de50
6aedfa8
3d72f81
eb895a8
3d72f81
6c5de50
 
eb895a8
 
 
 
6aedfa8
 
3d72f81
 
 
6aedfa8
3d72f81
6aedfa8
88feb92
6c5de50
 
 
 
 
 
6aedfa8
6308af3
3d72f81
 
 
6c5de50
3d72f81
 
6c5de50
 
3d72f81
6c5de50
 
 
 
 
6aedfa8
3d72f81
 
6c5de50
 
 
 
3d72f81
6aedfa8
6c5de50
 
3d72f81
 
6aedfa8
3d72f81
 
6aedfa8
6c5de50
b028eac
6aedfa8
 
 
3d72f81
6aedfa8
4b29b18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d72f81
6c5de50
 
3d72f81
 
 
 
6aedfa8
 
 
3d72f81
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

DESCRIPTION = """\
# EvaByte-SFT

EvaByte is a byte-level language model that combines multibyte prediction with the efficient EVA attention mechanism.  
This page hosts [EvaByte/EvaByte-SFT](https://huggingface.co/EvaByte/EvaByte-SFT), fine-tuned via supervised instruction data to enable chat and general instruction-following capabilities.  
For full details on architecture, training recipe, and benchmarks, see their blog post and the project repository:

- Blog: <https://hkunlp.github.io/blog/2025/evabyte>  
- GitHub: <https://github.com/OpenEvaByte/evabyte>
"""

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = 32000

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("EvaByte/EvaByte-SFT", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    "EvaByte/EvaByte-SFT",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
).eval().to(device)


@spaces.GPU(duration=60)
def generate(
    message: str,
    chat_history: list[dict],
    max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
    temperature: float = 0.6,
    top_p: float = 0.9,
) -> str:                                  

    conversation = [*chat_history, {"role": "user", "content": message}]
    input_ids = tokenizer.apply_chat_template(
        conversation,
        add_generation_prompt=True,
        return_tensors="pt"
    )

    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(
            f"Trimmed input to the last {MAX_INPUT_TOKEN_LENGTH} tokens because it exceeded the limit."
        )

    input_ids = input_ids.to(model.device)

    output_ids = model.multi_byte_generate(  
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        temperature=temperature,
    )

    generated_segment = output_ids[0][input_ids.shape[1]:]
    return tokenizer.decode(generated_segment, skip_special_tokens=True)


demo = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
    ],
    stop_btn=None,           
    examples=[["Write me an English pangram."]],
    cache_examples=False,
    type="messages",
    description=DESCRIPTION,
    fill_height=True,
)

if __name__ == "__main__":
    demo.queue(max_size=20).launch()