File size: 4,486 Bytes
e3278e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import enum
import heapq
from typing import Optional

from pydantic import BaseModel

from litellm import print_verbose
from litellm.caching.caching import DualCache, RedisCache


class SchedulerCacheKeys(enum.Enum):
    queue = "scheduler:queue"
    default_in_memory_ttl = 5  # cache queue in-memory for 5s when redis cache available


class DefaultPriorities(enum.Enum):
    High = 0
    Medium = 128
    Low = 255


class FlowItem(BaseModel):
    priority: int  # Priority between 0 and 255
    request_id: str
    model_name: str


class Scheduler:
    cache: DualCache

    def __init__(
        self,
        polling_interval: Optional[float] = None,
        redis_cache: Optional[RedisCache] = None,
    ):
        """
        polling_interval: float or null - frequency of polling queue. Default is 3ms.
        """
        self.queue: list = []
        default_in_memory_ttl: Optional[float] = None
        if redis_cache is not None:
            # if redis-cache available frequently poll that instead of using in-memory.
            default_in_memory_ttl = SchedulerCacheKeys.default_in_memory_ttl.value
        self.cache = DualCache(
            redis_cache=redis_cache, default_in_memory_ttl=default_in_memory_ttl
        )
        self.polling_interval = polling_interval or 0.03  # default to 3ms

    async def add_request(self, request: FlowItem):
        # We use the priority directly, as lower values indicate higher priority
        # get the queue
        queue = await self.get_queue(model_name=request.model_name)
        # update the queue
        heapq.heappush(queue, (request.priority, request.request_id))

        # save the queue
        await self.save_queue(queue=queue, model_name=request.model_name)

    async def poll(self, id: str, model_name: str, health_deployments: list) -> bool:
        """
        Return if request can be processed.

        Returns:
        - True:
            * If healthy deployments are available
            * OR If request at the top of queue
        - False:
            * If no healthy deployments available
            * AND request not at the top of queue
        """
        queue = await self.get_queue(model_name=model_name)
        if not queue:
            raise Exception(
                "Incorrectly setup. Queue is invalid. Queue={}".format(queue)
            )

        # ------------
        # Setup values
        # ------------

        print_verbose(f"len(health_deployments): {len(health_deployments)}")
        if len(health_deployments) == 0:
            print_verbose(f"queue: {queue}, seeking id={id}")
            # Check if the id is at the top of the heap
            if queue[0][1] == id:
                # Remove the item from the queue
                heapq.heappop(queue)
                print_verbose(f"Popped id: {id}")
                return True
            else:
                return False

        return True

    async def peek(self, id: str, model_name: str, health_deployments: list) -> bool:
        """Return if the id is at the top of the queue. Don't pop the value from heap."""
        queue = await self.get_queue(model_name=model_name)
        if not queue:
            raise Exception(
                "Incorrectly setup. Queue is invalid. Queue={}".format(queue)
            )

        # ------------
        # Setup values
        # ------------

        # Check if the id is at the top of the heap
        if queue[0][1] == id:
            return True

        return False

    def get_queue_status(self):
        """Get the status of items in the queue"""
        return self.queue

    async def get_queue(self, model_name: str) -> list:
        """
        Return a queue for that specific model group
        """
        if self.cache is not None:
            _cache_key = "{}:{}".format(SchedulerCacheKeys.queue.value, model_name)
            response = await self.cache.async_get_cache(key=_cache_key)
            if response is None or not isinstance(response, list):
                return []
            elif isinstance(response, list):
                return response
        return self.queue

    async def save_queue(self, queue: list, model_name: str) -> None:
        """
        Save the updated queue of the model group
        """
        if self.cache is not None:
            _cache_key = "{}:{}".format(SchedulerCacheKeys.queue.value, model_name)
            await self.cache.async_set_cache(key=_cache_key, value=queue)
        return None