Spaces:
Sleeping
Sleeping
```python | |
import requests | |
import json | |
# Build model mapping | |
original_models = [ | |
# OpenAI Models | |
"gpt-3.5-turbo", | |
"gpt-3.5-turbo-202201", | |
"gpt-4o", | |
"gpt-4o-2024-05-13", | |
"o1-preview", | |
# Claude Models | |
"claude", | |
"claude-3-5-sonnet", | |
"claude-sonnet-3.5", | |
"claude-3-5-sonnet-20240620", | |
# Meta/LLaMA Models | |
"@cf/meta/llama-2-7b-chat-fp16", | |
"@cf/meta/llama-2-7b-chat-int8", | |
"@cf/meta/llama-3-8b-instruct", | |
"@cf/meta/llama-3.1-8b-instruct", | |
"@cf/meta-llama/llama-2-7b-chat-hf-lora", | |
"llama-3.1-405b", | |
"llama-3.1-70b", | |
"llama-3.1-8b", | |
"meta-llama/Llama-2-7b-chat-hf", | |
"meta-llama/Llama-3.1-70B-Instruct", | |
"meta-llama/Llama-3.1-8B-Instruct", | |
"meta-llama/Llama-3.2-11B-Vision-Instruct", | |
"meta-llama/Llama-3.2-1B-Instruct", | |
"meta-llama/Llama-3.2-3B-Instruct", | |
"meta-llama/Llama-3.2-90B-Vision-Instruct", | |
"meta-llama/Llama-Guard-3-8B", | |
"meta-llama/Meta-Llama-3-70B-Instruct", | |
"meta-llama/Meta-Llama-3-8B-Instruct", | |
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", | |
"meta-llama/Meta-Llama-3.1-8B-Instruct", | |
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", | |
# Mistral Models | |
"mistral", | |
"mistral-large", | |
"@cf/mistral/mistral-7b-instruct-v0.1", | |
"@cf/mistral/mistral-7b-instruct-v0.2-lora", | |
"@hf/mistralai/mistral-7b-instruct-v0.2", | |
"mistralai/Mistral-7B-Instruct-v0.2", | |
"mistralai/Mistral-7B-Instruct-v0.3", | |
"mistralai/Mixtral-8x22B-Instruct-v0.1", | |
"mistralai/Mixtral-8x7B-Instruct-v0.1", | |
# Qwen Models | |
"@cf/qwen/qwen1.5-0.5b-chat", | |
"@cf/qwen/qwen1.5-1.8b-chat", | |
"@cf/qwen/qwen1.5-7b-chat-awq", | |
"@cf/qwen/qwen1.5-14b-chat-awq", | |
"Qwen/Qwen2.5-3B-Instruct", | |
"Qwen/Qwen2.5-72B-Instruct", | |
"Qwen/Qwen2.5-Coder-32B-Instruct", | |
# Google/Gemini Models | |
"@cf/google/gemma-2b-it-lora", | |
"@cf/google/gemma-7b-it-lora", | |
"@hf/google/gemma-7b-it", | |
"google/gemma-1.1-2b-it", | |
"google/gemma-1.1-7b-it", | |
"gemini-pro", | |
"gemini-1.5-pro", | |
"gemini-1.5-pro-latest", | |
"gemini-1.5-flash", | |
# Cohere Models | |
"c4ai-aya-23-35b", | |
"c4ai-aya-23-8b", | |
"command", | |
"command-light", | |
"command-light-nightly", | |
"command-nightly", | |
"command-r", | |
"command-r-08-2024", | |
"command-r-plus", | |
"command-r-plus-08-2024", | |
"rerank-english-v2.0", | |
"rerank-english-v3.0", | |
"rerank-multilingual-v2.0", | |
"rerank-multilingual-v3.0", | |
# Microsoft Models | |
"@cf/microsoft/phi-2", | |
"microsoft/DialoGPT-medium", | |
"microsoft/Phi-3-medium-4k-instruct", | |
"microsoft/Phi-3-mini-4k-instruct", | |
"microsoft/Phi-3.5-mini-instruct", | |
"microsoft/WizardLM-2-8x22B", | |
# Yi Models | |
"01-ai/Yi-1.5-34B-Chat", | |
"01-ai/Yi-34B-Chat", | |
] | |
# Create mapping from simplified model names to original model names | |
model_mapping = {} | |
simplified_models = [] | |
for original_model in original_models: | |
simplified_name = original_model.split('/')[-1] | |
if simplified_name in model_mapping: | |
# Conflict detected, handle as per instructions | |
print(f"Conflict detected for model name '{simplified_name}'. Excluding '{original_model}' from available models.") | |
continue | |
model_mapping[simplified_name] = original_model | |
simplified_models.append(simplified_name) | |
def generate( | |
model, | |
messages, | |
temperature=0.7, | |
top_p=1.0, | |
n=1, | |
stream=False, | |
stop=None, | |
max_tokens=None, | |
presence_penalty=0.0, | |
frequency_penalty=0.0, | |
logit_bias=None, | |
user=None, | |
timeout=30, | |
): | |
""" | |
Generates a chat completion using the provided model and messages. | |
""" | |
# Use the simplified model names | |
models = simplified_models | |
if model not in models: | |
raise ValueError(f"Invalid model: {model}. Choose from: {', '.join(models)}") | |
# Map simplified model name to original model name | |
original_model = model_mapping[model] | |
api_endpoint = "https://chat.typegpt.net/api/openai/v1/chat/completions" | |
headers = { | |
"authority": "chat.typegpt.net", | |
"accept": "application/json, text/event-stream", | |
"accept-language": "en-US,en;q=0.9", | |
"content-type": "application/json", | |
"origin": "https://chat.typegpt.net", | |
"referer": "https://chat.typegpt.net/", | |
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36" | |
} | |
# Payload | |
payload = { | |
"messages": messages, | |
"stream": stream, | |
"model": original_model, | |
"temperature": temperature, | |
"presence_penalty": presence_penalty, | |
"frequency_penalty": frequency_penalty, | |
"top_p": top_p, | |
} | |
# Only include max_tokens if it's not None | |
if max_tokens is not None: | |
payload["max_tokens"] = max_tokens | |
# Only include 'stop' if it's not None | |
if stop is not None: | |
payload["stop"] = stop | |
# Check if logit_bias is provided | |
if logit_bias is not None: | |
payload["logit_bias"] = logit_bias | |
# Include 'user' if provided | |
if user is not None: | |
payload["user"] = user | |
# Start the request | |
session = requests.Session() | |
response = session.post( | |
api_endpoint, headers=headers, json=payload, stream=stream, timeout=timeout | |
) | |
if not response.ok: | |
raise Exception(f"Failed to generate response - ({response.status_code}, {response.reason}) - {response.text}") | |
def stream_response(): | |
for line in response.iter_lines(): | |
if line: | |
line = line.decode("utf-8") | |
if line.startswith("data: "): | |
line = line[6:] # Remove "data: " prefix | |
if line.strip() == "[DONE]": | |
break | |
try: | |
data = json.loads(line) | |
yield data | |
except json.JSONDecodeError: | |
continue | |
if stream: | |
return stream_response() | |
else: | |
return response.json() | |
if __name__ == "__main__": | |
# Example usage | |
# model = "claude-3-5-sonnet-20240620" | |
# model = "qwen1.5-0.5b-chat" | |
# model = "llama-2-7b-chat-fp16" | |
model = "gpt-3.5-turbo" | |
messages = [ | |
{"role": "system", "content": "Be Detailed"}, | |
{"role": "user", "content": "What is the knowledge cut off? Be specific and also specify the month, year and date. If not sure, then provide approximate."} | |
] | |
# try: | |
# # For non-streamed response | |
# response = generate( | |
# model=model, | |
# messages=messages, | |
# temperature=0.5, | |
# max_tokens=4000, | |
# stream=False # Change to True for streaming | |
# ) | |
# if 'choices' in response: | |
# reply = response['choices'][0]['message']['content'] | |
# print(reply) | |
# else: | |
# print("No response received.") | |
# except Exception as e: | |
# print(e) | |
try: | |
# For streamed response | |
response = generate( | |
model=model, | |
messages=messages, | |
temperature=0.5, | |
max_tokens=4000, | |
stream=True, # Change to False for non-streamed response | |
) | |
for data in response: | |
if 'choices' in data: | |
reply = data['choices'][0]['delta']['content'] | |
print(reply, end="", flush=True) | |
else: | |
print("No response received.") | |
except Exception as e: | |
print(e) | |
``` | |
```python | |
from fastapi import FastAPI, Request, Response | |
from fastapi.responses import JSONResponse, StreamingResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
import uvicorn | |
import asyncio | |
import json | |
import requests | |
from TYPEGPT.typegpt_api import generate, model_mapping, simplified_models | |
from api_info import developer_info | |
app = FastAPI() | |
# Set up CORS middleware if needed | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
@app.get("/health_check") | |
async def health_check(): | |
return {"status": "OK"} | |
@app.get("/models") | |
async def get_models(): | |
# Retrieve models from TypeGPT API and forward the response | |
api_endpoint = "https://chat.typegpt.net/api/openai/v1/models" | |
try: | |
response = requests.get(api_endpoint) | |
# return response.text | |
return JSONResponse(content=response.json(), status_code=response.status_code) | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}, status_code=500) | |
@app.post("/chat/completions") | |
async def chat_completions(request: Request): | |
# Receive the JSON payload | |
try: | |
body = await request.json() | |
except Exception as e: | |
return JSONResponse(content={"error": "Invalid JSON payload"}, status_code=400) | |
# Extract parameters | |
model = body.get("model") | |
messages = body.get("messages") | |
temperature = body.get("temperature", 0.7) | |
top_p = body.get("top_p", 1.0) | |
n = body.get("n", 1) | |
stream = body.get("stream", False) | |
stop = body.get("stop") | |
max_tokens = body.get("max_tokens") | |
presence_penalty = body.get("presence_penalty", 0.0) | |
frequency_penalty = body.get("frequency_penalty", 0.0) | |
logit_bias = body.get("logit_bias") | |
user = body.get("user") | |
timeout = 30 # or set based on your preference | |
# Validate required parameters | |
if not model: | |
return JSONResponse(content={"error": "The 'model' parameter is required."}, status_code=400) | |
if not messages: | |
return JSONResponse(content={"error": "The 'messages' parameter is required."}, status_code=400) | |
# Call the generate function | |
try: | |
if stream: | |
async def generate_stream(): | |
response = generate( | |
model=model, | |
messages=messages, | |
temperature=temperature, | |
top_p=top_p, | |
n=n, | |
stream=True, | |
stop=stop, | |
max_tokens=max_tokens, | |
presence_penalty=presence_penalty, | |
frequency_penalty=frequency_penalty, | |
logit_bias=logit_bias, | |
user=user, | |
timeout=timeout, | |
) | |
for chunk in response: | |
yield f"data: {json.dumps(chunk)}\n\n" | |
yield "data: [DONE]\n\n" | |
return StreamingResponse( | |
generate_stream(), | |
media_type="text/event-stream", | |
headers={ | |
"Cache-Control": "no-cache", | |
"Connection": "keep-alive", | |
"Transfer-Encoding": "chunked" | |
} | |
) | |
else: | |
response = generate( | |
model=model, | |
messages=messages, | |
temperature=temperature, | |
top_p=top_p, | |
n=n, | |
stream=False, | |
stop=stop, | |
max_tokens=max_tokens, | |
presence_penalty=presence_penalty, | |
frequency_penalty=frequency_penalty, | |
logit_bias=logit_bias, | |
user=user, | |
timeout=timeout, | |
) | |
return JSONResponse(content=response) | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}, status_code=500) | |
@app.get("/developer_info") | |
async def get_developer_info(): | |
return JSONResponse(content=developer_info) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |
``` |