File size: 3,367 Bytes
a72fea7
7c34777
832a4d2
0787acc
a72fea7
7c34777
832a4d2
3dc2f1d
b751030
 
a72fea7
7c34777
0787acc
832a4d2
 
 
3dc2f1d
832a4d2
 
7c34777
0787acc
35ddf38
7c34777
9202d9a
 
37be440
6300d69
7c34777
 
0787acc
 
7c34777
5830d67
99b9339
 
 
 
7c34777
3dc2f1d
5830d67
 
 
3bc8976
9202d9a
 
 
 
 
 
 
 
b55dba2
 
 
 
9202d9a
 
8891495
15bfa4e
b55dba2
343abde
b55dba2
 
 
 
37be440
0acf62f
15bfa4e
37be440
 
 
129e6d2
9202d9a
3dc2f1d
3bc8976
 
7c34777
6300d69
5830d67
9202d9a
7c34777
 
 
 
 
 
 
 
0787acc
7c34777
 
 
a72fea7
 
 
0787acc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
from transformers import TextStreamer
from unsloth import FastLanguageModel
import torch

# Model Configuration
max_seq_length = 2048
dtype = None
model_name_or_path = "michailroussos/model_llama_8d"
#model_name_or_path = "Natassaf/lora_model-llama-new"

# Load Model and Tokenizer
print("Loading model...")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name_or_path,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=True,
)
FastLanguageModel.for_inference(model)  # Enable faster inference
print("Model loaded successfully!")

# Gradio Response Function
from transformers import TextStreamer

def respond(message, max_new_tokens, temperature, system_message="You are a helpful assistant. You should reply to the user's message without repeating the input."):
    try:
        # Prepare input messages
        messages = [{"role": "system", "content": system_message}] if system_message else []
        messages.append({"role": "user", "content": message})

        # Tokenize inputs
        input_ids = tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
        ).to("cuda")

        # Ensure the input tensor has the correct dimensions
        if input_ids.dim() != 2:
            raise ValueError(f"`input_ids` must be a 2D tensor. Found shape: {input_ids.shape}")

        # Generate output directly
        with torch.no_grad():  # No need to track gradients for inference
            output = model.generate(
                input_ids=input_ids,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                use_cache=True,
            )
        promt = messages[0]['content']
        promt += "assistant"
        print("[DEBUG] prompt with assistant:",promt)
        
        # Decode the generated tokens back to text
        generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
        print("[DEBUG] Generated Text:", generated_text)

        start_pos = generated_text.find(promt)
        result_text = generated_text[start_pos + len(promt)+2:]
        print("[DEBUG] Result Text:", result_text)

        #print("[DEBUG] Generated Text:", generated_text)

        # Clean up the response by removing unwanted parts (e.g., system and user info)
        cleaned_response = "".join(generated_text.split("\n")[9:])  # Assuming the response ends at the last line

        # Debug: Show the cleaned response
        print("[DEBUG] Cleaned Response:", cleaned_response)

        return result_text

    except Exception as e:
        # Debug: Log errors
        print("[ERROR]", str(e))
        return f"Error: {str(e)}"



# Gradio UI
demo = gr.Interface(
    fn=respond,
    inputs=[
        gr.Textbox(label="Your Message", placeholder="Enter your prompt here..."),
        gr.Slider(minimum=1, maximum=512, step=1, value=128, label="Max New Tokens"),
        gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature"),
        gr.Textbox(label="System Message", placeholder="Optional system instructions."),
    ],
    outputs="text",
    title="LLama-based Chatbot",
    description="Interact with the model. Enter a prompt and receive a response.",
)

if __name__ == "__main__":
    demo.launch(share=True)