File size: 3,706 Bytes
343c758
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import json
from fastapi import FastAPI, HTTPException
from dotenv import load_dotenv
import os
import re
import requests
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict, Any, Optional, Tuple

load_dotenv()

app = FastAPI()
ranked_models = [
    "llama-3.3-70b-versatile",
    "llama3-70b-8192",
    "meta-llama/llama-4-maverick-17b-128e-instruct",
    "meta-llama/llama-4-scout-17b-16e-instruct",
    "mistral-saba-24b",
    "gemma2-9b-it",
    "llama-3.1-8b-instant",
    "llama3-8b-8192"
]

api_keys = []

for k,v in os.environ.items():
    if re.match(r'^GROQ_\d+$', k):
        api_keys.append(v)

app.add_middleware(
    CORSMiddleware,
    allow_credentials=True,
    allow_headers=["*"],
    allow_methods=["GET", "POST"],
    allow_origins=["*"]
)

class ChatRequest(BaseModel):
    models: Optional[List[Any]] = []
    query: str

class ChatResponse(BaseModel):
    output: str

@app.get("/")
def main_page():
    return {"status": "ok"}

@app.post("/chat", response_model=ChatResponse)
def ask_groq_llm(req: ChatRequest):
    models = req.models
    query = req.query
    looping = True
    if models == []:
        while looping:
            for model in ranked_models:
                for key in api_keys:
                    resp = requests.post("https://api.groq.com/openai/v1/chat/completions", verify=False, headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}, data=json.dumps({"model": model, "messages": [{"role": "user", "content": query}]}))
                    if resp.status_code == 200:
                        respJson = resp.json()
                        print("Asked to", model, "with the key ID", str(api_keys.index(key)), ":", query)
                        return ChatResponse(output=respJson["choices"][0]["message"]["content"])
                    print(resp.status_code, resp.text)
            looping = False
        return ChatResponse(output="ERROR !")
    elif len(models) == 1:
        while looping:
            for key in api_keys:
                resp = requests.post("https://api.groq.com/openai/v1/chat/completions", verify=False, headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}, data=json.dumps({"model": models[0], "messages": [{"role": "user", "content": query}]}))
                if resp.status_code == 200:
                    respJson = resp.json()
                    print("Asked to", model, "with the key ID", str(api_keys.index(key)), ":", query)
                    return ChatResponse(output=respJson["choices"][0]["message"]["content"])
                print(resp.status_code, resp.text)
            looping = False
        return ChatResponse(output="ERROR !")
    else:
        while looping:
            order = {val: ind for ind, val in enumerate(ranked_models)}
            sorted_models = sorted(models, key=lambda x: order.get(x, float('inf')))
            for model in sorted_models:
                for key in api_keys:
                    resp = requests.post("https://api.groq.com/openai/v1/chat/completions", verify=False, headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}, data=json.dumps({"model": model, "messages": [{"role": "user", "content": query}]}))
                    if resp.status_code == 200:
                        respJson = resp.json()
                        print("Asked to", model, "with the key ID", str(api_keys.index(key)), ":", query)
                        return ChatResponse(output=respJson["choices"][0]["message"]["content"])
                    print(resp.status_code, resp.text)
            looping = False
        return ChatResponse(output="ERROR !")