Spaces:
Sleeping
Sleeping
File size: 4,592 Bytes
1d58561 07e5e01 da278a5 8f0265c a1b31ed da278a5 8f0265c 07e5e01 da278a5 07e5e01 8f0265c a1b31ed 8f0265c a1b31ed 8f0265c a1b31ed 8f0265c a1b31ed 07e5e01 8f0265c a1b31ed 1d58561 8f0265c 1d58561 8f0265c 1d58561 8f0265c 1d58561 a1b31ed 1d58561 8f0265c a1b31ed 1d58561 8f0265c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import eventlet
eventlet.monkey_patch(socket=True, select=True, thread=True)
import eventlet.wsgi
from flask import Flask, render_template, request
from flask_socketio import SocketIO
from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM
import torch
app = Flask(__name__)
socketio = SocketIO(
app,
async_mode='eventlet',
message_queue=None,
ping_timeout=60,
ping_interval=25,
cors_allowed_origins="*",
logger=True,
engineio_logger=True,
async_handlers=True
)
# Initialize models and tokenizers
MODELS = {
"qwen": {
"name": "Qwen/Qwen2.5-0.5B-Instruct",
"tokenizer": None,
"model": None,
"uses_chat_template": True # Qwen uses chat template
},
"gpt2": {
"name": "gpt2",
"tokenizer": None,
"model": None,
"uses_chat_template": False # GPT2 doesn't use chat template
}
}
# Load models and tokenizers
for model_key, model_info in MODELS.items():
model_info["tokenizer"] = AutoTokenizer.from_pretrained(model_info["name"])
model_info["model"] = AutoModelForCausalLM.from_pretrained(
model_info["name"],
torch_dtype="auto",
device_map="auto"
)
# Add pad token for GPT2 if it doesn't have one
if model_key == "gpt2" and model_info["tokenizer"].pad_token is None:
model_info["tokenizer"].pad_token = model_info["tokenizer"].eos_token
model_info["model"].config.pad_token_id = model_info["model"].config.eos_token_id
class WebSocketBeamStreamer(MultiBeamTextStreamer):
"""Custom streamer that sends updates through websockets with adjustable speed"""
def __init__(self, tokenizer, num_beams, sleep_time=0, skip_prompt=True):
super().__init__(
tokenizer,
num_beams=num_beams,
skip_prompt=skip_prompt,
on_beam_update=self.on_beam_update,
on_beam_finished=self.on_beam_finished
)
self.beam_texts = {i: "" for i in range(num_beams)}
self.sleep_time = sleep_time
def on_beam_update(self, beam_idx: int, new_text: str):
self.beam_texts[beam_idx] = new_text
if self.sleep_time > 0:
eventlet.sleep(self.sleep_time / 1000)
socketio.emit('beam_update', {
'beam_idx': beam_idx,
'text': new_text
}, namespace='/', callback=lambda: eventlet.sleep(0))
socketio.sleep(0)
def on_beam_finished(self, final_text: str):
socketio.emit('beam_finished', {
'text': final_text
})
@app.route('/')
def index():
return render_template('index.html')
@socketio.on('generate')
def handle_generation(data):
socketio.emit('generation_started')
prompt = data['prompt']
model_name = data.get('model', 'qwen') # Default to qwen if not specified
num_beams = data.get('num_beams', 5)
max_new_tokens = data.get('max_tokens', 512)
sleep_time = data.get('sleep_time', 0)
# Get the selected model info
model_info = MODELS[model_name]
model = model_info["model"]
tokenizer = model_info["tokenizer"]
# Prepare input text based on model type
if model_info["uses_chat_template"]:
# For Qwen, use chat template
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
else:
# For GPT2, use direct prompt
text = prompt
# Prepare inputs
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Initialize streamer
streamer = WebSocketBeamStreamer(
tokenizer=tokenizer,
num_beams=num_beams,
sleep_time=sleep_time,
skip_prompt=True
)
try:
# Generate with beam search
with torch.no_grad():
model.generate(
**model_inputs,
num_beams=num_beams,
num_return_sequences=num_beams,
max_new_tokens=max_new_tokens,
output_scores=True,
return_dict_in_generate=True,
early_stopping=True,
streamer=streamer,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
except Exception as e:
socketio.emit('generation_error', {'error': str(e)})
finally:
socketio.emit('generation_completed') |