muryshev commited on
Commit
fa7e09e
·
1 Parent(s): f3d1581
Files changed (2) hide show
  1. llm/deepinfra_api.py +9 -9
  2. llm/vllm_api.py +12 -12
llm/deepinfra_api.py CHANGED
@@ -25,7 +25,7 @@ class DeepInfraApi(LlmApi):
25
  """
26
  try:
27
  async with httpx.AsyncClient() as client:
28
- response = await client.get(f"{super().params.url}/v1/openai/models", super().create_headers())
29
  if response.status_code == 200:
30
  json_data = response.json()
31
  return [item['id'] for item in json_data.get('data', [])]
@@ -45,8 +45,8 @@ class DeepInfraApi(LlmApi):
45
  """
46
  actual_prompt = self.apply_llm_template_to_prompt(prompt)
47
  messages = []
48
- if super().params.predict_params and super().params.predict_params.system_prompt:
49
- messages.append({"role": "system", "content": super().params.predict_params.system_prompt})
50
  messages.append({"role": "user", "content": actual_prompt})
51
  return messages
52
 
@@ -61,8 +61,8 @@ class DeepInfraApi(LlmApi):
61
  str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
62
  """
63
  actual_prompt = prompt
64
- if super().params.template is not None:
65
- actual_prompt = super().params.template.replace("{{PROMPT}}", actual_prompt)
66
  return actual_prompt
67
 
68
  async def tokenize(self, prompt: str) -> Optional[dict]:
@@ -81,15 +81,15 @@ class DeepInfraApi(LlmApi):
81
  Returns:
82
  dict: Словарь с параметрами для выполнения запроса.
83
  """
84
- print(super().params)
85
  print(self.params)
86
 
87
  request = {
88
  "stream": False,
89
- "model": super().params.model,
90
  }
91
 
92
- predict_params = super().params.predict_params
93
  if predict_params:
94
  if predict_params.stop:
95
  non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
@@ -150,6 +150,6 @@ class DeepInfraApi(LlmApi):
150
  request = await self.create_request(prompt)
151
 
152
  async with httpx.AsyncClient() as client:
153
- response = client.post(f"{super().params.url}/v1/openai/chat/completions", super().create_headers(), json=request)
154
  if response.status_code == 200:
155
  return response.json()["choices"][0]["message"]["content"]
 
25
  """
26
  try:
27
  async with httpx.AsyncClient() as client:
28
+ response = await client.get(f"{self.params.url}/v1/openai/models", super().create_headers())
29
  if response.status_code == 200:
30
  json_data = response.json()
31
  return [item['id'] for item in json_data.get('data', [])]
 
45
  """
46
  actual_prompt = self.apply_llm_template_to_prompt(prompt)
47
  messages = []
48
+ if self.params.predict_params and self.params.predict_params.system_prompt:
49
+ messages.append({"role": "system", "content": self.params.predict_params.system_prompt})
50
  messages.append({"role": "user", "content": actual_prompt})
51
  return messages
52
 
 
61
  str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
62
  """
63
  actual_prompt = prompt
64
+ if self.params.template is not None:
65
+ actual_prompt = self.params.template.replace("{{PROMPT}}", actual_prompt)
66
  return actual_prompt
67
 
68
  async def tokenize(self, prompt: str) -> Optional[dict]:
 
81
  Returns:
82
  dict: Словарь с параметрами для выполнения запроса.
83
  """
84
+ print(self.params)
85
  print(self.params)
86
 
87
  request = {
88
  "stream": False,
89
+ "model": self.params.model,
90
  }
91
 
92
+ predict_params = self.params.predict_params
93
  if predict_params:
94
  if predict_params.stop:
95
  non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
 
150
  request = await self.create_request(prompt)
151
 
152
  async with httpx.AsyncClient() as client:
153
+ response = client.post(f"{self.params.url}/v1/openai/chat/completions", super().create_headers(), json=request)
154
  if response.status_code == 200:
155
  return response.json()["choices"][0]["message"]["content"]
llm/vllm_api.py CHANGED
@@ -27,7 +27,7 @@ class LlmApi(LlmApi):
27
  """
28
  try:
29
  async with httpx.AsyncClient() as client:
30
- response = await client.get(f"{super().params.url}/v1/models", super().create_headers())
31
  if response.status_code == 200:
32
  json_data = response.json()
33
  return [item['id'] for item in json_data.get('data', [])]
@@ -37,8 +37,8 @@ class LlmApi(LlmApi):
37
 
38
  async def get_model(self) -> str:
39
  model = None
40
- if super().params.model is not None:
41
- model = super().params.model
42
  else:
43
  models = await self.get_models()
44
  model = models[0] if models else None
@@ -60,8 +60,8 @@ class LlmApi(LlmApi):
60
  """
61
  actual_prompt = self.apply_llm_template_to_prompt(prompt)
62
  messages = []
63
- if super().params.predict_params and super().params.predict_params.system_prompt:
64
- messages.append({"role": "system", "content": super().params.predict_params.system_prompt})
65
  messages.append({"role": "user", "content": actual_prompt})
66
  return messages
67
 
@@ -76,8 +76,8 @@ class LlmApi(LlmApi):
76
  str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
77
  """
78
  actual_prompt = prompt
79
- if super().params.template is not None:
80
- actual_prompt = super().params.template.replace("{{PROMPT}}", actual_prompt)
81
  return actual_prompt
82
 
83
  async def tokenize(self, prompt: str) -> Optional[dict]:
@@ -102,7 +102,7 @@ class LlmApi(LlmApi):
102
  try:
103
  async with httpx.AsyncClient() as client:
104
  response = await client.post(
105
- f"{super().params.url}/tokenize",
106
  json=request_data,
107
  headers=super().create_headers(),
108
  )
@@ -136,7 +136,7 @@ class LlmApi(LlmApi):
136
  try:
137
  async with httpx.AsyncClient() as client:
138
  response = await client.post(
139
- f"{super().params.url}/detokenize",
140
  json=request_data,
141
  headers=super().create_headers(),
142
  )
@@ -170,7 +170,7 @@ class LlmApi(LlmApi):
170
  "model": model,
171
  }
172
 
173
- predict_params = super().params.predict_params
174
  if predict_params:
175
  if predict_params.stop:
176
  non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
@@ -251,7 +251,7 @@ class LlmApi(LlmApi):
251
  # Максимально допустимое количество токенов для источников
252
  max_length = (
253
  max_token_count
254
- - (super().params.predict_params.n_predict or 0)
255
  - aux_token_count
256
  - system_prompt_token_count
257
  )
@@ -290,7 +290,7 @@ class LlmApi(LlmApi):
290
  request = await self.create_request(prompt)
291
 
292
  # Начинаем потоковый запрос
293
- async with client.stream("POST", f"{super().params.url}/v1/chat/completions", json=request) as response:
294
  if response.status_code != 200:
295
  # Если ошибка, читаем ответ для получения подробностей
296
  error_content = await response.aread()
 
27
  """
28
  try:
29
  async with httpx.AsyncClient() as client:
30
+ response = await client.get(f"{self.params.url}/v1/models", super().create_headers())
31
  if response.status_code == 200:
32
  json_data = response.json()
33
  return [item['id'] for item in json_data.get('data', [])]
 
37
 
38
  async def get_model(self) -> str:
39
  model = None
40
+ if self.params.model is not None:
41
+ model = self.params.model
42
  else:
43
  models = await self.get_models()
44
  model = models[0] if models else None
 
60
  """
61
  actual_prompt = self.apply_llm_template_to_prompt(prompt)
62
  messages = []
63
+ if self.params.predict_params and self.params.predict_params.system_prompt:
64
+ messages.append({"role": "system", "content": self.params.predict_params.system_prompt})
65
  messages.append({"role": "user", "content": actual_prompt})
66
  return messages
67
 
 
76
  str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
77
  """
78
  actual_prompt = prompt
79
+ if self.params.template is not None:
80
+ actual_prompt = self.params.template.replace("{{PROMPT}}", actual_prompt)
81
  return actual_prompt
82
 
83
  async def tokenize(self, prompt: str) -> Optional[dict]:
 
102
  try:
103
  async with httpx.AsyncClient() as client:
104
  response = await client.post(
105
+ f"{self.params.url}/tokenize",
106
  json=request_data,
107
  headers=super().create_headers(),
108
  )
 
136
  try:
137
  async with httpx.AsyncClient() as client:
138
  response = await client.post(
139
+ f"{self.params.url}/detokenize",
140
  json=request_data,
141
  headers=super().create_headers(),
142
  )
 
170
  "model": model,
171
  }
172
 
173
+ predict_params = self.params.predict_params
174
  if predict_params:
175
  if predict_params.stop:
176
  non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
 
251
  # Максимально допустимое количество токенов для источников
252
  max_length = (
253
  max_token_count
254
+ - (self.params.predict_params.n_predict or 0)
255
  - aux_token_count
256
  - system_prompt_token_count
257
  )
 
290
  request = await self.create_request(prompt)
291
 
292
  # Начинаем потоковый запрос
293
+ async with client.stream("POST", f"{self.params.url}/v1/chat/completions", json=request) as response:
294
  if response.status_code != 200:
295
  # Если ошибка, читаем ответ для получения подробностей
296
  error_content = await response.aread()