File size: 3,352 Bytes
a72fea7
832a4d2
f9e2c2e
a72fea7
832a4d2
 
04cf79a
a72fea7
832a4d2
 
 
 
 
 
 
a72fea7
832a4d2
 
6300d69
 
04cf79a
 
 
6300d69
3a645a0
9b00c4f
6300d69
9b00c4f
be78dc3
 
6300d69
 
04cf79a
a72fea7
9b00c4f
6300d69
 
 
3a645a0
04cf79a
832a4d2
 
 
1ea5080
832a4d2
db497f0
9b00c4f
04cf79a
 
6300d69
9b00c4f
3a645a0
04cf79a
6300d69
 
 
 
 
 
 
 
 
 
 
 
 
1188d49
04cf79a
3a645a0
04cf79a
3a645a0
6300d69
 
 
be78dc3
6300d69
be78dc3
 
 
04cf79a
be78dc3
 
3a645a0
be78dc3
6300d69
832a4d2
a72fea7
9b00c4f
a72fea7
9b00c4f
 
a72fea7
9b00c4f
a72fea7
 
 
 
be78dc3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
from unsloth import FastLanguageModel
import torch

# Load the model and tokenizer locally
max_seq_length = 2048
model_name_or_path = "michailroussos/model_llama_8d"

# Load model and tokenizer using unsloth
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name_or_path,
    max_seq_length=max_seq_length,
    load_in_4bit=True,
)
FastLanguageModel.for_inference(model)  # Enable optimized inference

# Define the response function
def respond(message, history, system_message, max_tokens, temperature, top_p):
    # Print the inputs at the start
    print("===== Respond Function Called =====")
    print(f"Received message: {message}")
    print(f"Current history: {history}")
    
    # Prepare the messages for the model
    messages = []
    if history:
        print("Adding previous messages to the history...")
        for entry in history:
            messages.append({"role": "user", "content": entry[0]})
            messages.append({"role": "assistant", "content": entry[1]})

    # Add the current user message
    print(f"Adding current user message: {message}")
    messages.append({"role": "user", "content": message})

    # Print the messages list before tokenization
    print("Messages before tokenization:", messages)

    # Tokenize the input (prepare the data for the model)
    print("Preparing the input for the model...")
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to("cuda" if torch.cuda.is_available() else "cpu")

    # Print the tokenized inputs
    print(f"Tokenized inputs: {inputs}")

    # Generate the response
    attention_mask = inputs.ne(tokenizer.pad_token_id).long()
    print(f"Attention mask: {attention_mask}")

    try:
        generated_tokens = model.generate(
            input_ids=inputs,
            attention_mask=attention_mask,
            max_new_tokens=max_tokens,
            use_cache=True,
            temperature=temperature,
            top_p=top_p,
        )
    except Exception as e:
        print(f"Error during model generation: {e}")
        return []

    # Decode the generated response
    response = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
    print(f"Generated response: {response}")
    
    # Check and filter out unwanted system-level messages or metadata
    if "system" in response.lower():
        print("System message detected. Replacing with fallback response.")
        response = "Hello! How can I assist you today?"

    # Prepare the return format for Gradio (list of [user_message, assistant_message])
    if history is None:
        history = []
    
    # Append the new conversation turn
    history.append([message, response])

    return history

# Define the Gradio interface
demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[
        gr.Textbox(value="You are a helpful assistant.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
    ],
)

if __name__ == "__main__":
    demo.launch(share=False)  # Use share=False for local testing