dexter2389 commited on
Commit
a310155
·
1 Parent(s): a60bd23

Use Groq for inference for majority of the models

Browse files
Files changed (5) hide show
  1. app/app/main.py +78 -39
  2. app/static/index.html +15 -8
  3. compose.yaml +1 -0
  4. pyproject.toml +1 -0
  5. uv.lock +29 -1
app/app/main.py CHANGED
@@ -14,6 +14,7 @@ from arcana_codex import (
14
  from bson.objectid import ObjectId
15
  from fastapi import Depends, FastAPI, Header, HTTPException, Request, status
16
  from fastapi.responses import JSONResponse
 
17
  from llama_cpp import Llama
18
  from pydantic import BaseModel, EmailStr
19
  from pymongo.mongo_client import MongoClient
@@ -22,11 +23,13 @@ from starlette.responses import FileResponse
22
  __version__ = "0.0.0"
23
 
24
 
25
- class SupportedModelPipes(StrEnum):
 
26
  Gemma3 = "gemma3"
27
- QwenOpenR1 = "qwen-open-r1"
28
- SmolLLM2 = "smollm2"
29
- SmolLLM2Reasoning = "smollm2-reasoning"
 
30
 
31
 
32
  class LogEvent(StrEnum):
@@ -34,24 +37,6 @@ class LogEvent(StrEnum):
34
  LOGIN = "login"
35
 
36
 
37
- smollm2_pipeline = Llama.from_pretrained(
38
- repo_id="HuggingFaceTB/SmolLM2-360M-Instruct-GGUF",
39
- filename="smollm2-360m-instruct-q8_0.gguf",
40
- verbose=False,
41
- )
42
-
43
- smollm2_reasoning_pipeline = Llama.from_pretrained(
44
- repo_id="tensorblock/Reasoning-SmolLM2-135M-GGUF",
45
- filename="Reasoning-SmolLM2-135M-Q8_0.gguf",
46
- verbose=False,
47
- )
48
-
49
- qwen_open_r1_pipeline = Llama.from_pretrained(
50
- repo_id="tensorblock/Qwen2.5-0.5B-Open-R1-Distill-GGUF",
51
- filename="Qwen2.5-0.5B-Open-R1-Distill-Q8_0.gguf",
52
- verbose=False,
53
- )
54
-
55
  gemma_3_pipeline = Llama.from_pretrained(
56
  repo_id="ggml-org/gemma-3-1b-it-GGUF",
57
  filename="gemma-3-1b-it-Q8_0.gguf",
@@ -60,7 +45,7 @@ gemma_3_pipeline = Llama.from_pretrained(
60
 
61
 
62
  class ChatRequest(BaseModel):
63
- model: SupportedModelPipes = SupportedModelPipes.SmolLLM2
64
  message: str
65
 
66
 
@@ -120,11 +105,31 @@ def verify_authorization_header(
120
  )
121
 
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  @asynccontextmanager
124
  async def lifespan(app: FastAPI): # noqa: ARG001
125
  # Set API key in FastAPI app
126
  app.ARCANA_API_KEY = os.environ.get("ARCANA_API_KEY", "")
127
 
 
 
128
  app.mongo_db = MongoClient(
129
  os.environ.get("MONGO_URI", "mongodb+srv://localhost:27017/")
130
  )[os.environ.get("MONGO_DB", "arcana_hf_demo_test")]
@@ -136,6 +141,8 @@ async def lifespan(app: FastAPI): # noqa: ARG001
136
  # Clear API key to avoid leaking it
137
  app.ARCANA_API_KEY = ""
138
 
 
 
139
  logging.info("Application stopped")
140
 
141
 
@@ -182,26 +189,53 @@ def chat(
182
 
183
  logger.info(f"Using {payload.model}")
184
 
 
185
  match payload.model:
186
- case SupportedModelPipes.Gemma3:
187
- ai_pipeline = gemma_3_pipeline
188
- case SupportedModelPipes.QwenOpenR1:
189
- ai_pipeline = qwen_open_r1_pipeline
190
- case SupportedModelPipes.SmolLLM2:
191
- ai_pipeline = smollm2_pipeline
192
- case SupportedModelPipes.SmolLLM2Reasoning:
193
- ai_pipeline = smollm2_reasoning_pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- inference_start_time = time.perf_counter()
196
- ai_response = ai_pipeline.create_chat_completion(
197
- messages=[{"role": "user", "content": f"{payload.message}"}],
198
- max_tokens=512,
199
- seed=8,
200
- )["choices"][0]["message"]["content"].strip()
201
  inference_end_time = time.perf_counter()
202
 
203
- elapsed_time = inference_end_time - inference_start_time
204
- logger.info(f"Inference took: {elapsed_time:.4f} seconds")
205
 
206
  integrate_payload = AdUnitsIntegrateModel(
207
  ad_unit_ids=[
@@ -210,10 +244,15 @@ def chat(
210
  base_content=ai_response,
211
  )
212
 
 
213
  integration_result = client.integrate_ad_units(integrate_payload)
214
  integrated_content = integration_result.get("response_data", {}).get(
215
  "integrated_content"
216
  )
 
 
 
 
217
 
218
  request.app.mongo_db["logs"].insert_one(
219
  {
 
14
  from bson.objectid import ObjectId
15
  from fastapi import Depends, FastAPI, Header, HTTPException, Request, status
16
  from fastapi.responses import JSONResponse
17
+ from groq import Groq
18
  from llama_cpp import Llama
19
  from pydantic import BaseModel, EmailStr
20
  from pymongo.mongo_client import MongoClient
 
23
  __version__ = "0.0.0"
24
 
25
 
26
+ class SupportedModels(StrEnum):
27
+ Gemma2 = "gemma2"
28
  Gemma3 = "gemma3"
29
+ Llama3_3 = "llama3_3"
30
+ Llama3_1 = "llama3_1"
31
+ Qwen2_5 = "qwen2_5"
32
+ Deepseek_R1 = "deepseek_r1"
33
 
34
 
35
  class LogEvent(StrEnum):
 
37
  LOGIN = "login"
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  gemma_3_pipeline = Llama.from_pretrained(
41
  repo_id="ggml-org/gemma-3-1b-it-GGUF",
42
  filename="gemma-3-1b-it-Q8_0.gguf",
 
45
 
46
 
47
  class ChatRequest(BaseModel):
48
+ model: SupportedModels = SupportedModels.Gemma2
49
  message: str
50
 
51
 
 
105
  )
106
 
107
 
108
+ def process_groq_chat_request(
109
+ groq_client: Groq, message: str, model: str
110
+ ) -> str | None:
111
+ return (
112
+ groq_client.chat.completions.create(
113
+ messages=[
114
+ {"role": "system", "content": "You are a helpful assistant."},
115
+ {"role": "user", "content": f"{message}"},
116
+ ],
117
+ max_completion_tokens=1024,
118
+ seed=8,
119
+ model=model,
120
+ )
121
+ .choices[0]
122
+ .message.content
123
+ )
124
+
125
+
126
  @asynccontextmanager
127
  async def lifespan(app: FastAPI): # noqa: ARG001
128
  # Set API key in FastAPI app
129
  app.ARCANA_API_KEY = os.environ.get("ARCANA_API_KEY", "")
130
 
131
+ app.groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
132
+
133
  app.mongo_db = MongoClient(
134
  os.environ.get("MONGO_URI", "mongodb+srv://localhost:27017/")
135
  )[os.environ.get("MONGO_DB", "arcana_hf_demo_test")]
 
141
  # Clear API key to avoid leaking it
142
  app.ARCANA_API_KEY = ""
143
 
144
+ app.groq_client = None
145
+
146
  logging.info("Application stopped")
147
 
148
 
 
189
 
190
  logger.info(f"Using {payload.model}")
191
 
192
+ inference_start_time = time.perf_counter()
193
  match payload.model:
194
+ case SupportedModels.Gemma3:
195
+ llm_response = gemma_3_pipeline.create_chat_completion(
196
+ messages=[
197
+ {"role": "system", "content": "You are a helpful assistant."},
198
+ {"role": "user", "content": f"{payload.message}"},
199
+ ],
200
+ max_tokens=512,
201
+ seed=8,
202
+ )["choices"][0]["message"]["content"]
203
+ case SupportedModels.Gemma2:
204
+ llm_response = process_groq_chat_request(
205
+ groq_client=request.app.groq_client,
206
+ message=payload.message,
207
+ model="gemma2-9b-it",
208
+ )
209
+ case SupportedModels.Llama3_3:
210
+ llm_response = process_groq_chat_request(
211
+ groq_client=request.app.groq_client,
212
+ message=payload.message,
213
+ model="llama-3.3-70b-versatile",
214
+ )
215
+ case SupportedModels.Llama3_1:
216
+ llm_response = process_groq_chat_request(
217
+ groq_client=request.app.groq_client,
218
+ message=payload.message,
219
+ model="llama-3.1-8b-instant",
220
+ )
221
+ case SupportedModels.Qwen2_5:
222
+ llm_response = process_groq_chat_request(
223
+ groq_client=request.app.groq_client,
224
+ message=payload.message,
225
+ model="qwen-2.5-32b",
226
+ )
227
+ case SupportedModels.Deepseek_R1:
228
+ llm_response = process_groq_chat_request(
229
+ groq_client=request.app.groq_client,
230
+ message=payload.message,
231
+ model="deepseek-r1-distill-qwen-32b",
232
+ )
233
 
234
+ ai_response = "" if llm_response is None else llm_response.strip()
 
 
 
 
 
235
  inference_end_time = time.perf_counter()
236
 
237
+ inference_elapsed_time = inference_end_time - inference_start_time
238
+ logger.info(f"Inference took: {inference_elapsed_time:.4f} seconds")
239
 
240
  integrate_payload = AdUnitsIntegrateModel(
241
  ad_unit_ids=[
 
244
  base_content=ai_response,
245
  )
246
 
247
+ integration_start_time = time.perf_counter()
248
  integration_result = client.integrate_ad_units(integrate_payload)
249
  integrated_content = integration_result.get("response_data", {}).get(
250
  "integrated_content"
251
  )
252
+ integration_end_time = time.perf_counter()
253
+
254
+ integration_elapsed_time = integration_end_time - integration_start_time
255
+ logger.info(f"Integration took: {integration_elapsed_time:.4f} seconds")
256
 
257
  request.app.mongo_db["logs"].insert_one(
258
  {
app/static/index.html CHANGED
@@ -62,18 +62,25 @@
62
  <div id="model-dropdown"
63
  class="absolute z-10 w-full mt-1 bg-white border border-gray-300 rounded-md shadow-lg hidden">
64
  <ul class="py-1">
65
- <li class="model-option px-4 py-2 hover:bg-gray-100 cursor-pointer" data-value="smollm2">
66
- SmolLM2</li>
67
  </ul>
68
  <ul class="py-1">
69
- <li class="model-option px-4 py-2 hover:bg-gray-100 cursor-pointer"
70
- data-value="smollm2-reasoning">
71
- SmolLLM2Reasoning</li>
 
 
 
 
 
 
 
72
  </ul>
73
  <ul class="py-1">
74
  <li class="model-option px-4 py-2 hover:bg-gray-100 cursor-pointer"
75
- data-value="qwen-open-r1">
76
- QwenOpenR1</li>
77
  </ul>
78
  <ul class="py-1">
79
  <li class="model-option px-4 py-2 hover:bg-gray-100 cursor-pointer" data-value="gemma3">
@@ -325,7 +332,7 @@
325
  const selectedModelText = document.getElementById('selected-model-text');
326
 
327
  // Default model
328
- let selectedModel = 'smollm2';
329
 
330
  // Toggle dropdown
331
  modelDropdownButton.addEventListener('click', () => {
 
62
  <div id="model-dropdown"
63
  class="absolute z-10 w-full mt-1 bg-white border border-gray-300 rounded-md shadow-lg hidden">
64
  <ul class="py-1">
65
+ <li class="model-option px-4 py-2 hover:bg-gray-100 cursor-pointer" data-value="gemma2">
66
+ Gemma2</li>
67
  </ul>
68
  <ul class="py-1">
69
+ <li class="model-option px-4 py-2 hover:bg-gray-100 cursor-pointer" data-value="llama3_1">
70
+ Llama3.1</li>
71
+ </ul>
72
+ <ul class="py-1">
73
+ <li class="model-option px-4 py-2 hover:bg-gray-100 cursor-pointer" data-value="llama3_3">
74
+ Llama3.3</li>
75
+ </ul>
76
+ <ul class="py-1">
77
+ <li class="model-option px-4 py-2 hover:bg-gray-100 cursor-pointer" data-value="qwen2_5">
78
+ Qwen2.5</li>
79
  </ul>
80
  <ul class="py-1">
81
  <li class="model-option px-4 py-2 hover:bg-gray-100 cursor-pointer"
82
+ data-value="deepseek_r1">
83
+ Deepseek_R1</li>
84
  </ul>
85
  <ul class="py-1">
86
  <li class="model-option px-4 py-2 hover:bg-gray-100 cursor-pointer" data-value="gemma3">
 
332
  const selectedModelText = document.getElementById('selected-model-text');
333
 
334
  // Default model
335
+ let selectedModel = 'gemma2';
336
 
337
  // Toggle dropdown
338
  modelDropdownButton.addEventListener('click', () => {
compose.yaml CHANGED
@@ -12,6 +12,7 @@ services:
12
  - ./app:/app
13
  environment:
14
  ARCANA_API_KEY: "${ARCANA_API_KEY}"
 
15
  MONGO_URI: "${MONGO_URI}"
16
  OPENBLAS_NUM_THREADS: "${OPENBLAS_NUM_THREADS:-4}"
17
  healthcheck:
 
12
  - ./app:/app
13
  environment:
14
  ARCANA_API_KEY: "${ARCANA_API_KEY}"
15
+ GROQ_API_KEY: "${GROQ_API_KEY}"
16
  MONGO_URI: "${MONGO_URI}"
17
  OPENBLAS_NUM_THREADS: "${OPENBLAS_NUM_THREADS:-4}"
18
  healthcheck:
pyproject.toml CHANGED
@@ -17,6 +17,7 @@ dependencies = [
17
  "huggingface-hub>=0.29",
18
  "pymongo>=4.11",
19
  "email-validator>=2.2",
 
20
  ]
21
 
22
 
 
17
  "huggingface-hub>=0.29",
18
  "pymongo>=4.11",
19
  "email-validator>=2.2",
20
+ "groq>=0.20",
21
  ]
22
 
23
 
uv.lock CHANGED
@@ -45,6 +45,7 @@ dependencies = [
45
  { name = "arcana-codex" },
46
  { name = "email-validator" },
47
  { name = "fastapi-slim" },
 
48
  { name = "huggingface-hub" },
49
  { name = "llama-cpp-python" },
50
  { name = "pillow" },
@@ -68,8 +69,9 @@ test = [
68
  [package.metadata]
69
  requires-dist = [
70
  { name = "arcana-codex", specifier = ">=0.2" },
71
- { name = "email-validator", specifier = ">=2.2.0" },
72
  { name = "fastapi-slim", specifier = ">=0.115" },
 
73
  { name = "huggingface-hub", specifier = ">=0.29" },
74
  { name = "llama-cpp-python", specifier = ">=0.3" },
75
  { name = "pillow", specifier = ">=11.1" },
@@ -250,6 +252,15 @@ wheels = [
250
  { url = "https://files.pythonhosted.org/packages/91/a1/cf2472db20f7ce4a6be1253a81cfdf85ad9c7885ffbed7047fb72c24cf87/distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87", size = 468973 },
251
  ]
252
 
 
 
 
 
 
 
 
 
 
253
  [[package]]
254
  name = "dnspython"
255
  version = "2.7.0"
@@ -304,6 +315,23 @@ wheels = [
304
  { url = "https://files.pythonhosted.org/packages/56/53/eb690efa8513166adef3e0669afd31e95ffde69fb3c52ec2ac7223ed6018/fsspec-2025.3.0-py3-none-any.whl", hash = "sha256:efb87af3efa9103f94ca91a7f8cb7a4df91af9f74fc106c9c7ea0efd7277c1b3", size = 193615 },
305
  ]
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  [[package]]
308
  name = "h11"
309
  version = "0.14.0"
 
45
  { name = "arcana-codex" },
46
  { name = "email-validator" },
47
  { name = "fastapi-slim" },
48
+ { name = "groq" },
49
  { name = "huggingface-hub" },
50
  { name = "llama-cpp-python" },
51
  { name = "pillow" },
 
69
  [package.metadata]
70
  requires-dist = [
71
  { name = "arcana-codex", specifier = ">=0.2" },
72
+ { name = "email-validator", specifier = ">=2.2" },
73
  { name = "fastapi-slim", specifier = ">=0.115" },
74
+ { name = "groq", specifier = ">=0.20.0" },
75
  { name = "huggingface-hub", specifier = ">=0.29" },
76
  { name = "llama-cpp-python", specifier = ">=0.3" },
77
  { name = "pillow", specifier = ">=11.1" },
 
252
  { url = "https://files.pythonhosted.org/packages/91/a1/cf2472db20f7ce4a6be1253a81cfdf85ad9c7885ffbed7047fb72c24cf87/distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87", size = 468973 },
253
  ]
254
 
255
+ [[package]]
256
+ name = "distro"
257
+ version = "1.9.0"
258
+ source = { registry = "https://pypi.org/simple" }
259
+ sdist = { url = "https://files.pythonhosted.org/packages/fc/f8/98eea607f65de6527f8a2e8885fc8015d3e6f5775df186e443e0964a11c3/distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed", size = 60722 }
260
+ wheels = [
261
+ { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277 },
262
+ ]
263
+
264
  [[package]]
265
  name = "dnspython"
266
  version = "2.7.0"
 
315
  { url = "https://files.pythonhosted.org/packages/56/53/eb690efa8513166adef3e0669afd31e95ffde69fb3c52ec2ac7223ed6018/fsspec-2025.3.0-py3-none-any.whl", hash = "sha256:efb87af3efa9103f94ca91a7f8cb7a4df91af9f74fc106c9c7ea0efd7277c1b3", size = 193615 },
316
  ]
317
 
318
+ [[package]]
319
+ name = "groq"
320
+ version = "0.20.0"
321
+ source = { registry = "https://pypi.org/simple" }
322
+ dependencies = [
323
+ { name = "anyio" },
324
+ { name = "distro" },
325
+ { name = "httpx" },
326
+ { name = "pydantic" },
327
+ { name = "sniffio" },
328
+ { name = "typing-extensions" },
329
+ ]
330
+ sdist = { url = "https://files.pythonhosted.org/packages/8f/fc/e5a03586ffad7ae6c7996f388ca321a3bf8b9fa544a36a934ce4b6b44211/groq-0.20.0.tar.gz", hash = "sha256:2a201d41cae768c53d411dabcfea2333e2e138df22d909ed555ece426f1e016f", size = 121936 }
331
+ wheels = [
332
+ { url = "https://files.pythonhosted.org/packages/95/37/9b415df5dd1e6a685d3e8fd4e564a5e80f4f87c19d82829ad027fa2bb150/groq-0.20.0-py3-none-any.whl", hash = "sha256:c27b89903eb2b77f94ed95837ff3cadfc8c9e670953b1c5e5e2e855fea54b6c5", size = 124919 },
333
+ ]
334
+
335
  [[package]]
336
  name = "h11"
337
  version = "0.14.0"