File size: 3,649 Bytes
343c758
 
 
 
 
 
 
689776f
 
343c758
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689776f
 
 
 
 
 
 
 
343c758
689776f
 
 
 
 
 
 
 
 
 
 
 
 
 
343c758
689776f
 
343c758
 
 
 
 
b2fc61a
343c758
 
689776f
 
 
343c758
689776f
343c758
 
689776f
343c758
 
ab68772
b2fc61a
343c758
 
b2fc61a
343c758
 
689776f
343c758
689776f
343c758
 
689776f
b2fc61a
343c758
 
b2fc61a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import json
from fastapi import FastAPI, HTTPException
from dotenv import load_dotenv
import os
import re
import requests
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Optional, Literal, Union

load_dotenv()

app = FastAPI()

api_keys = []

for k,v in os.environ.items():
    if re.match(r'^GROQ_\d+$', k):
        api_keys.append(v)

app.add_middleware(
    CORSMiddleware,
    allow_credentials=True,
    allow_headers=["*"],
    allow_methods=["GET", "POST"],
    allow_origins=["*"]
)

class ChatMessage(BaseModel):
    role: Literal["system", "user", "assistant", "tool"]
    content: Optional[str]  # Null pour certains messages (ex: tool calls)
    name: Optional[str] = None
    function_call: Optional[dict] = None  # Déprécié
    tool_call_id: Optional[str] = None
    tool_calls: Optional[List[dict]] = None

class ChatRequest(BaseModel):
    models: Optional[List[str]] = []
    messages: List[ChatMessage]
    temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
    top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
    n: Optional[int] = Field(default=1, ge=1)
    stream: Optional[bool] = False
    stop: Optional[Union[str, List[str]]] = None
    max_tokens: Optional[int] = None
    presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
    frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
    logit_bias: Optional[dict] = None
    user: Optional[str] = None
    tools: Optional[List[dict]] = None
    tool_choice: Optional[Union[str, dict]] = None

def clean_message(msg: ChatMessage) -> dict:
    return {k: v for k, v in msg.model_dump().items() if v is not None}

@app.get("/")
def main_page():
    return {"status": "ok"}

@app.post("/chat")
def ask_groq_llm(req: ChatRequest):
    models = req.models
    if len(models) == 1 and models[0] == "":
        raise HTTPException(400, detail="Empty model field")
    messages = [clean_message(m) for m in req.messages]
    looping = True
    if len(models) == 1:
        while looping:
            for key in api_keys:
                resp = requests.post("https://api.groq.com/openai/v1/chat/completions", verify=False, headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}, data=json.dumps({"model": models[0], "messages": messages}))
                if resp.status_code == 200:
                    respJson = resp.json()
                    print("Asked to", models[0], "with the key ID", str(api_keys.index(key)+1), ":", messages)
                    return {"error": False, "content": respJson["choices"]}
                print(resp.status_code, resp.text)
            looping = False
        return {"error": True, "content": "Aucun des modèles, ni des clés ne fonctionne, patientez ...."}
    else:
        while looping:
            for model in models:
                for key in api_keys:
                    resp = requests.post("https://api.groq.com/openai/v1/chat/completions", verify=False, headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}, data=json.dumps({"model": model, "messages": messages}))
                    if resp.status_code == 200:
                        respJson = resp.json()
                        print("Asked to", model, "with the key ID", str(api_keys.index(key)+1), ":", messages)
                        return {"error": False, "content": respJson["choices"]}
                    print(resp.status_code, resp.text)
            looping = False
        return {"error": True, "content": "Aucun des modèles, ni des clés ne fonctionne, patientez ...."}