muryshev commited on
Commit
13e17af
·
1 Parent(s): a5ee5e1
Files changed (3) hide show
  1. llm/common.py +3 -0
  2. llm/deepinfra_api.py +9 -10
  3. llm/vllm_api.py +16 -16
llm/common.py CHANGED
@@ -54,6 +54,9 @@ class LlmApi:
54
  params: LlmParams = None
55
 
56
 
 
 
 
57
  def create_headers(self) -> dict[str, str]:
58
  headers = {"Content-Type": "application/json"}
59
 
 
54
  params: LlmParams = None
55
 
56
 
57
+ def set_params(self, params: LlmParams):
58
+ self.params = params
59
+
60
  def create_headers(self) -> dict[str, str]:
61
  headers = {"Content-Type": "application/json"}
62
 
llm/deepinfra_api.py CHANGED
@@ -9,8 +9,7 @@ class DeepInfraApi(LlmApi):
9
  """
10
 
11
  def __init__(self, params: LlmParams):
12
- super.params = params
13
-
14
 
15
  async def get_models(self) -> List[str]:
16
  """
@@ -25,7 +24,7 @@ class DeepInfraApi(LlmApi):
25
  """
26
  try:
27
  async with httpx.AsyncClient() as client:
28
- response = await client.get(f"{super.params.url}/v1/openai/models", super.create_headers())
29
  if response.status_code == 200:
30
  json_data = response.json()
31
  return [item['id'] for item in json_data.get('data', [])]
@@ -45,8 +44,8 @@ class DeepInfraApi(LlmApi):
45
  """
46
  actual_prompt = self.apply_llm_template_to_prompt(prompt)
47
  messages = []
48
- if super.params.predict_params and super.params.predict_params.system_prompt:
49
- messages.append({"role": "system", "content": super.params.predict_params.system_prompt})
50
  messages.append({"role": "user", "content": actual_prompt})
51
  return messages
52
 
@@ -61,8 +60,8 @@ class DeepInfraApi(LlmApi):
61
  str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
62
  """
63
  actual_prompt = prompt
64
- if super.params.template is not None:
65
- actual_prompt = super.params.template.replace("{{PROMPT}}", actual_prompt)
66
  return actual_prompt
67
 
68
  async def tokenize(self, prompt: str) -> Optional[dict]:
@@ -84,10 +83,10 @@ class DeepInfraApi(LlmApi):
84
 
85
  request = {
86
  "stream": False,
87
- "model": super.params.model,
88
  }
89
 
90
- predict_params = super.params.predict_params
91
  if predict_params:
92
  if predict_params.stop:
93
  non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
@@ -148,6 +147,6 @@ class DeepInfraApi(LlmApi):
148
  request = await self.create_request(prompt)
149
 
150
  async with httpx.AsyncClient() as client:
151
- response = client.post(f"{super.params.url}/v1/openai/chat/completions", super.create_headers(), json=request)
152
  if response.status_code == 200:
153
  return response.json()["choices"][0]["message"]["content"]
 
9
  """
10
 
11
  def __init__(self, params: LlmParams):
12
+ super().set_params(params)
 
13
 
14
  async def get_models(self) -> List[str]:
15
  """
 
24
  """
25
  try:
26
  async with httpx.AsyncClient() as client:
27
+ response = await client.get(f"{super().params.url}/v1/openai/models", super().create_headers())
28
  if response.status_code == 200:
29
  json_data = response.json()
30
  return [item['id'] for item in json_data.get('data', [])]
 
44
  """
45
  actual_prompt = self.apply_llm_template_to_prompt(prompt)
46
  messages = []
47
+ if super().params.predict_params and super().params.predict_params.system_prompt:
48
+ messages.append({"role": "system", "content": super().params.predict_params.system_prompt})
49
  messages.append({"role": "user", "content": actual_prompt})
50
  return messages
51
 
 
60
  str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
61
  """
62
  actual_prompt = prompt
63
+ if super().params.template is not None:
64
+ actual_prompt = super().params.template.replace("{{PROMPT}}", actual_prompt)
65
  return actual_prompt
66
 
67
  async def tokenize(self, prompt: str) -> Optional[dict]:
 
83
 
84
  request = {
85
  "stream": False,
86
+ "model": super().params.model,
87
  }
88
 
89
+ predict_params = super().params.predict_params
90
  if predict_params:
91
  if predict_params.stop:
92
  non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
 
147
  request = await self.create_request(prompt)
148
 
149
  async with httpx.AsyncClient() as client:
150
+ response = client.post(f"{super().params.url}/v1/openai/chat/completions", super().create_headers(), json=request)
151
  if response.status_code == 200:
152
  return response.json()["choices"][0]["message"]["content"]
llm/vllm_api.py CHANGED
@@ -11,8 +11,8 @@ class LlmApi(LlmApi):
11
  """
12
 
13
  def __init__(self, params: LlmParams):
14
- super.params = params
15
-
16
  async def get_models(self) -> List[str]:
17
  """
18
  Выполняет GET-запрос к API для получения списка доступных моделей.
@@ -26,7 +26,7 @@ class LlmApi(LlmApi):
26
  """
27
  try:
28
  async with httpx.AsyncClient() as client:
29
- response = await client.get(f"{super.params.url}/v1/models", super.create_headers())
30
  if response.status_code == 200:
31
  json_data = response.json()
32
  return [item['id'] for item in json_data.get('data', [])]
@@ -36,8 +36,8 @@ class LlmApi(LlmApi):
36
 
37
  async def get_model(self) -> str:
38
  model = None
39
- if super.params.model is not None:
40
- model = super.params.model
41
  else:
42
  models = await self.get_models()
43
  model = models[0] if models else None
@@ -59,8 +59,8 @@ class LlmApi(LlmApi):
59
  """
60
  actual_prompt = self.apply_llm_template_to_prompt(prompt)
61
  messages = []
62
- if super.params.predict_params and super.params.predict_params.system_prompt:
63
- messages.append({"role": "system", "content": super.params.predict_params.system_prompt})
64
  messages.append({"role": "user", "content": actual_prompt})
65
  return messages
66
 
@@ -75,8 +75,8 @@ class LlmApi(LlmApi):
75
  str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
76
  """
77
  actual_prompt = prompt
78
- if super.params.template is not None:
79
- actual_prompt = super.params.template.replace("{{PROMPT}}", actual_prompt)
80
  return actual_prompt
81
 
82
  async def tokenize(self, prompt: str) -> Optional[dict]:
@@ -101,9 +101,9 @@ class LlmApi(LlmApi):
101
  try:
102
  async with httpx.AsyncClient() as client:
103
  response = await client.post(
104
- f"{super.params.url}/tokenize",
105
  json=request_data,
106
- headers=super.create_headers(),
107
  )
108
  if response.status_code == 200:
109
  data = response.json()
@@ -135,9 +135,9 @@ class LlmApi(LlmApi):
135
  try:
136
  async with httpx.AsyncClient() as client:
137
  response = await client.post(
138
- f"{super.params.url}/detokenize",
139
  json=request_data,
140
- headers=super.create_headers(),
141
  )
142
  if response.status_code == 200:
143
  data = response.json()
@@ -169,7 +169,7 @@ class LlmApi(LlmApi):
169
  "model": model,
170
  }
171
 
172
- predict_params = super.params.predict_params
173
  if predict_params:
174
  if predict_params.stop:
175
  non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
@@ -250,7 +250,7 @@ class LlmApi(LlmApi):
250
  # Максимально допустимое количество токенов для источников
251
  max_length = (
252
  max_token_count
253
- - (super.params.predict_params.n_predict or 0)
254
  - aux_token_count
255
  - system_prompt_token_count
256
  )
@@ -289,7 +289,7 @@ class LlmApi(LlmApi):
289
  request = await self.create_request(prompt)
290
 
291
  # Начинаем потоковый запрос
292
- async with client.stream("POST", f"{super.params.url}/v1/chat/completions", json=request) as response:
293
  if response.status_code != 200:
294
  # Если ошибка, читаем ответ для получения подробностей
295
  error_content = await response.aread()
 
11
  """
12
 
13
  def __init__(self, params: LlmParams):
14
+ super().set_params(params)
15
+
16
  async def get_models(self) -> List[str]:
17
  """
18
  Выполняет GET-запрос к API для получения списка доступных моделей.
 
26
  """
27
  try:
28
  async with httpx.AsyncClient() as client:
29
+ response = await client.get(f"{super().params.url}/v1/models", super().create_headers())
30
  if response.status_code == 200:
31
  json_data = response.json()
32
  return [item['id'] for item in json_data.get('data', [])]
 
36
 
37
  async def get_model(self) -> str:
38
  model = None
39
+ if super().params.model is not None:
40
+ model = super().params.model
41
  else:
42
  models = await self.get_models()
43
  model = models[0] if models else None
 
59
  """
60
  actual_prompt = self.apply_llm_template_to_prompt(prompt)
61
  messages = []
62
+ if super().params.predict_params and super().params.predict_params.system_prompt:
63
+ messages.append({"role": "system", "content": super().params.predict_params.system_prompt})
64
  messages.append({"role": "user", "content": actual_prompt})
65
  return messages
66
 
 
75
  str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
76
  """
77
  actual_prompt = prompt
78
+ if super().params.template is not None:
79
+ actual_prompt = super().params.template.replace("{{PROMPT}}", actual_prompt)
80
  return actual_prompt
81
 
82
  async def tokenize(self, prompt: str) -> Optional[dict]:
 
101
  try:
102
  async with httpx.AsyncClient() as client:
103
  response = await client.post(
104
+ f"{super().params.url}/tokenize",
105
  json=request_data,
106
+ headers=super().create_headers(),
107
  )
108
  if response.status_code == 200:
109
  data = response.json()
 
135
  try:
136
  async with httpx.AsyncClient() as client:
137
  response = await client.post(
138
+ f"{super().params.url}/detokenize",
139
  json=request_data,
140
+ headers=super().create_headers(),
141
  )
142
  if response.status_code == 200:
143
  data = response.json()
 
169
  "model": model,
170
  }
171
 
172
+ predict_params = super().params.predict_params
173
  if predict_params:
174
  if predict_params.stop:
175
  non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
 
250
  # Максимально допустимое количество токенов для источников
251
  max_length = (
252
  max_token_count
253
+ - (super().params.predict_params.n_predict or 0)
254
  - aux_token_count
255
  - system_prompt_token_count
256
  )
 
289
  request = await self.create_request(prompt)
290
 
291
  # Начинаем потоковый запрос
292
+ async with client.stream("POST", f"{super().params.url}/v1/chat/completions", json=request) as response:
293
  if response.status_code != 200:
294
  # Если ошибка, читаем ответ для получения подробностей
295
  error_content = await response.aread()