File size: 3,262 Bytes
a72fea7
832a4d2
f9e2c2e
a72fea7
832a4d2
 
 
a72fea7
832a4d2
 
 
 
 
 
 
a72fea7
832a4d2
 
1188d49
e8ace7a
584beb9
1188d49
9b00c4f
 
 
 
 
1188d49
a72fea7
9b00c4f
1188d49
832a4d2
 
 
1ea5080
832a4d2
db497f0
9b00c4f
 
0556c99
ebd9e26
1ea5080
0556c99
832a4d2
db497f0
a72fea7
 
1ea5080
ebd9e26
f5a59a6
1188d49
584beb9
6195f56
1188d49
 
 
 
9b00c4f
 
4668547
9b00c4f
1188d49
584beb9
6195f56
1188d49
9b00c4f
 
 
 
 
584beb9
 
ebd9e26
584beb9
 
9b00c4f
6195f56
832a4d2
a72fea7
9b00c4f
a72fea7
9b00c4f
 
a72fea7
9b00c4f
a72fea7
bad2083
a72fea7
 
9b00c4f
a72fea7
4668547
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
from unsloth import FastLanguageModel
import torch

# Load the model and tokenizer locally
max_seq_length = 2048
model_name_or_path = "michailroussos/model_llama_8d"

# Load model and tokenizer using unsloth
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name_or_path,
    max_seq_length=max_seq_length,
    load_in_4bit=True,
)
FastLanguageModel.for_inference(model)  # Enable optimized inference

# Define the response function
def respond(message, history, system_message, max_tokens, temperature, top_p):
    # Prepare the messages, separating the system message from user/assistant pairs
    messages = [{"role": "system", "content": system_message}]
    
    # Append the conversation history (user-assistant pairs)
    if history:
        for entry in history:
            messages.append({"role": "user", "content": entry["user"]})
            messages.append({"role": "assistant", "content": entry["assistant"]})

    # Add the user's new message to the list of messages
    messages.append({"role": "user", "content": message})

    # Tokenize the input
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to("cuda" if torch.cuda.is_available() else "cpu")

    # Generate the response
    #attention_mask = inputs.ne(tokenizer.pad_token_id).long()
    generated_tokens = model.generate(
        input_ids=inputs,
        #attention_mask=attention_mask,
        max_new_tokens=max_tokens,
        use_cache=True,
        temperature=temperature,
        top_p=top_p,
    )
    response = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

    # Clean the response to ensure no system messages are included
    response = response.replace("Cutting Knowledge Date", "").replace("You are a helpful assistant.", "").strip()

    # Debug: Print the raw and cleaned assistant response
    print("Raw Assistant Response:", response)

    # Update the conversation history with the new user-assistant interaction
    if history is None:
        history = []
    history.append({"user": message, "assistant": response})

    # Debug: Print updated history
    print("Updated History:", history)
    
    # Format the history into the structure expected by Gradio
    formatted_history = []
    for entry in history:
        formatted_history.append({"role": "user", "content": entry["user"]})
        formatted_history.append({"role": "assistant", "content": entry["assistant"]})

    # Debug: Print the formatted history
    print("Formatted History:", formatted_history)

    # Return the formatted history
    return formatted_history


# Define the Gradio interface
demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[
        gr.Textbox(value="You are a helpful assistant.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
    ],
    type="messages",
)


if __name__ == "__main__":
    demo.launch(share=False)  # Use share=False for local testing