Spaces:
Running
Running
import os | |
import time | |
import threading | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
import torch | |
# Carregar modelo local | |
model_id = "lambdaindie/lambda-1v-1B" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
) | |
model.to("cuda" if torch.cuda.is_available() else "cpu") | |
model.eval() | |
# Estilo | |
css = """ | |
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono&display=swap'); | |
* { | |
font-family: 'JetBrains Mono', monospace !important; | |
} | |
html, body, .gradio-container { | |
background-color: #111 !important; | |
color: #e0e0e0 !important; | |
} | |
textarea, input, button, select { | |
background-color: transparent !important; | |
color: #e0e0e0 !important; | |
border: 1px solid #444 !important; | |
} | |
.markdown-think { | |
background-color: #1e1e1e; | |
border-left: 4px solid #555; | |
padding: 10px; | |
margin-bottom: 8px; | |
font-style: italic; | |
white-space: pre-wrap; | |
animation: pulse 1.5s infinite ease-in-out; | |
} | |
@keyframes pulse { | |
0% { opacity: 0.6; } | |
50% { opacity: 1.0; } | |
100% { opacity: 0.6; } | |
} | |
""" | |
theme = gr.themes.Base( | |
primary_hue="gray", | |
font=[gr.themes.GoogleFont("JetBrains Mono"), "monospace"] | |
).set( | |
body_background_fill="#111", | |
body_text_color="#e0e0e0", | |
button_primary_background_fill="#333", | |
button_primary_text_color="#e0e0e0", | |
input_background_fill="#222", | |
input_border_color="#444", | |
block_title_text_color="#fff" | |
) | |
# Flag de parada | |
stop_signal = False | |
def stop_stream(): | |
global stop_signal | |
stop_signal = True | |
def respond(history, system_message, max_tokens, temperature, top_p): | |
global stop_signal | |
stop_signal = False | |
# Construir prompt | |
prompt = "" | |
if system_message: | |
prompt += system_message + "\n\n" | |
for msg in history: | |
role = msg["role"] | |
content = msg["content"] | |
if role == "user": | |
prompt += f"User: {content}\n" | |
elif role == "assistant": | |
prompt += f"Assistant: {content}\n" | |
prompt += "Assistant:" | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = dict( | |
**inputs, | |
streamer=streamer, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
) | |
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
output = "" | |
start = time.time() | |
for token in streamer: | |
if stop_signal: | |
break | |
output += token | |
yield history + [{"role": "assistant", "content": output}] | |
end = time.time() | |
yield history + [ | |
{"role": "assistant", "content": output}, | |
{"role": "system", "content": f"Pensou por {end - start:.1f} segundos"} | |
] | |
# Interface | |
with gr.Blocks(css=css, theme=theme) as app: | |
chatbot = gr.Chatbot(label="λ", type="messages") | |
state = gr.State([]) | |
with gr.Row(): | |
msg = gr.Textbox(label="Mensagem") | |
send_btn = gr.Button("Enviar") | |
stop_btn = gr.Button("Parar") | |
with gr.Accordion("Configurações Avançadas", open=False): | |
system_message = gr.Textbox(label="System Message", value="") | |
max_tokens = gr.Slider(64, 2048, value=256, step=1, label="Max Tokens") | |
temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") | |
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p") | |
def handle_user_msg(user_msg, chat_history): | |
if user_msg: | |
chat_history = chat_history + [{"role": "user", "content": user_msg}] | |
return "", chat_history | |
send_btn.click(fn=handle_user_msg, inputs=[msg, state], outputs=[msg, state])\ | |
.then(fn=respond, inputs=[state, system_message, max_tokens, temperature, top_p], outputs=[chatbot, state]) | |
stop_btn.click(fn=stop_stream, inputs=[], outputs=[]) | |
app.launch(share=True |