Spaces:
Sleeping
Sleeping
File size: 4,246 Bytes
501c69f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from flask import Flask, request, Response, jsonify, stream_with_context
from flask_cors import CORS
import json
from typegpt_api import generate, model_mapping, simplified_models
from api_info import developer_info, model_providers
app = Flask(__name__)
# Set up CORS middleware if needed
CORS(app, resources={
r"/*": {
"origins": "*",
"allow_credentials": True,
"methods": ["*"],
"headers": ["*"]
}
})
@app.route("/health_check", methods=['GET'])
def health_check():
return jsonify({"status": "OK"})
@app.route("/models", methods=['GET'])
def get_models():
try:
response = {
"object": "list",
"data": []
}
for provider, info in model_providers.items():
for model in info["models"]:
response["data"].append({
"id": model,
"object": "model",
"provider": provider,
"description": info["description"]
})
return jsonify(response)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/chat/completions", methods=['POST'])
def chat_completions():
# Receive the JSON payload
try:
body = request.get_json()
except Exception as e:
return jsonify({"error": "Invalid JSON payload"}), 400
# Extract parameters
model = body.get("model")
messages = body.get("messages")
temperature = body.get("temperature", 0.7)
top_p = body.get("top_p", 1.0)
n = body.get("n", 1)
stream = body.get("stream", False)
stop = body.get("stop")
max_tokens = body.get("max_tokens")
presence_penalty = body.get("presence_penalty", 0.0)
frequency_penalty = body.get("frequency_penalty", 0.0)
logit_bias = body.get("logit_bias")
user = body.get("user")
timeout = 30 # or set based on your preference
# Validate required parameters
if not model:
return jsonify({"error": "The 'model' parameter is required."}), 400
if not messages:
return jsonify({"error": "The 'messages' parameter is required."}), 400
# Call the generate function
try:
if stream:
def generate_stream():
response = generate(
model=model,
messages=messages,
temperature=temperature,
top_p=top_p,
n=n,
stream=True,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
timeout=timeout,
)
for chunk in response:
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
return Response(
stream_with_context(generate_stream()),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked"
}
)
else:
response = generate(
model=model,
messages=messages,
temperature=temperature,
top_p=top_p,
n=n,
stream=False,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
timeout=timeout,
)
return jsonify(response)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/developer_info", methods=['GET'])
def get_developer_info():
return jsonify(developer_info)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=8000) |