File size: 2,112 Bytes
a72fea7
832a4d2
 
f9e2c2e
a72fea7
832a4d2
 
 
 
a72fea7
832a4d2
 
 
 
 
 
 
 
a72fea7
832a4d2
 
a72fea7
c8295e7
1ea5080
c8295e7
1ea5080
c8295e7
a72fea7
832a4d2
 
 
 
1ea5080
832a4d2
c8a8974
 
 
1ea5080
 
 
832a4d2
 
a72fea7
 
1ea5080
 
 
a72fea7
 
832a4d2
a72fea7
 
 
 
 
 
832a4d2
a72fea7
c8a8974
a72fea7
 
 
37a21af
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gradio as gr
from unsloth import FastLanguageModel
from transformers import TextStreamer
import torch

# Load the model and tokenizer locally
max_seq_length = 2048
dtype = None
model_name_or_path = "michailroussos/model_llama_8d"

# Load model and tokenizer using unsloth
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name_or_path,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=True,
)
FastLanguageModel.for_inference(model)  # Enable optimized inference

# Define the response function
def respond(message, history, system_message, max_tokens, temperature, top_p):
    messages = [{"role": "system", "content": system_message}]
    for user_msg, assistant_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if assistant_msg:
            messages.append({"role": "assistant", "content": assistant_msg})
    messages.append({"role": "user", "content": message})
    
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")
    attention_mask = inputs.ne(tokenizer.pad_token_id).long()
    
    output = model.generate(
        input_ids=inputs,
        max_new_tokens=max_tokens,
        use_cache=True,
        temperature=temperature,
        top_p=top_p,
    )
    response = tokenizer.decode(output[0], skip_special_tokens=True)
    return response  # Return full response directly


# Define the Gradio interface
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
    type="messages"
)

if __name__ == "__main__":
    demo.launch(share=True)