File size: 3,516 Bytes
911fcc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import time

from tools.schema import ServeMessage, ServeResponse, ServeStreamResponse
from tools.server.agent.generation_utils import (
    initialize_decode_buffers,
    process_response_tokens,
    send_reset_buffer,
)
from tools.server.agent.pre_generation_utils import (
    create_generation_request,
    send_generation_request,
)


def generate_responses(
    input_queue, tokenizer, config, request, prompt, im_end_id, device
):
    """
    Main generation function that handles the conversation, encodes the request,
    sends the generation request, and handles decoding/streaming.
    It returns a response generator (ServeResponse or ServeStreamResponse).
    """
    stats = {}
    start = time.time()
    stats["start_time"] = start
    stats["tokens_count"] = 0

    # Prepare and send the generation request
    req = create_generation_request(prompt, request, im_end_id, device)
    response_queue = send_generation_request(input_queue, req)
    decode_buffer, parts, finished = initialize_decode_buffers(request.num_samples)

    while True:
        response = response_queue.get()

        # Handle abnormal finish or error
        if response in ["stop", "error"]:
            finish_reason = response
            break

        # Process the response tokens
        is_first_token = stats["tokens_count"] == 0
        responses = process_response_tokens(
            response,
            tokenizer,
            config,
            request,
            decode_buffer,
            parts,
            finished,
            im_end_id,
            stats,
            start,
            is_first_token,
        )

        # Yield the responses if streaming
        if request.streaming and responses:
            for r in responses:
                yield r

        stats["tokens_count"] += 1

        # Check if all samples are finished
        if all(finished):
            finish_reason = "stop"
            break

    # Finalize the response
    final_responses = finalize_response(
        request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
    )
    for fr in final_responses:
        yield fr


def finalize_response(
    request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
):
    """
    Finalize the response by sending the remaining text buffers.
    """
    responses = []

    # Send the remaining text buffers
    for sample_id in range(request.num_samples):
        responses.extend(
            send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
        )

    # Calculate the final stats
    stats["total_time"] = (time.time() - stats["start_time"]) * 1000
    stats["total_tokens"] = stats["tokens_count"]

    # If streaming, send the final chunks for each sample
    if request.streaming:
        for sample_id in range(request.num_samples):
            if finished[sample_id]:
                continue
            responses.append(
                ServeStreamResponse(
                    finish_reason=finish_reason, stats=stats, sample_id=sample_id
                )
            )
    else:
        # If not streaming, send the full messages for each sample
        full_messages = [
            ServeMessage(role="assistant", parts=parts[i])
            for i in range(request.num_samples)
        ]
        responses.append(
            ServeResponse(
                messages=full_messages,
                finish_reason=finish_reason,
                stats=stats,
            )
        )

    return responses