File size: 4,369 Bytes
f15ed8e
 
 
 
f5f6359
f15ed8e
 
835ba85
f5f6359
5c2ab5b
1dc34d7
d9b9a34
1dc34d7
 
d9b9a34
1dc34d7
0488844
d6af013
 
835ba85
d6af013
 
 
 
 
 
 
 
 
 
 
 
 
835ba85
 
d6af013
335ddf3
1dc34d7
d6af013
 
 
 
1dc34d7
d9b9a34
0eb1946
f15ed8e
 
 
7942c52
1dc34d7
0b127ab
d9b9a34
 
 
f15ed8e
 
1dc34d7
 
f15ed8e
 
 
0eb1946
f15ed8e
 
 
 
d6af013
0eb1946
d6af013
f15ed8e
d6af013
f15ed8e
 
0eb1946
 
f15ed8e
0eb1946
72ded5c
f15ed8e
0eb1946
f15ed8e
 
0eb1946
f15ed8e
 
 
 
0eb1946
f15ed8e
 
 
7942c52
f15ed8e
 
 
 
 
 
 
 
 
 
 
 
0b127ab
f15ed8e
 
 
 
 
 
d9b9a34
f15ed8e
 
 
00cff3f
f15ed8e
 
d9b9a34
f15ed8e
 
 
 
 
 
d9b9a34
f15ed8e
 
 
 
 
 
 
 
 
0eb1946
f5f6359
f15ed8e
d6af013
7250fa7
 
 
f15ed8e
522d261
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
from threading import Thread
from typing import Iterator

import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import bitsandbytes

MAX_MAX_NEW_TOKENS = 4096
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))

DESCRIPTION = """\
# Chat
"""

# Load model with appropriate device configuration
def load_model():
    model_id = "CreitinGameplays/Mistral-Nemo-12B-R1-v0.1"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # If using CPU, load in 32-bit to avoid potential issues with 16-bit operations
    if device == "cpu":
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.float32,
            low_cpu_mem_usage=True
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.float16,
            device_map="auto",
            load_in_8bit=True
        )
    tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left')
    tokenizer.use_default_system_prompt = False
    
    return model, tokenizer, device

model, tokenizer, device = load_model()

system_prompt_text = "You are a helpful AI assistant."

def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str = system_prompt_text,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.1,
) -> Iterator[str]:
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6, value=system_prompt_text),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=0,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.1,
        ),
    ],
    stop_btn=None,
    examples=[
        ["Hello there! How are you doing?"],
        ["Can you explain briefly to me what is the Python programming language?"],
        ["Explain the plot of Cinderella in a sentence."],
        ["How many hours does it take a man to eat a Helicopter?"],
        ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
    ],
)

with gr.Blocks() as demo:
    gr.Markdown(DESCRIPTION)
    chat_interface.render()
    
if __name__ == "__main__":
    demo.queue(max_size=10).launch(show_error=True)