Moshe Ofer commited on
Commit
8f0265c
·
1 Parent(s): 59cd46d
Files changed (3) hide show
  1. __pycache__/app.cpython-312.pyc +0 -0
  2. app.py +63 -37
  3. templates/index.html +10 -0
__pycache__/app.cpython-312.pyc CHANGED
Binary files a/__pycache__/app.cpython-312.pyc and b/__pycache__/app.cpython-312.pyc differ
 
app.py CHANGED
@@ -2,8 +2,7 @@ import eventlet
2
  eventlet.monkey_patch(socket=True, select=True, thread=True)
3
 
4
  import eventlet.wsgi
5
-
6
- from flask import Flask, render_template
7
  from flask_socketio import SocketIO
8
  from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM
9
  import torch
@@ -12,23 +11,44 @@ app = Flask(__name__)
12
  socketio = SocketIO(
13
  app,
14
  async_mode='eventlet',
15
- message_queue=None, # Explicitly set to None for single-worker setup
16
  ping_timeout=60,
17
  ping_interval=25,
18
  cors_allowed_origins="*",
19
  logger=True,
20
  engineio_logger=True,
21
- async_handlers=True # Enable async handlers
22
- )
23
- # Initialize model and tokenizer
24
- MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
25
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
26
- model = AutoModelForCausalLM.from_pretrained(
27
- MODEL_NAME,
28
- torch_dtype="auto",
29
- device_map="auto"
30
  )
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  class WebSocketBeamStreamer(MultiBeamTextStreamer):
34
  """Custom streamer that sends updates through websockets with adjustable speed"""
@@ -42,22 +62,19 @@ class WebSocketBeamStreamer(MultiBeamTextStreamer):
42
  on_beam_finished=self.on_beam_finished
43
  )
44
  self.beam_texts = {i: "" for i in range(num_beams)}
45
- self.sleep_time = sleep_time # Sleep time in milliseconds
46
 
47
  def on_beam_update(self, beam_idx: int, new_text: str):
48
- """Send beam updates through websocket with delay"""
49
  self.beam_texts[beam_idx] = new_text
50
  if self.sleep_time > 0:
51
- eventlet.sleep(self.sleep_time / 1000) # Convert milliseconds to seconds
52
- # Force immediate emit and flush
53
  socketio.emit('beam_update', {
54
  'beam_idx': beam_idx,
55
  'text': new_text
56
  }, namespace='/', callback=lambda: eventlet.sleep(0))
57
- socketio.sleep(0) # Force context switch
58
 
59
  def on_beam_finished(self, final_text: str):
60
- """Send completion notification through websocket"""
61
  socketio.emit('beam_finished', {
62
  'text': final_text
63
  })
@@ -70,31 +87,39 @@ def index():
70
 
71
  @socketio.on('generate')
72
  def handle_generation(data):
73
- # Emit a generation start event
74
  socketio.emit('generation_started')
75
 
76
  prompt = data['prompt']
 
77
  num_beams = data.get('num_beams', 5)
78
  max_new_tokens = data.get('max_tokens', 512)
79
- sleep_time = data.get('sleep_time', 0) # Get sleep time from frontend
80
-
81
- # Create messages format
82
- messages = [
83
- {"role": "system", "content": "You are a helpful assistant."},
84
- {"role": "user", "content": prompt}
85
- ]
86
-
87
- # Apply chat template
88
- text = tokenizer.apply_chat_template(
89
- messages,
90
- tokenize=False,
91
- add_generation_prompt=True
92
- )
 
 
 
 
 
 
 
 
93
 
94
  # Prepare inputs
95
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
96
 
97
- # Initialize streamer with sleep time
98
  streamer = WebSocketBeamStreamer(
99
  tokenizer=tokenizer,
100
  num_beams=num_beams,
@@ -113,10 +138,11 @@ def handle_generation(data):
113
  output_scores=True,
114
  return_dict_in_generate=True,
115
  early_stopping=True,
116
- streamer=streamer
 
 
117
  )
118
  except Exception as e:
119
  socketio.emit('generation_error', {'error': str(e)})
120
  finally:
121
- # Emit generation completed event
122
- socketio.emit('generation_completed')
 
2
  eventlet.monkey_patch(socket=True, select=True, thread=True)
3
 
4
  import eventlet.wsgi
5
+ from flask import Flask, render_template, request
 
6
  from flask_socketio import SocketIO
7
  from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM
8
  import torch
 
11
  socketio = SocketIO(
12
  app,
13
  async_mode='eventlet',
14
+ message_queue=None,
15
  ping_timeout=60,
16
  ping_interval=25,
17
  cors_allowed_origins="*",
18
  logger=True,
19
  engineio_logger=True,
20
+ async_handlers=True
 
 
 
 
 
 
 
 
21
  )
22
 
23
+ # Initialize models and tokenizers
24
+ MODELS = {
25
+ "qwen": {
26
+ "name": "Qwen/Qwen2.5-0.5B-Instruct",
27
+ "tokenizer": None,
28
+ "model": None,
29
+ "uses_chat_template": True # Qwen uses chat template
30
+ },
31
+ "gpt2": {
32
+ "name": "gpt2",
33
+ "tokenizer": None,
34
+ "model": None,
35
+ "uses_chat_template": False # GPT2 doesn't use chat template
36
+ }
37
+ }
38
+
39
+ # Load models and tokenizers
40
+ for model_key, model_info in MODELS.items():
41
+ model_info["tokenizer"] = AutoTokenizer.from_pretrained(model_info["name"])
42
+ model_info["model"] = AutoModelForCausalLM.from_pretrained(
43
+ model_info["name"],
44
+ torch_dtype="auto",
45
+ device_map="auto"
46
+ )
47
+ # Add pad token for GPT2 if it doesn't have one
48
+ if model_key == "gpt2" and model_info["tokenizer"].pad_token is None:
49
+ model_info["tokenizer"].pad_token = model_info["tokenizer"].eos_token
50
+ model_info["model"].config.pad_token_id = model_info["model"].config.eos_token_id
51
+
52
 
53
  class WebSocketBeamStreamer(MultiBeamTextStreamer):
54
  """Custom streamer that sends updates through websockets with adjustable speed"""
 
62
  on_beam_finished=self.on_beam_finished
63
  )
64
  self.beam_texts = {i: "" for i in range(num_beams)}
65
+ self.sleep_time = sleep_time
66
 
67
  def on_beam_update(self, beam_idx: int, new_text: str):
 
68
  self.beam_texts[beam_idx] = new_text
69
  if self.sleep_time > 0:
70
+ eventlet.sleep(self.sleep_time / 1000)
 
71
  socketio.emit('beam_update', {
72
  'beam_idx': beam_idx,
73
  'text': new_text
74
  }, namespace='/', callback=lambda: eventlet.sleep(0))
75
+ socketio.sleep(0)
76
 
77
  def on_beam_finished(self, final_text: str):
 
78
  socketio.emit('beam_finished', {
79
  'text': final_text
80
  })
 
87
 
88
  @socketio.on('generate')
89
  def handle_generation(data):
 
90
  socketio.emit('generation_started')
91
 
92
  prompt = data['prompt']
93
+ model_name = data.get('model', 'qwen') # Default to qwen if not specified
94
  num_beams = data.get('num_beams', 5)
95
  max_new_tokens = data.get('max_tokens', 512)
96
+ sleep_time = data.get('sleep_time', 0)
97
+
98
+ # Get the selected model info
99
+ model_info = MODELS[model_name]
100
+ model = model_info["model"]
101
+ tokenizer = model_info["tokenizer"]
102
+
103
+ # Prepare input text based on model type
104
+ if model_info["uses_chat_template"]:
105
+ # For Qwen, use chat template
106
+ messages = [
107
+ {"role": "system", "content": "You are a helpful assistant."},
108
+ {"role": "user", "content": prompt}
109
+ ]
110
+ text = tokenizer.apply_chat_template(
111
+ messages,
112
+ tokenize=False,
113
+ add_generation_prompt=True
114
+ )
115
+ else:
116
+ # For GPT2, use direct prompt
117
+ text = prompt
118
 
119
  # Prepare inputs
120
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
121
 
122
+ # Initialize streamer
123
  streamer = WebSocketBeamStreamer(
124
  tokenizer=tokenizer,
125
  num_beams=num_beams,
 
138
  output_scores=True,
139
  return_dict_in_generate=True,
140
  early_stopping=True,
141
+ streamer=streamer,
142
+ pad_token_id=tokenizer.pad_token_id,
143
+ eos_token_id=tokenizer.eos_token_id
144
  )
145
  except Exception as e:
146
  socketio.emit('generation_error', {'error': str(e)})
147
  finally:
148
+ socketio.emit('generation_completed')
 
templates/index.html CHANGED
@@ -357,6 +357,14 @@
357
  <label for="max_tokens">Max tokens</label>
358
  <input type="number" id="max_tokens" value="512" min="1">
359
  </div>
 
 
 
 
 
 
 
 
360
  </div>
361
 
362
  <div class="slider-container">
@@ -517,6 +525,7 @@
517
  resetConnection();
518
 
519
  const prompt = document.getElementById('prompt').value;
 
520
  const numBeams = parseInt(document.getElementById('num_beams').value);
521
  const maxTokens = parseInt(document.getElementById('max_tokens').value);
522
  const sleepTime = parseInt(document.getElementById('sleep_time').value);
@@ -527,6 +536,7 @@
527
 
528
  socket.emit('generate', {
529
  prompt: prompt,
 
530
  num_beams: numBeams,
531
  max_tokens: maxTokens,
532
  sleep_time: sleepTime
 
357
  <label for="max_tokens">Max tokens</label>
358
  <input type="number" id="max_tokens" value="512" min="1">
359
  </div>
360
+
361
+ <div class="input-group">
362
+ <label for="model_select">Model</label>
363
+ <select id="model_select" class="form-select">
364
+ <option value="gpt2">GPT-2</option>
365
+ <option value="qwen">Qwen</option>
366
+ </select>
367
+ </div>
368
  </div>
369
 
370
  <div class="slider-container">
 
525
  resetConnection();
526
 
527
  const prompt = document.getElementById('prompt').value;
528
+ const model = document.getElementById('model_select').value;
529
  const numBeams = parseInt(document.getElementById('num_beams').value);
530
  const maxTokens = parseInt(document.getElementById('max_tokens').value);
531
  const sleepTime = parseInt(document.getElementById('sleep_time').value);
 
536
 
537
  socket.emit('generate', {
538
  prompt: prompt,
539
+ model: model,
540
  num_beams: numBeams,
541
  max_tokens: maxTokens,
542
  sleep_time: sleepTime