import json from fastapi import FastAPI, HTTPException from dotenv import load_dotenv import os import re import requests from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import List, Dict, Any, Optional, Tuple load_dotenv() app = FastAPI() ranked_models = [ "llama-3.3-70b-versatile", "llama3-70b-8192", "meta-llama/llama-4-maverick-17b-128e-instruct", "meta-llama/llama-4-scout-17b-16e-instruct", "mistral-saba-24b", "gemma2-9b-it", "llama-3.1-8b-instant", "llama3-8b-8192" ] api_keys = [] for k,v in os.environ.items(): if re.match(r'^GROQ_\d+$', k): api_keys.append(v) app.add_middleware( CORSMiddleware, allow_credentials=True, allow_headers=["*"], allow_methods=["GET", "POST"], allow_origins=["*"] ) class ChatRequest(BaseModel): models: Optional[List[Any]] = [] query: str class ChatResponse(BaseModel): output: str @app.get("/") def main_page(): return {"status": "ok"} @app.post("/chat", response_model=ChatResponse) def ask_groq_llm(req: ChatRequest): models = req.models query = req.query looping = True if models == []: while looping: for model in ranked_models: for key in api_keys: resp = requests.post("https://api.groq.com/openai/v1/chat/completions", verify=False, headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}, data=json.dumps({"model": model, "messages": [{"role": "user", "content": query}]})) if resp.status_code == 200: respJson = resp.json() print("Asked to", model, "with the key ID", str(api_keys.index(key)), ":", query) return ChatResponse(output=respJson["choices"][0]["message"]["content"]) print(resp.status_code, resp.text) looping = False return ChatResponse(output="ERROR !") elif len(models) == 1: while looping: for key in api_keys: resp = requests.post("https://api.groq.com/openai/v1/chat/completions", verify=False, headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}, data=json.dumps({"model": models[0], "messages": [{"role": "user", "content": query}]})) if resp.status_code == 200: respJson = resp.json() print("Asked to", model, "with the key ID", str(api_keys.index(key)), ":", query) return ChatResponse(output=respJson["choices"][0]["message"]["content"]) print(resp.status_code, resp.text) looping = False return ChatResponse(output="ERROR !") else: while looping: order = {val: ind for ind, val in enumerate(ranked_models)} sorted_models = sorted(models, key=lambda x: order.get(x, float('inf'))) for model in sorted_models: for key in api_keys: resp = requests.post("https://api.groq.com/openai/v1/chat/completions", verify=False, headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}, data=json.dumps({"model": model, "messages": [{"role": "user", "content": query}]})) if resp.status_code == 200: respJson = resp.json() print("Asked to", model, "with the key ID", str(api_keys.index(key)), ":", query) return ChatResponse(output=respJson["choices"][0]["message"]["content"]) print(resp.status_code, resp.text) looping = False return ChatResponse(output="ERROR !")