File size: 2,614 Bytes
911fcc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import queue

from fish_speech.conversation import Conversation, Message
from fish_speech.tokenizer import IM_END_TOKEN
from tools.llama.generate import GenerateRequest


def prepare_messages(request, tokenizer, config):
    """
    Reorganise the provided list of messages into a conversation.
    Encode the conversation for inference.
    """
    # Convert the messages to ConversationMessage objects
    messages = [msg.to_conversation_message() for msg in request.messages]

    if len(messages) < 1:
        raise ValueError("At least one message is required")

    # Check the last message to determine the next step
    last_role = messages[-1].role
    match last_role:
        case "user":
            # The last message is from the user, ask the assistant to respond with a new message
            messages.append(
                Message(role="assistant", parts=[], add_im_end=False, modality="voice")
            )
        case "raw":
            # The last message is raw text, ask the assistant to complete it
            messages[-1].add_im_start = False
            messages[-1].add_im_end = False
            messages[-1].modality = "voice"
        case "assistant":
            # The last message is from the assistant, ask the assistant to continue
            messages[-1].add_im_end = False
        case _:
            # We expect it to be assistant if not user or raw
            raise ValueError("The last message must be from the assistant, user or raw")

    # Create a conversation object and encode it for inference
    conv = Conversation(messages=messages)
    prompt = conv.encode_for_inference(
        tokenizer=tokenizer, num_codebooks=config.num_codebooks
    )
    im_end_id = tokenizer.get_token_id(IM_END_TOKEN)

    return prompt, im_end_id


def create_generation_request(prompt, request, im_end_id, device):
    """
    Convert the request into a dictionary that can be sent to the model for generation.
    """
    req = {
        "prompt": prompt.to(device),
        "max_new_tokens": request.max_new_tokens,
        "im_end_id": im_end_id,
        "temperature": request.temperature,
        "top_p": request.top_p,
        "repetition_penalty": request.repetition_penalty,
        "num_samples": request.num_samples,
        "early_stop_threshold": request.early_stop_threshold,
    }
    return req


def send_generation_request(input_queue, req):
    """
    Send the generation request to the model and return a queue to get the response.
    """
    response_queue = queue.Queue()
    input_queue.put(GenerateRequest(req, response_queue))
    return response_queue