mariusjabami commited on
Commit
8751f54
verified
1 Parent(s): e5039e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -84
app.py CHANGED
@@ -1,21 +1,21 @@
1
- import os
2
- import time
3
  import threading
 
4
  import gradio as gr
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
6
  import torch
7
 
8
- # Carregar modelo local
9
  model_id = "lambdaindie/lambda-1v-1B"
10
  tokenizer = AutoTokenizer.from_pretrained(model_id)
11
  model = AutoModelForCausalLM.from_pretrained(
12
  model_id,
13
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
14
  )
15
- model.to("cuda" if torch.cuda.is_available() else "cpu")
 
16
  model.eval()
17
 
18
- # Estilo
19
  css = """
20
  @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono&display=swap');
21
  * {
@@ -30,115 +30,78 @@ textarea, input, button, select {
30
  color: #e0e0e0 !important;
31
  border: 1px solid #444 !important;
32
  }
33
- .markdown-think {
34
- background-color: #1e1e1e;
35
- border-left: 4px solid #555;
36
- padding: 10px;
37
- margin-bottom: 8px;
38
- font-style: italic;
39
- white-space: pre-wrap;
40
- animation: pulse 1.5s infinite ease-in-out;
41
- }
42
- @keyframes pulse {
43
- 0% { opacity: 0.6; }
44
- 50% { opacity: 1.0; }
45
- 100% { opacity: 0.6; }
46
- }
47
  """
48
 
49
- theme = gr.themes.Base(
50
- primary_hue="gray",
51
- font=[gr.themes.GoogleFont("JetBrains Mono"), "monospace"]
52
- ).set(
53
- body_background_fill="#111",
54
- body_text_color="#e0e0e0",
55
- button_primary_background_fill="#333",
56
- button_primary_text_color="#e0e0e0",
57
- input_background_fill="#222",
58
- input_border_color="#444",
59
- block_title_text_color="#fff"
60
- )
61
-
62
- # Flag de parada
63
  stop_signal = False
64
 
65
  def stop_stream():
66
  global stop_signal
67
  stop_signal = True
68
 
69
- def respond(history, system_message, max_tokens, temperature, top_p):
 
70
  global stop_signal
71
  stop_signal = False
72
 
73
- # Construir prompt
74
- prompt = ""
75
- if system_message:
76
- prompt += system_message + "\n\n"
77
-
78
- for msg in history:
79
- role = msg["role"]
80
- content = msg["content"]
81
- if role == "user":
82
- prompt += f"User: {content}\n"
83
- elif role == "assistant":
84
- prompt += f"Assistant: {content}\n"
85
-
86
- prompt += "Assistant:"
87
 
88
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
89
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
90
 
91
  generation_kwargs = dict(
92
- **inputs,
 
93
  streamer=streamer,
94
  max_new_tokens=max_tokens,
95
  temperature=temperature,
96
  top_p=top_p,
97
  do_sample=True,
 
98
  )
99
 
100
  thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
101
  thread.start()
102
 
103
- output = ""
104
- start = time.time()
105
-
106
  for token in streamer:
107
  if stop_signal:
108
  break
109
- output += token
110
- yield history + [{"role": "assistant", "content": output}]
111
-
112
- end = time.time()
113
- yield history + [
114
- {"role": "assistant", "content": output},
115
- {"role": "system", "content": f"Pensou por {end - start:.1f} segundos"}
116
- ]
117
-
118
- # Interface
119
- with gr.Blocks(css=css, theme=theme) as app:
120
- chatbot = gr.Chatbot(label="", type="messages")
121
- state = gr.State([])
122
-
123
- with gr.Row():
124
- msg = gr.Textbox(label="Mensagem")
125
- send_btn = gr.Button("Enviar")
126
- stop_btn = gr.Button("Parar")
127
-
128
- with gr.Accordion("Configura莽玫es Avan莽adas", open=False):
129
- system_message = gr.Textbox(label="System Message", value="")
130
- max_tokens = gr.Slider(64, 2048, value=256, step=1, label="Max Tokens")
131
- temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
132
- top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
133
-
134
- def handle_user_msg(user_msg, chat_history):
135
- if user_msg:
136
- chat_history = chat_history + [{"role": "user", "content": user_msg}]
137
  return "", chat_history
138
 
139
- send_btn.click(fn=handle_user_msg, inputs=[msg, state], outputs=[msg, state])\
140
- .then(fn=respond, inputs=[state, system_message, max_tokens, temperature, top_p], outputs=[chatbot, state])
 
 
 
 
 
 
 
 
 
141
 
142
- stop_btn.click(fn=stop_stream, inputs=[], outputs=[])
143
 
144
  app.launch(share=True)
 
 
 
1
  import threading
2
+ import time
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
5
  import torch
6
 
7
+ # Configura莽茫o do modelo
8
  model_id = "lambdaindie/lambda-1v-1B"
9
  tokenizer = AutoTokenizer.from_pretrained(model_id)
10
  model = AutoModelForCausalLM.from_pretrained(
11
  model_id,
12
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
13
  )
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ model.to(device)
16
  model.eval()
17
 
18
+ # CSS visual
19
  css = """
20
  @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono&display=swap');
21
  * {
 
30
  color: #e0e0e0 !important;
31
  border: 1px solid #444 !important;
32
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  """
34
 
35
+ # Controle global de parada
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  stop_signal = False
37
 
38
  def stop_stream():
39
  global stop_signal
40
  stop_signal = True
41
 
42
+ # Gera莽茫o com streaming
43
+ def generate_response(message, max_tokens, temperature, top_p):
44
  global stop_signal
45
  stop_signal = False
46
 
47
+ prompt = f"Question: {message}\nThinking:"
48
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
49
 
 
50
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
51
 
52
  generation_kwargs = dict(
53
+ input_ids=inputs["input_ids"],
54
+ attention_mask=inputs["attention_mask"],
55
  streamer=streamer,
56
  max_new_tokens=max_tokens,
57
  temperature=temperature,
58
  top_p=top_p,
59
  do_sample=True,
60
+ eos_token_id=tokenizer.eos_token_id
61
  )
62
 
63
  thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
64
  thread.start()
65
 
66
+ full_text = ""
 
 
67
  for token in streamer:
68
  if stop_signal:
69
  break
70
+ full_text += token
71
+ yield full_text.strip()
72
+
73
+ if stop_signal:
74
+ return
75
+
76
+ # Interface Gradio
77
+ with gr.Blocks(css=css) as app:
78
+ chatbot = gr.Chatbot(label="位", elem_id="chatbot")
79
+ msg = gr.Textbox(label="Mensagem", placeholder="Digite aqui...", lines=2)
80
+ send_btn = gr.Button("Enviar")
81
+ stop_btn = gr.Button("Parar")
82
+
83
+ max_tokens = gr.Slider(64, 512, value=128, step=1, label="Max Tokens")
84
+ temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature")
85
+ top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
86
+
87
+ state = gr.State([]) # hist贸rico apenas visual
88
+
89
+ def update_chat(message, chat_history):
90
+ chat_history = chat_history + [(message, None)] # adiciona s贸 a pergunta
 
 
 
 
 
 
 
91
  return "", chat_history
92
 
93
+ def generate_full(chat_history, max_tokens, temperature, top_p):
94
+ message = chat_history[-1][0] # 煤ltima mensagem enviada
95
+ visual_history = chat_history[:-1] # remove temporariamente a entrada pendente
96
+
97
+ full_response = ""
98
+ for chunk in generate_response(message, max_tokens, temperature, top_p):
99
+ full_response = chunk
100
+ yield visual_history + [(message, full_response)], visual_history + [(message, full_response)]
101
+
102
+ send_btn.click(update_chat, inputs=[msg, state], outputs=[msg, state]) \
103
+ .then(generate_full, inputs=[state, max_tokens, temperature, top_p], outputs=[chatbot, state])
104
 
105
+ stop_btn.click(stop_stream, inputs=[], outputs=[])
106
 
107
  app.launch(share=True)