Spaces:
Running
Running
File size: 3,408 Bytes
4f7e40d 8751f54 a474012 665b7ce a474012 9faf370 8751f54 9917b41 a474012 9917b41 8751f54 a474012 665b7ce 8751f54 665b7ce a474012 665b7ce 8751f54 a474012 8751f54 a474012 8dfb8e4 8751f54 a474012 9917b41 a474012 8751f54 a474012 8751f54 a474012 8751f54 a474012 8751f54 a474012 665b7ce 8751f54 665b7ce 8751f54 665b7ce e5039e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import threading
import time
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
# Configura莽茫o do modelo
model_id = "lambdaindie/lambda-1v-1B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
# CSS visual
css = """
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono&display=swap');
* {
font-family: 'JetBrains Mono', monospace !important;
}
html, body, .gradio-container {
background-color: #111 !important;
color: #e0e0e0 !important;
}
textarea, input, button, select {
background-color: transparent !important;
color: #e0e0e0 !important;
border: 1px solid #444 !important;
}
"""
# Controle global de parada
stop_signal = False
def stop_stream():
global stop_signal
stop_signal = True
# Gera莽茫o com streaming
def generate_response(message, max_tokens, temperature, top_p):
global stop_signal
stop_signal = False
prompt = f"Question: {message}\nThinking: \nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
eos_token_id=tokenizer.eos_token_id
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
full_text = ""
for token in streamer:
if stop_signal:
break
full_text += token
yield full_text.strip()
if stop_signal:
return
# Interface Gradio
with gr.Blocks(css=css) as app:
chatbot = gr.Chatbot(label="位", elem_id="chatbot")
msg = gr.Textbox(label="Mensagem", placeholder="Digite aqui...", lines=2)
send_btn = gr.Button("Enviar")
stop_btn = gr.Button("Parar")
max_tokens = gr.Slider(64, 512, value=128, step=1, label="Max Tokens")
temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
state = gr.State([]) # hist贸rico apenas visual
def update_chat(message, chat_history):
chat_history = chat_history + [(message, None)] # adiciona s贸 a pergunta
return "", chat_history
def generate_full(chat_history, max_tokens, temperature, top_p):
message = chat_history[-1][0] # 煤ltima mensagem enviada
visual_history = chat_history[:-1] # remove temporariamente a entrada pendente
full_response = ""
for chunk in generate_response(message, max_tokens, temperature, top_p):
full_response = chunk
yield visual_history + [(message, full_response)], visual_history + [(message, full_response)]
send_btn.click(update_chat, inputs=[msg, state], outputs=[msg, state]) \
.then(generate_full, inputs=[state, max_tokens, temperature, top_p], outputs=[chatbot, state])
stop_btn.click(stop_stream, inputs=[], outputs=[])
app.launch(share=True) |