Spaces:
Sleeping
Sleeping
File size: 1,854 Bytes
a438652 944b573 3b80834 944b573 3d03ce4 944b573 a438652 944b573 a438652 644ccbc a438652 944b573 cf9f7eb a438652 3d03ce4 a438652 3d03ce4 a438652 944b573 a438652 e9448a1 944b573 3d03ce4 944b573 8fa65f5 a438652 bf69a80 e9448a1 a438652 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
from fastapi import FastAPI, Request, Body, HTTPException, Depends
from fastapi.security import APIKeyHeader
from typing import Optional
from huggingface_hub import InferenceClient
import random
import os
API_URL = os.environ.get("API_URL")
API_KEY = os.environ.get("API_KEY")
MODEL_NAME = os.environ.get("MODEL_NAME")
client = InferenceClient(MODEL_NAME)
app = FastAPI()
security = APIKeyHeader(name="api_key", auto_error=False)
def get_api_key(api_key: Optional[str] = Depends(security)):
if api_key is None or api_key != API_KEY:
raise HTTPException(status_code=401, detail="Unauthorized access")
return api_key
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
@app.post("/api/v1/generate_text")
def generate_text(
request: Request,
body: dict = Body(...),
api_key: str = Depends(get_api_key)
):
prompt = body.get("prompt", "")
sys_prompt = body.get("sysPrompt", "")
temperature = body.get("temperature", 0.5)
top_p = body.get("top_p", 0.95)
max_new_tokens = body.get("max_new_tokens",512)
repetition_penalty = body.get("repetition_penalty", 1.0)
print(f"temperature + {temperature}")
history = [] # You might need to handle this based on your actual usage
formatted_prompt = format_prompt(prompt, history)
stream = client.text_generation(
formatted_prompt,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=random.randint(0, 10**7),
stream=False,
details=False,
return_full_text=False
)
return stream
|