File size: 3,597 Bytes
a72fea7
832a4d2
f9e2c2e
a72fea7
832a4d2
 
04cf79a
a72fea7
832a4d2
 
 
 
 
 
 
a72fea7
832a4d2
 
04cf79a
 
 
 
 
3a645a0
9b00c4f
 
04cf79a
 
9b00c4f
 
3a645a0
 
04cf79a
a72fea7
9b00c4f
3a645a0
04cf79a
832a4d2
 
 
1ea5080
832a4d2
db497f0
9b00c4f
04cf79a
 
 
9b00c4f
3a645a0
04cf79a
ebd9e26
1ea5080
3a645a0
832a4d2
db497f0
a72fea7
 
1ea5080
1188d49
04cf79a
3a645a0
04cf79a
3a645a0
 
9b00c4f
 
4668547
9b00c4f
04cf79a
9b00c4f
 
04cf79a
 
9b00c4f
 
 
04cf79a
 
 
3a645a0
584beb9
9b00c4f
6195f56
3a645a0
832a4d2
a72fea7
9b00c4f
a72fea7
9b00c4f
 
a72fea7
9b00c4f
a72fea7
bad2083
a72fea7
 
9b00c4f
a72fea7
4668547
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
from unsloth import FastLanguageModel
import torch

# Load the model and tokenizer locally
max_seq_length = 2048
model_name_or_path = "michailroussos/model_llama_8d"

# Load model and tokenizer using unsloth
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name_or_path,
    max_seq_length=max_seq_length,
    load_in_4bit=True,
)
FastLanguageModel.for_inference(model)  # Enable optimized inference

# Define the response function
def respond(message, history, system_message, max_tokens, temperature, top_p):
    # Print to show the inputs at the start
    print(f"Received message: {message}")
    print(f"Current history: {history}")
    
    # Prepare the messages for the model: Exclude the system message for now
    messages = []
    if history:
        for entry in history:
            print(f"Adding user message to history: {entry['user']}")
            print(f"Adding assistant message to history: {entry['assistant']}")
            messages.append({"role": "user", "content": entry["user"]})
            messages.append({"role": "assistant", "content": entry["assistant"]})
    
    # Add the user's new message to the list
    print(f"Adding current user message: {message}")
    messages.append({"role": "user", "content": message})

    # Tokenize the input (prepare the data for the model)
    print("Preparing the input for the model...")
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to("cuda" if torch.cuda.is_available() else "cpu")

    # Print the tokenized inputs
    print(f"Tokenized inputs: {inputs}")
    
    # Generate the response
    attention_mask = inputs.ne(tokenizer.pad_token_id).long()
    print(f"Attention mask: {attention_mask}")
    generated_tokens = model.generate(
        input_ids=inputs,
        attention_mask=attention_mask,
        max_new_tokens=max_tokens,
        use_cache=True,
        temperature=temperature,
        top_p=top_p,
    )

    # Decode the generated response
    response = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
    print(f"Generated response: {response}")
    
    # Update the conversation history with the new user-assistant pair
    if history is None:
        history = []
    history.append({"user": message, "assistant": response})

    # Prepare the history for Gradio: Formatting it correctly
    formatted_history = []
    for entry in history:
        print(f"Formatting user message for history: {entry['user']}")
        print(f"Formatting assistant message for history: {entry['assistant']}")
        formatted_history.append({"role": "user", "content": entry["user"]})
        formatted_history.append({"role": "assistant", "content": entry["assistant"]})

    # Print the final formatted history before returning
    print(f"Formatted history for Gradio: {formatted_history}")
    
    # Return the formatted history for Gradio to display
    return formatted_history



# Define the Gradio interface
demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[
        gr.Textbox(value="You are a helpful assistant.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
    ],
    type="messages",
)


if __name__ == "__main__":
    demo.launch(share=False)  # Use share=False for local testing