lamb / app.py
mariusjabami's picture
Update app.py
e5039e0 verified
raw
history blame
4.15 kB
import os
import time
import threading
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
# Carregar modelo local
model_id = "lambdaindie/lambda-1v-1B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
model.to("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
# Estilo
css = """
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono&display=swap');
* {
font-family: 'JetBrains Mono', monospace !important;
}
html, body, .gradio-container {
background-color: #111 !important;
color: #e0e0e0 !important;
}
textarea, input, button, select {
background-color: transparent !important;
color: #e0e0e0 !important;
border: 1px solid #444 !important;
}
.markdown-think {
background-color: #1e1e1e;
border-left: 4px solid #555;
padding: 10px;
margin-bottom: 8px;
font-style: italic;
white-space: pre-wrap;
animation: pulse 1.5s infinite ease-in-out;
}
@keyframes pulse {
0% { opacity: 0.6; }
50% { opacity: 1.0; }
100% { opacity: 0.6; }
}
"""
theme = gr.themes.Base(
primary_hue="gray",
font=[gr.themes.GoogleFont("JetBrains Mono"), "monospace"]
).set(
body_background_fill="#111",
body_text_color="#e0e0e0",
button_primary_background_fill="#333",
button_primary_text_color="#e0e0e0",
input_background_fill="#222",
input_border_color="#444",
block_title_text_color="#fff"
)
# Flag de parada
stop_signal = False
def stop_stream():
global stop_signal
stop_signal = True
def respond(history, system_message, max_tokens, temperature, top_p):
global stop_signal
stop_signal = False
# Construir prompt
prompt = ""
if system_message:
prompt += system_message + "\n\n"
for msg in history:
role = msg["role"]
content = msg["content"]
if role == "user":
prompt += f"User: {content}\n"
elif role == "assistant":
prompt += f"Assistant: {content}\n"
prompt += "Assistant:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
output = ""
start = time.time()
for token in streamer:
if stop_signal:
break
output += token
yield history + [{"role": "assistant", "content": output}]
end = time.time()
yield history + [
{"role": "assistant", "content": output},
{"role": "system", "content": f"Pensou por {end - start:.1f} segundos"}
]
# Interface
with gr.Blocks(css=css, theme=theme) as app:
chatbot = gr.Chatbot(label="λ", type="messages")
state = gr.State([])
with gr.Row():
msg = gr.Textbox(label="Mensagem")
send_btn = gr.Button("Enviar")
stop_btn = gr.Button("Parar")
with gr.Accordion("Configurações Avançadas", open=False):
system_message = gr.Textbox(label="System Message", value="")
max_tokens = gr.Slider(64, 2048, value=256, step=1, label="Max Tokens")
temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
def handle_user_msg(user_msg, chat_history):
if user_msg:
chat_history = chat_history + [{"role": "user", "content": user_msg}]
return "", chat_history
send_btn.click(fn=handle_user_msg, inputs=[msg, state], outputs=[msg, state])\
.then(fn=respond, inputs=[state, system_message, max_tokens, temperature, top_p], outputs=[chatbot, state])
stop_btn.click(fn=stop_stream, inputs=[], outputs=[])
app.launch(share=True)