lamb / app.py
mariusjabami's picture
Update app.py
8dfb8e4 verified
import threading
import time
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
# Configura莽茫o do modelo
model_id = "lambdaindie/lambda-1v-1B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
# CSS visual
css = """
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono&display=swap');
* {
font-family: 'JetBrains Mono', monospace !important;
}
html, body, .gradio-container {
background-color: #111 !important;
color: #e0e0e0 !important;
}
textarea, input, button, select {
background-color: transparent !important;
color: #e0e0e0 !important;
border: 1px solid #444 !important;
}
"""
# Controle global de parada
stop_signal = False
def stop_stream():
global stop_signal
stop_signal = True
# Gera莽茫o com streaming
def generate_response(message, max_tokens, temperature, top_p):
global stop_signal
stop_signal = False
prompt = f"Question: {message}\nThinking: \nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
eos_token_id=tokenizer.eos_token_id
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
full_text = ""
for token in streamer:
if stop_signal:
break
full_text += token
yield full_text.strip()
if stop_signal:
return
# Interface Gradio
with gr.Blocks(css=css) as app:
chatbot = gr.Chatbot(label="位", elem_id="chatbot")
msg = gr.Textbox(label="Mensagem", placeholder="Digite aqui...", lines=2)
send_btn = gr.Button("Enviar")
stop_btn = gr.Button("Parar")
max_tokens = gr.Slider(64, 512, value=128, step=1, label="Max Tokens")
temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
state = gr.State([]) # hist贸rico apenas visual
def update_chat(message, chat_history):
chat_history = chat_history + [(message, None)] # adiciona s贸 a pergunta
return "", chat_history
def generate_full(chat_history, max_tokens, temperature, top_p):
message = chat_history[-1][0] # 煤ltima mensagem enviada
visual_history = chat_history[:-1] # remove temporariamente a entrada pendente
full_response = ""
for chunk in generate_response(message, max_tokens, temperature, top_p):
full_response = chunk
yield visual_history + [(message, full_response)], visual_history + [(message, full_response)]
send_btn.click(update_chat, inputs=[msg, state], outputs=[msg, state]) \
.then(generate_full, inputs=[state, max_tokens, temperature, top_p], outputs=[chatbot, state])
stop_btn.click(stop_stream, inputs=[], outputs=[])
app.launch(share=True)