File size: 3,408 Bytes
4f7e40d
8751f54
a474012
665b7ce
a474012
9faf370
8751f54
9917b41
a474012
9917b41
 
 
 
8751f54
 
a474012
665b7ce
8751f54
665b7ce
 
a474012
 
 
665b7ce
 
 
 
 
 
 
 
 
 
 
8751f54
a474012
 
 
 
 
 
8751f54
 
a474012
 
 
8dfb8e4
8751f54
a474012
 
9917b41
a474012
8751f54
 
a474012
 
 
 
 
8751f54
a474012
 
 
 
 
8751f54
a474012
 
 
8751f54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a474012
665b7ce
8751f54
 
 
 
 
 
 
 
 
 
 
665b7ce
8751f54
665b7ce
e5039e0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import threading
import time
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch

# Configura莽茫o do modelo
model_id = "lambdaindie/lambda-1v-1B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

# CSS visual
css = """
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono&display=swap');
* {
    font-family: 'JetBrains Mono', monospace !important;
}
html, body, .gradio-container {
    background-color: #111 !important;
    color: #e0e0e0 !important;
}
textarea, input, button, select {
    background-color: transparent !important;
    color: #e0e0e0 !important;
    border: 1px solid #444 !important;
}
"""

# Controle global de parada
stop_signal = False

def stop_stream():
    global stop_signal
    stop_signal = True

# Gera莽茫o com streaming
def generate_response(message, max_tokens, temperature, top_p):
    global stop_signal
    stop_signal = False

    prompt = f"Question: {message}\nThinking: \nAnswer:"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    generation_kwargs = dict(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        streamer=streamer,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
        eos_token_id=tokenizer.eos_token_id
    )

    thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    full_text = ""
    for token in streamer:
        if stop_signal:
            break
        full_text += token
        yield full_text.strip()

    if stop_signal:
        return

# Interface Gradio
with gr.Blocks(css=css) as app:
    chatbot = gr.Chatbot(label="位", elem_id="chatbot")
    msg = gr.Textbox(label="Mensagem", placeholder="Digite aqui...", lines=2)
    send_btn = gr.Button("Enviar")
    stop_btn = gr.Button("Parar")

    max_tokens = gr.Slider(64, 512, value=128, step=1, label="Max Tokens")
    temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature")
    top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")

    state = gr.State([])  # hist贸rico apenas visual

    def update_chat(message, chat_history):
        chat_history = chat_history + [(message, None)]  # adiciona s贸 a pergunta
        return "", chat_history

    def generate_full(chat_history, max_tokens, temperature, top_p):
        message = chat_history[-1][0]  # 煤ltima mensagem enviada
        visual_history = chat_history[:-1]  # remove temporariamente a entrada pendente

        full_response = ""
        for chunk in generate_response(message, max_tokens, temperature, top_p):
            full_response = chunk
            yield visual_history + [(message, full_response)], visual_history + [(message, full_response)]

    send_btn.click(update_chat, inputs=[msg, state], outputs=[msg, state]) \
        .then(generate_full, inputs=[state, max_tokens, temperature, top_p], outputs=[chatbot, state])

    stop_btn.click(stop_stream, inputs=[], outputs=[])

app.launch(share=True)