File size: 3,531 Bytes
121f84a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c7bf64
 
 
 
 
 
 
 
 
 
 
121f84a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c7bf64
121f84a
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import json
import requests


def check_server_health(cloud_gateway_api: str):
    """
    Use the appropriate API endpoint to check the server health.
    Args:
        cloud_gateway_api: API endpoint to probe.

    Returns:
        True if server is active, false otherwise.
    """
    try:
        response = requests.get(cloud_gateway_api + "/health")
        if response.status_code == 200:
            return True
    except requests.ConnectionError:
        print("Failed to establish connection to the server.")

    return False


def request_generation(
    message: str,
    system_prompt: str,
    cloud_gateway_api: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
):
    """
    Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize
    token-by-token generation from LLM.

    Args:
        message: prompt from the user.
        system_prompt: system prompt to append.
        cloud_gateway_api (str): API endpoint to send the request.
        max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt.
        temperature: the value used to module the next token probabilities.
        top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p
                or higher are kept for generation.
        top_k: the number of highest probability vocabulary tokens to keep for top-k-filtering.
        repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.

    Returns:

    """

    payload = {
        "model": "google/gemma-3-27b-it",
        "messages": [
            *(
                [
                    {
                        "role": "system",
                        "content": [{"type": "text", "text": system_prompt}],
                    }
                ]
                if system_prompt
                else []
            ),
            {"role": "user", "content": [{"type": "text", "text": message}]},
        ],
        "max_tokens": max_new_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "repetition_penalty": repetition_penalty,
        "top_k": top_k,
        "stream": True,  # Enable streaming
    }

    with requests.post(
        cloud_gateway_api + "/v1/chat/completions", json=payload, stream=True
    ) as response:
        for chunk in response.iter_lines():
            if chunk:
                # Convert the chunk from bytes to a string and then parse it as json
                chunk_str = chunk.decode("utf-8")

                # Remove the `data: ` prefix from the chunk if it exists
                if chunk_str.startswith("data: "):
                    chunk_str = chunk_str[len("data: ") :]

                # Skip empty chunks
                if chunk_str.strip() == "[DONE]":
                    break

                # Parse the chunk into a JSON object
                try:
                    chunk_json = json.loads(chunk_str)

                    # Extract the "content" field from the choices
                    content = chunk_json["choices"][0]["delta"].get("content", "")

                    # Print the generated content as it's streamed
                    if content:
                        yield content
                except json.JSONDecodeError:
                    # Handle any potential errors in decoding
                    continue