|
import json |
|
from fastapi import FastAPI, HTTPException |
|
from dotenv import load_dotenv |
|
import os |
|
import re |
|
import requests |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel, Field |
|
from typing import List, Optional, Literal, Union |
|
|
|
load_dotenv() |
|
|
|
app = FastAPI() |
|
|
|
api_keys = [] |
|
|
|
for k,v in os.environ.items(): |
|
if re.match(r'^GROQ_\d+$', k): |
|
api_keys.append(v) |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_credentials=True, |
|
allow_headers=["*"], |
|
allow_methods=["GET", "POST"], |
|
allow_origins=["*"] |
|
) |
|
|
|
class ChatMessage(BaseModel): |
|
role: Literal["system", "user", "assistant", "tool"] |
|
content: Optional[str] |
|
name: Optional[str] = None |
|
function_call: Optional[dict] = None |
|
tool_call_id: Optional[str] = None |
|
tool_calls: Optional[List[dict]] = None |
|
|
|
class ChatRequest(BaseModel): |
|
models: Optional[List[str]] = [] |
|
messages: List[ChatMessage] |
|
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0) |
|
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0) |
|
n: Optional[int] = Field(default=1, ge=1) |
|
stream: Optional[bool] = False |
|
stop: Optional[Union[str, List[str]]] = None |
|
max_tokens: Optional[int] = None |
|
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) |
|
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) |
|
logit_bias: Optional[dict] = None |
|
user: Optional[str] = None |
|
tools: Optional[List[dict]] = None |
|
tool_choice: Optional[Union[str, dict]] = None |
|
|
|
def clean_message(msg: ChatMessage) -> dict: |
|
return {k: v for k, v in msg.model_dump().items() if v is not None} |
|
|
|
@app.get("/") |
|
def main_page(): |
|
return {"status": "ok"} |
|
|
|
@app.post("/chat") |
|
def ask_groq_llm(req: ChatRequest): |
|
models = req.models |
|
if len(models) == 1 and models[0] == "": |
|
raise HTTPException(400, detail="Empty model field") |
|
messages = [clean_message(m) for m in req.messages] |
|
looping = True |
|
if len(models) == 1: |
|
while looping: |
|
for key in api_keys: |
|
resp = requests.post("https://api.groq.com/openai/v1/chat/completions", verify=False, headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}, data=json.dumps({"model": models[0], "messages": messages})) |
|
if resp.status_code == 200: |
|
respJson = resp.json() |
|
print("Asked to", models[0], "with the key ID", str(api_keys.index(key)+1), ":", messages) |
|
return {"error": False, "content": respJson["choices"]} |
|
print(resp.status_code, resp.text) |
|
looping = False |
|
return {"error": True, "content": "Aucun des modèles, ni des clés ne fonctionne, patientez ...."} |
|
else: |
|
while looping: |
|
for model in models: |
|
for key in api_keys: |
|
resp = requests.post("https://api.groq.com/openai/v1/chat/completions", verify=False, headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}, data=json.dumps({"model": model, "messages": messages})) |
|
if resp.status_code == 200: |
|
respJson = resp.json() |
|
print("Asked to", model, "with the key ID", str(api_keys.index(key)+1), ":", messages) |
|
return {"error": False, "content": respJson["choices"]} |
|
print(resp.status_code, resp.text) |
|
looping = False |
|
return {"error": True, "content": "Aucun des modèles, ni des clés ne fonctionne, patientez ...."} |