File size: 1,854 Bytes
a438652
 
 
944b573
 
3b80834
944b573
3d03ce4
 
 
 
 
944b573
 
a438652
944b573
a438652
 
644ccbc
a438652
 
 
 
 
 
 
 
 
944b573
cf9f7eb
a438652
 
 
 
 
 
 
 
 
3d03ce4
a438652
3d03ce4
a438652
944b573
a438652
e9448a1
944b573
 
3d03ce4
944b573
 
 
 
8fa65f5
a438652
bf69a80
e9448a1
 
a438652
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from fastapi import FastAPI, Request, Body, HTTPException, Depends
from fastapi.security import APIKeyHeader
from typing import Optional
from huggingface_hub import InferenceClient
import random
import os

API_URL = os.environ.get("API_URL")
API_KEY = os.environ.get("API_KEY")
MODEL_NAME = os.environ.get("MODEL_NAME")

client = InferenceClient(MODEL_NAME)
app = FastAPI()

security = APIKeyHeader(name="api_key", auto_error=False)

def get_api_key(api_key: Optional[str] = Depends(security)):
    if api_key is None or api_key != API_KEY:
        raise HTTPException(status_code=401, detail="Unauthorized access")
    return api_key

def format_prompt(message, history):
    prompt = "<s>"
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    prompt += f"[INST] {message} [/INST]"
    return prompt

@app.post("/api/v1/generate_text")
def generate_text(
    request: Request,
    body: dict = Body(...),
    api_key: str = Depends(get_api_key)
):
    prompt = body.get("prompt", "")
    sys_prompt = body.get("sysPrompt", "")
    temperature = body.get("temperature", 0.5)
    top_p = body.get("top_p", 0.95)
    max_new_tokens = body.get("max_new_tokens",512)
    repetition_penalty = body.get("repetition_penalty", 1.0)
    print(f"temperature + {temperature}")
    history = []  # You might need to handle this based on your actual usage
    formatted_prompt = format_prompt(prompt, history)

    stream = client.text_generation(
        formatted_prompt,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=random.randint(0, 10**7),
        stream=False,
        details=False,
        return_full_text=False
    )

    return stream