Echo-ai
Update app.py
b839d79 verified
raw
history blame
7.36 kB
import os
import requests
import time
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse
from llama_cpp import Llama
from pydantic import BaseModel
import uvicorn
from typing import Generator
import threading
# Configuration
MODEL_URL = "https://huggingface.co/unsloth/DeepSeek-R1-Distill-Qwen-1.5B-GGUF/resolve/main/DeepSeek-R1-Distill-Qwen-1.5B-Q4_K_M.gguf" # Changed to Q4 for faster inference
MODEL_NAME = "DeepSeek-R1-Distill-Qwen-1.5B-Q4_K_M.gguf"
MODEL_DIR = "model"
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME)
# Create model directory if it doesn't exist
os.makedirs(MODEL_DIR, exist_ok=True)
# Download the model if it doesn't exist
if not os.path.exists(MODEL_PATH):
print(f"Downloading model from {MODEL_URL}...")
response = requests.get(MODEL_URL, stream=True)
if response.status_code == 200:
with open(MODEL_PATH, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print("Model downloaded successfully!")
else:
raise RuntimeError(f"Failed to download model: HTTP {response.status_code}")
else:
print("Model already exists. Skipping download.")
# Initialize FastAPI
app = FastAPI(
title="DeepSeek-R1 OpenAI-Compatible API",
description="Optimized OpenAI-compatible API with streaming support",
version="2.0.0"
)
# CORS Configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Global model loader with optimized settings
print("Loading model with optimized settings...")
try:
llm = Llama(
model_path=MODEL_PATH,
n_ctx=1024, # Reduced context window for faster processing
n_threads=8, # Increased threads for better CPU utilization
n_batch=512, # Larger batch size for improved throughput
n_gpu_layers=0,
use_mlock=True, # Prevent swapping to disk
verbose=False
)
print("Model loaded with optimized settings!")
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}")
# Streaming generator
def generate_stream(prompt: str, max_tokens: int, temperature: float, top_p: float) -> Generator[str, None, None]:
start_time = time.time()
stream = llm.create_completion(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=["</s>"],
stream=True
)
for chunk in stream:
delta = chunk['choices'][0]['text']
yield f"data: {delta}\n\n"
# Early stopping if taking too long
if time.time() - start_time > 30: # 30s timeout
break
# OpenAI-Compatible Request Schema
class ChatCompletionRequest(BaseModel):
model: str = "DeepSeek-R1-Distill-Qwen-1.5B"
messages: list[dict]
max_tokens: int = 256
temperature: float = 0.7
top_p: float = 0.9
stream: bool = False
# Enhanced root endpoint with performance info
@app.get("/", response_class=HTMLResponse)
async def root():
return f"""
<html>
<head>
<title>DeepSeek-R1 Optimized API</title>
<style>
body {{ font-family: Arial, sans-serif; max-width: 800px; margin: 20px auto; padding: 0 20px; }}
.warning {{ color: #dc3545; background: #ffeef0; padding: 15px; border-radius: 5px; }}
.info {{ color: #0c5460; background: #d1ecf1; padding: 15px; border-radius: 5px; }}
a {{ color: #007bff; text-decoration: none; }}
code {{ background: #f8f9fa; padding: 2px 4px; border-radius: 4px; }}
</style>
</head>
<body>
<h1>DeepSeek-R1 Optimized API</h1>
<div class="warning">
<h3>⚠️ Important Notice</h3>
<p>For private use, please duplicate this space:<br>
1. Click your profile picture in the top-right<br>
2. Select "Duplicate Space"<br>
3. Set visibility to Private</p>
</div>
<div class="info">
<h3>⚡ Performance Optimizations</h3>
<ul>
<li>Quantization: Q4_K_M (optimized speed/quality balance)</li>
<li>Batch processing: 512 tokens/chunk</li>
<li>Streaming support with 30s timeout</li>
<li>8 CPU threads utilization</li>
</ul>
</div>
<h2>API Documentation</h2>
<ul>
<li><a href="/docs">Interactive Swagger Documentation</a></li>
<li><a href="/redoc">ReDoc Documentation</a></li>
</ul>
<h2>Example Streaming Request</h2>
<pre>
curl -N -X POST "{os.environ.get('SPACE_HOST', 'http://localhost:7860')}/v1/chat/completions" \\
-H "Content-Type: application/json" \\
-d '{{
"messages": [{{"role": "user", "content": "Explain quantum computing"}}],
"stream": true,
"max_tokens": 150
}}'
</pre>
</body>
</html>
"""
# Async endpoint handler
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest):
try:
prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in request.messages])
prompt += "\nassistant:"
if request.stream:
return StreamingResponse(
generate_stream(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p
),
media_type="text/event-stream"
)
# Non-streaming response
start_time = time.time()
response = llm(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
stop=["</s>"]
)
return {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response['choices'][0]['text'].strip()
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": len(prompt),
"completion_tokens": len(response['choices'][0]['text']),
"total_tokens": len(prompt) + len(response['choices'][0]['text'])
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"model_loaded": True,
"performance_settings": {
"n_threads": llm.params.n_threads,
"n_ctx": llm.params.n_ctx,
"n_batch": llm.params.n_batch
}
}
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
timeout_keep_alive=300 # Keep alive for streaming connections
)