Spaces:
Sleeping
Sleeping
Moshe Ofer
commited on
Commit
·
8f0265c
1
Parent(s):
59cd46d
GPT2
Browse files- __pycache__/app.cpython-312.pyc +0 -0
- app.py +63 -37
- templates/index.html +10 -0
__pycache__/app.cpython-312.pyc
CHANGED
Binary files a/__pycache__/app.cpython-312.pyc and b/__pycache__/app.cpython-312.pyc differ
|
|
app.py
CHANGED
@@ -2,8 +2,7 @@ import eventlet
|
|
2 |
eventlet.monkey_patch(socket=True, select=True, thread=True)
|
3 |
|
4 |
import eventlet.wsgi
|
5 |
-
|
6 |
-
from flask import Flask, render_template
|
7 |
from flask_socketio import SocketIO
|
8 |
from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM
|
9 |
import torch
|
@@ -12,23 +11,44 @@ app = Flask(__name__)
|
|
12 |
socketio = SocketIO(
|
13 |
app,
|
14 |
async_mode='eventlet',
|
15 |
-
message_queue=None,
|
16 |
ping_timeout=60,
|
17 |
ping_interval=25,
|
18 |
cors_allowed_origins="*",
|
19 |
logger=True,
|
20 |
engineio_logger=True,
|
21 |
-
async_handlers=True
|
22 |
-
)
|
23 |
-
# Initialize model and tokenizer
|
24 |
-
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
|
25 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
26 |
-
model = AutoModelForCausalLM.from_pretrained(
|
27 |
-
MODEL_NAME,
|
28 |
-
torch_dtype="auto",
|
29 |
-
device_map="auto"
|
30 |
)
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
class WebSocketBeamStreamer(MultiBeamTextStreamer):
|
34 |
"""Custom streamer that sends updates through websockets with adjustable speed"""
|
@@ -42,22 +62,19 @@ class WebSocketBeamStreamer(MultiBeamTextStreamer):
|
|
42 |
on_beam_finished=self.on_beam_finished
|
43 |
)
|
44 |
self.beam_texts = {i: "" for i in range(num_beams)}
|
45 |
-
self.sleep_time = sleep_time
|
46 |
|
47 |
def on_beam_update(self, beam_idx: int, new_text: str):
|
48 |
-
"""Send beam updates through websocket with delay"""
|
49 |
self.beam_texts[beam_idx] = new_text
|
50 |
if self.sleep_time > 0:
|
51 |
-
eventlet.sleep(self.sleep_time / 1000)
|
52 |
-
# Force immediate emit and flush
|
53 |
socketio.emit('beam_update', {
|
54 |
'beam_idx': beam_idx,
|
55 |
'text': new_text
|
56 |
}, namespace='/', callback=lambda: eventlet.sleep(0))
|
57 |
-
socketio.sleep(0)
|
58 |
|
59 |
def on_beam_finished(self, final_text: str):
|
60 |
-
"""Send completion notification through websocket"""
|
61 |
socketio.emit('beam_finished', {
|
62 |
'text': final_text
|
63 |
})
|
@@ -70,31 +87,39 @@ def index():
|
|
70 |
|
71 |
@socketio.on('generate')
|
72 |
def handle_generation(data):
|
73 |
-
# Emit a generation start event
|
74 |
socketio.emit('generation_started')
|
75 |
|
76 |
prompt = data['prompt']
|
|
|
77 |
num_beams = data.get('num_beams', 5)
|
78 |
max_new_tokens = data.get('max_tokens', 512)
|
79 |
-
sleep_time = data.get('sleep_time', 0)
|
80 |
-
|
81 |
-
#
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
messages
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
# Prepare inputs
|
95 |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
96 |
|
97 |
-
# Initialize streamer
|
98 |
streamer = WebSocketBeamStreamer(
|
99 |
tokenizer=tokenizer,
|
100 |
num_beams=num_beams,
|
@@ -113,10 +138,11 @@ def handle_generation(data):
|
|
113 |
output_scores=True,
|
114 |
return_dict_in_generate=True,
|
115 |
early_stopping=True,
|
116 |
-
streamer=streamer
|
|
|
|
|
117 |
)
|
118 |
except Exception as e:
|
119 |
socketio.emit('generation_error', {'error': str(e)})
|
120 |
finally:
|
121 |
-
|
122 |
-
socketio.emit('generation_completed')
|
|
|
2 |
eventlet.monkey_patch(socket=True, select=True, thread=True)
|
3 |
|
4 |
import eventlet.wsgi
|
5 |
+
from flask import Flask, render_template, request
|
|
|
6 |
from flask_socketio import SocketIO
|
7 |
from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM
|
8 |
import torch
|
|
|
11 |
socketio = SocketIO(
|
12 |
app,
|
13 |
async_mode='eventlet',
|
14 |
+
message_queue=None,
|
15 |
ping_timeout=60,
|
16 |
ping_interval=25,
|
17 |
cors_allowed_origins="*",
|
18 |
logger=True,
|
19 |
engineio_logger=True,
|
20 |
+
async_handlers=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
)
|
22 |
|
23 |
+
# Initialize models and tokenizers
|
24 |
+
MODELS = {
|
25 |
+
"qwen": {
|
26 |
+
"name": "Qwen/Qwen2.5-0.5B-Instruct",
|
27 |
+
"tokenizer": None,
|
28 |
+
"model": None,
|
29 |
+
"uses_chat_template": True # Qwen uses chat template
|
30 |
+
},
|
31 |
+
"gpt2": {
|
32 |
+
"name": "gpt2",
|
33 |
+
"tokenizer": None,
|
34 |
+
"model": None,
|
35 |
+
"uses_chat_template": False # GPT2 doesn't use chat template
|
36 |
+
}
|
37 |
+
}
|
38 |
+
|
39 |
+
# Load models and tokenizers
|
40 |
+
for model_key, model_info in MODELS.items():
|
41 |
+
model_info["tokenizer"] = AutoTokenizer.from_pretrained(model_info["name"])
|
42 |
+
model_info["model"] = AutoModelForCausalLM.from_pretrained(
|
43 |
+
model_info["name"],
|
44 |
+
torch_dtype="auto",
|
45 |
+
device_map="auto"
|
46 |
+
)
|
47 |
+
# Add pad token for GPT2 if it doesn't have one
|
48 |
+
if model_key == "gpt2" and model_info["tokenizer"].pad_token is None:
|
49 |
+
model_info["tokenizer"].pad_token = model_info["tokenizer"].eos_token
|
50 |
+
model_info["model"].config.pad_token_id = model_info["model"].config.eos_token_id
|
51 |
+
|
52 |
|
53 |
class WebSocketBeamStreamer(MultiBeamTextStreamer):
|
54 |
"""Custom streamer that sends updates through websockets with adjustable speed"""
|
|
|
62 |
on_beam_finished=self.on_beam_finished
|
63 |
)
|
64 |
self.beam_texts = {i: "" for i in range(num_beams)}
|
65 |
+
self.sleep_time = sleep_time
|
66 |
|
67 |
def on_beam_update(self, beam_idx: int, new_text: str):
|
|
|
68 |
self.beam_texts[beam_idx] = new_text
|
69 |
if self.sleep_time > 0:
|
70 |
+
eventlet.sleep(self.sleep_time / 1000)
|
|
|
71 |
socketio.emit('beam_update', {
|
72 |
'beam_idx': beam_idx,
|
73 |
'text': new_text
|
74 |
}, namespace='/', callback=lambda: eventlet.sleep(0))
|
75 |
+
socketio.sleep(0)
|
76 |
|
77 |
def on_beam_finished(self, final_text: str):
|
|
|
78 |
socketio.emit('beam_finished', {
|
79 |
'text': final_text
|
80 |
})
|
|
|
87 |
|
88 |
@socketio.on('generate')
|
89 |
def handle_generation(data):
|
|
|
90 |
socketio.emit('generation_started')
|
91 |
|
92 |
prompt = data['prompt']
|
93 |
+
model_name = data.get('model', 'qwen') # Default to qwen if not specified
|
94 |
num_beams = data.get('num_beams', 5)
|
95 |
max_new_tokens = data.get('max_tokens', 512)
|
96 |
+
sleep_time = data.get('sleep_time', 0)
|
97 |
+
|
98 |
+
# Get the selected model info
|
99 |
+
model_info = MODELS[model_name]
|
100 |
+
model = model_info["model"]
|
101 |
+
tokenizer = model_info["tokenizer"]
|
102 |
+
|
103 |
+
# Prepare input text based on model type
|
104 |
+
if model_info["uses_chat_template"]:
|
105 |
+
# For Qwen, use chat template
|
106 |
+
messages = [
|
107 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
108 |
+
{"role": "user", "content": prompt}
|
109 |
+
]
|
110 |
+
text = tokenizer.apply_chat_template(
|
111 |
+
messages,
|
112 |
+
tokenize=False,
|
113 |
+
add_generation_prompt=True
|
114 |
+
)
|
115 |
+
else:
|
116 |
+
# For GPT2, use direct prompt
|
117 |
+
text = prompt
|
118 |
|
119 |
# Prepare inputs
|
120 |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
121 |
|
122 |
+
# Initialize streamer
|
123 |
streamer = WebSocketBeamStreamer(
|
124 |
tokenizer=tokenizer,
|
125 |
num_beams=num_beams,
|
|
|
138 |
output_scores=True,
|
139 |
return_dict_in_generate=True,
|
140 |
early_stopping=True,
|
141 |
+
streamer=streamer,
|
142 |
+
pad_token_id=tokenizer.pad_token_id,
|
143 |
+
eos_token_id=tokenizer.eos_token_id
|
144 |
)
|
145 |
except Exception as e:
|
146 |
socketio.emit('generation_error', {'error': str(e)})
|
147 |
finally:
|
148 |
+
socketio.emit('generation_completed')
|
|
templates/index.html
CHANGED
@@ -357,6 +357,14 @@
|
|
357 |
<label for="max_tokens">Max tokens</label>
|
358 |
<input type="number" id="max_tokens" value="512" min="1">
|
359 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
</div>
|
361 |
|
362 |
<div class="slider-container">
|
@@ -517,6 +525,7 @@
|
|
517 |
resetConnection();
|
518 |
|
519 |
const prompt = document.getElementById('prompt').value;
|
|
|
520 |
const numBeams = parseInt(document.getElementById('num_beams').value);
|
521 |
const maxTokens = parseInt(document.getElementById('max_tokens').value);
|
522 |
const sleepTime = parseInt(document.getElementById('sleep_time').value);
|
@@ -527,6 +536,7 @@
|
|
527 |
|
528 |
socket.emit('generate', {
|
529 |
prompt: prompt,
|
|
|
530 |
num_beams: numBeams,
|
531 |
max_tokens: maxTokens,
|
532 |
sleep_time: sleepTime
|
|
|
357 |
<label for="max_tokens">Max tokens</label>
|
358 |
<input type="number" id="max_tokens" value="512" min="1">
|
359 |
</div>
|
360 |
+
|
361 |
+
<div class="input-group">
|
362 |
+
<label for="model_select">Model</label>
|
363 |
+
<select id="model_select" class="form-select">
|
364 |
+
<option value="gpt2">GPT-2</option>
|
365 |
+
<option value="qwen">Qwen</option>
|
366 |
+
</select>
|
367 |
+
</div>
|
368 |
</div>
|
369 |
|
370 |
<div class="slider-container">
|
|
|
525 |
resetConnection();
|
526 |
|
527 |
const prompt = document.getElementById('prompt').value;
|
528 |
+
const model = document.getElementById('model_select').value;
|
529 |
const numBeams = parseInt(document.getElementById('num_beams').value);
|
530 |
const maxTokens = parseInt(document.getElementById('max_tokens').value);
|
531 |
const sleepTime = parseInt(document.getElementById('sleep_time').value);
|
|
|
536 |
|
537 |
socket.emit('generate', {
|
538 |
prompt: prompt,
|
539 |
+
model: model,
|
540 |
num_beams: numBeams,
|
541 |
max_tokens: maxTokens,
|
542 |
sleep_time: sleepTime
|