File size: 2,199 Bytes
a72fea7
832a4d2
f9e2c2e
a72fea7
832a4d2
 
 
a72fea7
832a4d2
 
 
 
 
 
 
a72fea7
832a4d2
 
e8ace7a
 
 
 
 
 
a72fea7
832a4d2
 
 
 
1ea5080
832a4d2
db497f0
bad2083
 
 
ebd9e26
1ea5080
ebd9e26
832a4d2
db497f0
a72fea7
 
1ea5080
ebd9e26
f5a59a6
 
 
ebd9e26
832a4d2
a72fea7
 
 
 
 
 
832a4d2
a72fea7
bad2083
a72fea7
 
 
f5a59a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import gradio as gr
from unsloth import FastLanguageModel
import torch

# Load the model and tokenizer locally
max_seq_length = 2048
model_name_or_path = "michailroussos/model_llama_8d"

# Load model and tokenizer using unsloth
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name_or_path,
    max_seq_length=max_seq_length,
    load_in_4bit=True,
)
FastLanguageModel.for_inference(model)  # Enable optimized inference

# Define the response function
def respond(message, history, system_message, max_tokens, temperature, top_p):
    messages = [{"role": "system", "content": system_message}]
    for user_msg, assistant_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if assistant_msg:
            messages.append({"role": "assistant", "content": assistant_msg})
    messages.append({"role": "user", "content": message})
    
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to("cuda" if torch.cuda.is_available() else "cpu")
    
    attention_mask = inputs.ne(tokenizer.pad_token_id).long()
    
    generated_tokens = model.generate(
        input_ids=inputs,
        attention_mask=attention_mask,
        max_new_tokens=max_tokens,
        use_cache=True,
        temperature=temperature,
        top_p=top_p,
    )
    response = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

    history.append((message, response))  # Update history with new exchange
    return history  # Return the updated history

# Define the Gradio interface
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
    type="messages",
)

if __name__ == "__main__":
    demo.launch(share=False)  # Set share=False for local testing