File size: 4,592 Bytes
1d58561
07e5e01
da278a5
 
8f0265c
a1b31ed
 
 
 
 
da278a5
 
 
8f0265c
07e5e01
 
da278a5
 
07e5e01
8f0265c
a1b31ed
 
8f0265c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1b31ed
 
 
 
 
 
 
 
 
 
 
 
 
8f0265c
a1b31ed
 
 
 
8f0265c
a1b31ed
 
 
07e5e01
8f0265c
a1b31ed
 
 
 
 
 
 
 
 
 
 
 
 
 
1d58561
 
 
8f0265c
1d58561
 
8f0265c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d58561
 
 
 
8f0265c
1d58561
 
 
 
 
 
 
 
 
 
 
 
a1b31ed
1d58561
 
 
 
 
8f0265c
 
 
a1b31ed
1d58561
 
 
8f0265c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import eventlet
eventlet.monkey_patch(socket=True, select=True, thread=True)

import eventlet.wsgi
from flask import Flask, render_template, request
from flask_socketio import SocketIO
from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM
import torch

app = Flask(__name__)
socketio = SocketIO(
    app,
    async_mode='eventlet',
    message_queue=None,
    ping_timeout=60,
    ping_interval=25,
    cors_allowed_origins="*",
    logger=True,
    engineio_logger=True,
    async_handlers=True
)

# Initialize models and tokenizers
MODELS = {
    "qwen": {
        "name": "Qwen/Qwen2.5-0.5B-Instruct",
        "tokenizer": None,
        "model": None,
        "uses_chat_template": True  # Qwen uses chat template
    },
    "gpt2": {
        "name": "gpt2",
        "tokenizer": None,
        "model": None,
        "uses_chat_template": False  # GPT2 doesn't use chat template
    }
}

# Load models and tokenizers
for model_key, model_info in MODELS.items():
    model_info["tokenizer"] = AutoTokenizer.from_pretrained(model_info["name"])
    model_info["model"] = AutoModelForCausalLM.from_pretrained(
        model_info["name"],
        torch_dtype="auto",
        device_map="auto"
    )
    # Add pad token for GPT2 if it doesn't have one
    if model_key == "gpt2" and model_info["tokenizer"].pad_token is None:
        model_info["tokenizer"].pad_token = model_info["tokenizer"].eos_token
        model_info["model"].config.pad_token_id = model_info["model"].config.eos_token_id


class WebSocketBeamStreamer(MultiBeamTextStreamer):
    """Custom streamer that sends updates through websockets with adjustable speed"""

    def __init__(self, tokenizer, num_beams, sleep_time=0, skip_prompt=True):
        super().__init__(
            tokenizer,
            num_beams=num_beams,
            skip_prompt=skip_prompt,
            on_beam_update=self.on_beam_update,
            on_beam_finished=self.on_beam_finished
        )
        self.beam_texts = {i: "" for i in range(num_beams)}
        self.sleep_time = sleep_time

    def on_beam_update(self, beam_idx: int, new_text: str):
        self.beam_texts[beam_idx] = new_text
        if self.sleep_time > 0:
            eventlet.sleep(self.sleep_time / 1000)
        socketio.emit('beam_update', {
            'beam_idx': beam_idx,
            'text': new_text
        }, namespace='/', callback=lambda: eventlet.sleep(0))
        socketio.sleep(0)

    def on_beam_finished(self, final_text: str):
        socketio.emit('beam_finished', {
            'text': final_text
        })


@app.route('/')
def index():
    return render_template('index.html')


@socketio.on('generate')
def handle_generation(data):
    socketio.emit('generation_started')

    prompt = data['prompt']
    model_name = data.get('model', 'qwen')  # Default to qwen if not specified
    num_beams = data.get('num_beams', 5)
    max_new_tokens = data.get('max_tokens', 512)
    sleep_time = data.get('sleep_time', 0)

    # Get the selected model info
    model_info = MODELS[model_name]
    model = model_info["model"]
    tokenizer = model_info["tokenizer"]

    # Prepare input text based on model type
    if model_info["uses_chat_template"]:
        # For Qwen, use chat template
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt}
        ]
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
    else:
        # For GPT2, use direct prompt
        text = prompt

    # Prepare inputs
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    # Initialize streamer
    streamer = WebSocketBeamStreamer(
        tokenizer=tokenizer,
        num_beams=num_beams,
        sleep_time=sleep_time,
        skip_prompt=True
    )

    try:
        # Generate with beam search
        with torch.no_grad():
            model.generate(
                **model_inputs,
                num_beams=num_beams,
                num_return_sequences=num_beams,
                max_new_tokens=max_new_tokens,
                output_scores=True,
                return_dict_in_generate=True,
                early_stopping=True,
                streamer=streamer,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id
            )
    except Exception as e:
        socketio.emit('generation_error', {'error': str(e)})
    finally:
        socketio.emit('generation_completed')