File size: 1,938 Bytes
bee5263
 
 
3fbd422
a58b418
4091744
8e7163b
9f9177a
3fbd422
bee5263
e97eb3c
4091744
 
 
 
8e7163b
4091744
bee5263
 
 
 
 
4cbaa02
dbcfd8e
 
29fbbe7
dbcfd8e
 
 
bee5263
3fbd422
 
 
 
d465d44
f79168b
 
 
3fbd422
bee5263
 
3fbd422
4091744
9f9177a
3fbd422
4091744
 
bee5263
 
 
9f9177a
fff1df0
9f9177a
 
f6819c7
 
 
 
 
bab92d5
f0a5811
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn
import prompt_style
import os
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
import time


model_id = "failspy/Meta-Llama-3-8B-Instruct-abliterated-v3-GGUF"
filename="Meta-Llama-3-8B-Instruct-abliterated-v3_q6.gguf"
# model_path = hf_hub_download(repo_id=model_id, filename="Meta-Llama-3-8B-Instruct-abliterated-v3_q6.gguf", token=os.environ['HF_TOKEN'])
# model = Llama(model_path=model_path, n_gpu_layers=-1, n_ctx=4096, verbose=False)

model = Llama.from_pretrained(repo_id=model_id, filename=filename, n_gpu_layers=-1, token=os.environ['HF_TOKEN'], 
                              n_ctx=4096, verbose=False, attn_implementation="flash_attention_2")

class Item(BaseModel):
    prompt: str
    history: list
    system_prompt: str
    temperature: float = 0.8
    max_new_tokens: int = 1024
    top_p: float = 0.95
    repetition_penalty: float = 1.0
    seed : int = 42
    
app = FastAPI()

def format_prompt(item: Item):
    messages = [
        {"role": "system", "content": prompt_style.data},
    ]
    for it in item.history:
        messages.append({"role" : "user", "content": it[0]})
        messages.append({"role" : "assistant", "content": it[1]})
    messages.append({"role" : "user", "content": item.prompt})
    return messages

def generate(item: Item):
    formatted_prompt = format_prompt(item)
    output = model.create_chat_completion(messages=formatted_prompt, seed=item.seed,  
                                          temperature=item.temperature, max_tokens=item.max_new_tokens)

    out = output['choices'][0]['message']['content']
    return out

@app.post("/generate/")
async def generate_text(item: Item):
    t1 = time.time()
    ans = generate(item)
    print(ans)
    print(f"time: {str(time.time() - t1)}")
    return {"response": ans}


@app.get("/")
def read_root():
    
    return {"Hello": "Worlds"}