File size: 3,863 Bytes
4375b7f
4e683ec
76a154f
f4e3549
8c0a7e8
76a154f
b1c12fa
76a154f
d534002
4e683ec
76a154f
 
4375b7f
76a154f
ad4597f
 
 
 
 
76a154f
 
 
 
e12dd90
0da65a1
2fe7e62
4e683ec
 
 
 
 
 
 
 
e12dd90
4e683ec
 
2fe7e62
6111f2c
 
 
 
4e683ec
 
 
6111f2c
4e683ec
 
2fe7e62
4e683ec
 
 
 
76a154f
4e683ec
 
 
 
76a154f
 
4e683ec
 
 
a400f4b
4e683ec
 
76a154f
 
 
 
4e683ec
 
 
76a154f
 
 
2fe7e62
4e683ec
 
 
76a154f
 
 
4e683ec
 
 
 
76a154f
 
 
 
4e683ec
 
 
 
 
 
 
 
 
 
 
af33034
 
4e683ec
cadad8a
e12dd90
4e683ec
 
bb98bf8
6336a63
 
7621468
bf2e7bf
6336a63
d6a1a38
6336a63
 
 
 
 
 
a636dc2
f5aa776
a636dc2
 
76a154f
4e683ec
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
from threading import Thread
from typing import Iterator
import os
from huggingface_hub import login
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

def load_model():
    model_id = "stabilityai/ar-stablelm-2-chat"
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    tokenizer.use_default_system_prompt = False


def generate(
    message: str,
    chat_history: list[dict],
    system_prompt: str = "",
    max_new_tokens: int = 128,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    conversation += chat_history
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        temperature=temperature,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.7,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,
    examples=[
        ["السلام عليكم"],
        ["اعرب الجملة التالية: ذهبت الى السوق"]
    ],
    cache_examples=False,
    type="messages",
)

with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
    def authenticate_token(token):
        try:
            login(token = token)
            # load_model()
            return f"Authenticated successfully"
        except:
            return "Invalid token. Please try again."

    # Components
    token_input = gr.Textbox(label="Hugging Face Access Token", type="password", placeholder="Enter your token here...")
    auth_button = gr.Button("Authenticate")
    output = gr.Textbox(label="Output")
    auth_button.click(fn=authenticate_token, inputs=token_input, outputs=output)
    chat_interface.render()

    

if __name__ == "__main__":
    demo.queue(max_size=20).launch()