File size: 2,603 Bytes
a72fea7
832a4d2
 
f9e2c2e
a72fea7
832a4d2
 
 
 
a72fea7
832a4d2
 
 
 
 
 
 
 
a72fea7
832a4d2
 
e8ace7a
 
 
 
 
 
 
a72fea7
832a4d2
e8ace7a
832a4d2
 
 
1ea5080
832a4d2
db497f0
 
e8ace7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80bc875
1ea5080
832a4d2
db497f0
a72fea7
 
80bc875
1ea5080
bafd5e5
a72fea7
 
832a4d2
a72fea7
 
 
 
 
 
832a4d2
a72fea7
c8a8974
a72fea7
 
 
37a21af
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from unsloth import FastLanguageModel
from transformers import TextStreamer
import torch

# Load the model and tokenizer locally
max_seq_length = 2048
dtype = None
model_name_or_path = "michailroussos/model_llama_8d"

# Load model and tokenizer using unsloth
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name_or_path,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=True,
)
FastLanguageModel.for_inference(model)  # Enable optimized inference

# Define the response function
def respond(message, history, system_message, max_tokens, temperature, top_p):
    # Combine system message and conversation history
    messages = [{"role": "system", "content": system_message}]
    for user_msg, assistant_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if assistant_msg:
            messages.append({"role": "assistant", "content": assistant_msg})
    messages.append({"role": "user", "content": message})
    
    # Tokenize inputs
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to("cuda" if torch.cuda.is_available() else "cpu")

    # Use TextStreamer to process and yield outputs incrementally
    class GradioStreamer(TextStreamer):
        def __init__(self, tokenizer, *args, **kwargs):
            super().__init__(tokenizer, *args, **kwargs)
            self.generated_text = ""

        def on_token(self, token_id):
            token = self.tokenizer.decode(token_id, skip_special_tokens=True)
            self.generated_text += token
            yield self.generated_text

    # Initialize Gradio-compatible streamer
    streamer = GradioStreamer(tokenizer, skip_prompt=True)
    
    # Generate response with streaming
    _ = model.generate(
        input_ids=inputs,
        max_new_tokens=max_tokens,
        use_cache=True,
        temperature=temperature,
        top_p=top_p,
        streamer=streamer,
    )



# Define the Gradio interface
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
    type="messages"
)

if __name__ == "__main__":
    demo.launch(share=True)