File size: 3,706 Bytes
343c758 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import json
from fastapi import FastAPI, HTTPException
from dotenv import load_dotenv
import os
import re
import requests
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict, Any, Optional, Tuple
load_dotenv()
app = FastAPI()
ranked_models = [
"llama-3.3-70b-versatile",
"llama3-70b-8192",
"meta-llama/llama-4-maverick-17b-128e-instruct",
"meta-llama/llama-4-scout-17b-16e-instruct",
"mistral-saba-24b",
"gemma2-9b-it",
"llama-3.1-8b-instant",
"llama3-8b-8192"
]
api_keys = []
for k,v in os.environ.items():
if re.match(r'^GROQ_\d+$', k):
api_keys.append(v)
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_headers=["*"],
allow_methods=["GET", "POST"],
allow_origins=["*"]
)
class ChatRequest(BaseModel):
models: Optional[List[Any]] = []
query: str
class ChatResponse(BaseModel):
output: str
@app.get("/")
def main_page():
return {"status": "ok"}
@app.post("/chat", response_model=ChatResponse)
def ask_groq_llm(req: ChatRequest):
models = req.models
query = req.query
looping = True
if models == []:
while looping:
for model in ranked_models:
for key in api_keys:
resp = requests.post("https://api.groq.com/openai/v1/chat/completions", verify=False, headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}, data=json.dumps({"model": model, "messages": [{"role": "user", "content": query}]}))
if resp.status_code == 200:
respJson = resp.json()
print("Asked to", model, "with the key ID", str(api_keys.index(key)), ":", query)
return ChatResponse(output=respJson["choices"][0]["message"]["content"])
print(resp.status_code, resp.text)
looping = False
return ChatResponse(output="ERROR !")
elif len(models) == 1:
while looping:
for key in api_keys:
resp = requests.post("https://api.groq.com/openai/v1/chat/completions", verify=False, headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}, data=json.dumps({"model": models[0], "messages": [{"role": "user", "content": query}]}))
if resp.status_code == 200:
respJson = resp.json()
print("Asked to", model, "with the key ID", str(api_keys.index(key)), ":", query)
return ChatResponse(output=respJson["choices"][0]["message"]["content"])
print(resp.status_code, resp.text)
looping = False
return ChatResponse(output="ERROR !")
else:
while looping:
order = {val: ind for ind, val in enumerate(ranked_models)}
sorted_models = sorted(models, key=lambda x: order.get(x, float('inf')))
for model in sorted_models:
for key in api_keys:
resp = requests.post("https://api.groq.com/openai/v1/chat/completions", verify=False, headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}, data=json.dumps({"model": model, "messages": [{"role": "user", "content": query}]}))
if resp.status_code == 200:
respJson = resp.json()
print("Asked to", model, "with the key ID", str(api_keys.index(key)), ":", query)
return ChatResponse(output=respJson["choices"][0]["message"]["content"])
print(resp.status_code, resp.text)
looping = False
return ChatResponse(output="ERROR !") |