File size: 2,029 Bytes
bee5263
 
 
3fbd422
a58b418
4cbaa02
3fbd422
bee5263
4cbaa02
 
bee5263
 
 
 
 
4cbaa02
dbcfd8e
 
29fbbe7
dbcfd8e
 
 
bee5263
3fbd422
 
 
 
d465d44
f79168b
 
 
3fbd422
bee5263
 
4cbaa02
 
 
 
 
 
 
 
 
 
 
 
 
 
3fbd422
4cbaa02
 
 
 
 
 
 
 
 
 
f79168b
3fbd422
4cbaa02
 
bee5263
 
 
fff1df0
f6819c7
 
 
 
 
bab92d5
f0a5811
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn
import prompt_style
import os
from huggingface_hub import InferenceClient


model_id = "failspy/Meta-Llama-3-8B-Instruct-abliterated-v3"
client = InferenceClient(model_id, token=os.environ['HF_TOKEN'])

class Item(BaseModel):
    prompt: str
    history: list
    system_prompt: str
    temperature: float = 0.8
    max_new_tokens: int = 1024
    top_p: float = 0.95
    repetition_penalty: float = 1.0
    seed : int = 42
    
app = FastAPI()

def format_prompt(item: Item):
    messages = [
        {"role": "system", "content": prompt_style.data},
    ]
    for it in item.history:
        messages.append({"role" : "user", "content": it[0]})
        messages.append({"role" : "assistant", "content": it[1]})
    messages.append({"role" : "user", "content": item.prompt})
    return messages

def generate(item: Item):
    temperature = float(item.temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(item.top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=item.max_new_tokens,
        top_p=top_p,
        repetition_penalty=item.repetition_penalty,
        do_sample=True,
        seed=item.seed,
    )
    
    formatted_prompt = format_prompt(item)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
    return output
    
    # output = model.create_chat_completion(messages=formatted_prompt, seed=item.seed, 
    #                                       temperature=item.temperature,
    #                                       max_tokens=item.max_new_tokens)


    # out = output['choices'][0]['message']['content']
    # return out

@app.post("/generate/")
async def generate_text(item: Item):
    ans = generate(item)
    return {"response": ans}


@app.get("/")
def read_root():
    
    return {"Hello": "Worlds"}