File size: 4,676 Bytes
6f8934c
 
 
abd9fd7
6f8934c
 
 
 
 
 
abd9fd7
6f8934c
 
2ad05c2
6f8934c
 
 
 
 
abd9fd7
6f8934c
 
 
 
 
 
abd9fd7
6f8934c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47dded2
abd9fd7
6f8934c
 
 
 
 
 
47dded2
abd9fd7
 
de6224b
 
 
abd9fd7
6f8934c
abd9fd7
6f8934c
 
 
 
 
de6224b
 
6f8934c
 
 
de6224b
47dded2
11dd5a9
 
 
47dded2
f1d7efb
47dded2
804e8e0
 
6f8934c
 
 
abd9fd7
804e8e0
6f8934c
 
 
 
de6224b
6f8934c
804e8e0
6f8934c
abd9fd7
6f8934c
 
 
 
 
 
 
b951ea5
6f8934c
b951ea5
 
 
 
6f8934c
b951ea5
6f8934c
 
dffea0f
6f8934c
2a400f0
6f8934c
 
 
be15139
 
6f8934c
 
 
abd9fd7
 
 
 
 
 
 
 
 
 
 
 
de6224b
 
 
 
 
 
 
 
 
 
abd9fd7
6f8934c
3835819
 
 
abd9fd7
3835819
abd9fd7
6f8934c
2a400f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
import os
import spaces
from transformers import GemmaTokenizer, AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)


DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Test Model</h1>
</div>
'''

LICENSE = """
<p/>

---
"""

PLACEHOLDER = """
"""


css = """
h1 {
  text-align: center;
  display: block;
}

#duplicate-button {
  margin: auto;
  color: white;
  background: #1565c0;
  border-radius: 100vh;
}
"""

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("Orenguteng/Llama-3-8B-Lexi-Uncensored")
model = AutoModelForCausalLM.from_pretrained("Orenguteng/Llama-3-8B-Lexi-Uncensored", device_map="auto")  # to("cuda:0") 
terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

@spaces.GPU(duration=120)
def chat_llama3_8b(message: str, 
              history: list, 
              temperature: float, 
              max_new_tokens: int,
              top_p: float,
              system_prompt: str
             ) -> str:
    """
    Generate a streaming response using the llama3-8b model.
    Args:
        message (str): The input message.
        history (list): The conversation history used by ChatInterface.
        temperature (float): The temperature for generating the response.
        max_new_tokens (int): The maximum number of new tokens to generate.
        top_p (float): The top_p value for nucleus sampling.
        system_prompt (str): The system prompt to guide the conversation.
    Returns:
        str: The generated response.
    """
    conversation = [{"role": "system", "content": system_prompt}]
    for user, assistant in history:
        conversation.append({"role": "user", "content": user})
        conversation.append({"role": "assistant", "content": assistant})

    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
    attention_mask = input_ids.ne(tokenizer.pad_token_id).long()

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        input_ids= input_ids,
        attention_mask=attention_mask,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        eos_token_id=terminators,
        pad_token_id=tokenizer.eos_token_id,
    )
    # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.             
    if temperature == 0:
        generate_kwargs['do_sample'] = False
        
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    first_chunk = True
    for text in streamer:
        if first_chunk and text.startswith("assistant"):
            text = text[len("assistant"):].lstrip(": \n")  # Remove "assistant" and any following symbols
        first_chunk = False

        outputs.append(text)
        yield "".join(outputs)        

# Gradio block
chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface', type='messages')

with gr.Blocks(fill_height=True, css=css) as aida:
    
    gr.Markdown(DESCRIPTION)
    gr.ChatInterface(
        fn=chat_llama3_8b,
        chatbot=None,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Slider(minimum=0,
                      maximum=1, 
                      step=0.1,
                      value=0.8, 
                      label="Temperature", 
                      render=False),
            gr.Slider(minimum=128, 
                      maximum=4096,
                      step=1,
                      value=4096, 
                      label="Max new tokens", 
                      render=False ),
            gr.Slider(minimum=0,
                      maximum=1,
                      step=0.1,
                      value=0.9,
                      label="Top_p",
                      render=False),
            gr.Textbox(lines=2,
                       placeholder="Enter system prompt here...",
                       label="System Prompt",
                       render=False),
            ],
        examples=[
            ['Who Are you?']
            
            ],
        cache_examples=False,
                     )
    
if __name__ == "__main__":
    aida.launch(share=True)