File size: 3,371 Bytes
b58bfab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
from threading import Thread

MODEL = "THUDM/LongWriter-glm4-9b"

TITLE = "<h1><center>LongWriter-glm4-9b</center></h1>"

PLACEHOLDER = """
<center>
<p>Hi! I'm LongWriter-glm4-9b, capable of generating 10,000+ words. How can I assist you today?</p>
</center>
"""

CSS = """
.duplicate-button {
    margin: auto !important;
    color: white !important;
    background: black !important;
    border-radius: 100vh !important;
}
h3 {
    text-align: center;
}
"""

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
model = model.eval()

def stream_chat(
    message: str, 
    history: list,
    system_prompt: str,
    temperature: float = 0.5, 
    max_new_tokens: int = 32768, 
    top_p: float = 1.0, 
    top_k: int = 50,
):
    print(f'message: {message}')
    print(f'history: {history}')

    # Prepare the conversation history
    chat_history = []
    for prompt, answer in history:
        chat_history.append((prompt, answer))

    # Generate the response
    for response, _ in model.stream_chat(
        tokenizer,
        message,
        chat_history,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
    ):
        yield response

chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)

with gr.Blocks(css=CSS, theme="soft") as demo:
    gr.HTML(TITLE)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
    gr.ChatInterface(
        fn=stream_chat,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
        additional_inputs=[
            gr.Textbox(
                value="You are a helpful assistant capable of generating long-form content.",
                label="System Prompt",
            ),
            gr.Slider(
                minimum=0,
                maximum=1,
                step=0.1,
                value=0.5,
                label="Temperature",
            ),
            gr.Slider(
                minimum=1024,
                maximum=32768,
                step=1024,
                value=32768,
                label="Max new tokens",
            ),
            gr.Slider(
                minimum=0.0,
                maximum=1.0,
                step=0.1,
                value=1.0,
                label="Top p",
            ),
            gr.Slider(
                minimum=1,
                maximum=100,
                step=1,
                value=50,
                label="Top k",
            ),
        ],
        examples=[
            ["Write a 10000-word comprehensive guide on artificial intelligence and its applications."],
            ["Create a detailed 5000-word business plan for a space tourism company."],
            ["Compose a 3000-word short story about time travel and its consequences."],
            ["Develop a 7000-word research proposal on the potential of quantum computing in cryptography."],
        ],
        cache_examples=False,
    )

if __name__ == "__main__":
    demo.launch()