Spaces:
Running
Running
import threading | |
import time | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
import torch | |
# Configura莽茫o do modelo | |
model_id = "lambdaindie/lambda-1v-1B" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
model.eval() | |
# CSS visual | |
css = """ | |
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono&display=swap'); | |
* { | |
font-family: 'JetBrains Mono', monospace !important; | |
} | |
html, body, .gradio-container { | |
background-color: #111 !important; | |
color: #e0e0e0 !important; | |
} | |
textarea, input, button, select { | |
background-color: transparent !important; | |
color: #e0e0e0 !important; | |
border: 1px solid #444 !important; | |
} | |
""" | |
# Controle global de parada | |
stop_signal = False | |
def stop_stream(): | |
global stop_signal | |
stop_signal = True | |
# Gera莽茫o com streaming | |
def generate_response(message, max_tokens, temperature, top_p): | |
global stop_signal | |
stop_signal = False | |
prompt = f"Question: {message}\nThinking: \nAnswer:" | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = dict( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
streamer=streamer, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
full_text = "" | |
for token in streamer: | |
if stop_signal: | |
break | |
full_text += token | |
yield full_text.strip() | |
if stop_signal: | |
return | |
# Interface Gradio | |
with gr.Blocks(css=css) as app: | |
chatbot = gr.Chatbot(label="位", elem_id="chatbot") | |
msg = gr.Textbox(label="Mensagem", placeholder="Digite aqui...", lines=2) | |
send_btn = gr.Button("Enviar") | |
stop_btn = gr.Button("Parar") | |
max_tokens = gr.Slider(64, 512, value=128, step=1, label="Max Tokens") | |
temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature") | |
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p") | |
state = gr.State([]) # hist贸rico apenas visual | |
def update_chat(message, chat_history): | |
chat_history = chat_history + [(message, None)] # adiciona s贸 a pergunta | |
return "", chat_history | |
def generate_full(chat_history, max_tokens, temperature, top_p): | |
message = chat_history[-1][0] # 煤ltima mensagem enviada | |
visual_history = chat_history[:-1] # remove temporariamente a entrada pendente | |
full_response = "" | |
for chunk in generate_response(message, max_tokens, temperature, top_p): | |
full_response = chunk | |
yield visual_history + [(message, full_response)], visual_history + [(message, full_response)] | |
send_btn.click(update_chat, inputs=[msg, state], outputs=[msg, state]) \ | |
.then(generate_full, inputs=[state, max_tokens, temperature, top_p], outputs=[chatbot, state]) | |
stop_btn.click(stop_stream, inputs=[], outputs=[]) | |
app.launch(share=True) |