File size: 3,082 Bytes
f156ceb
8d05041
 
 
 
 
 
 
 
611c4ac
8d05041
f156ceb
 
bda6e7a
f156ceb
611c4ac
921b6d2
b2344d3
f156ceb
 
 
 
 
 
 
 
611c4ac
f156ceb
 
 
611c4ac
 
f156ceb
611c4ac
 
 
f156ceb
 
853e734
 
921b6d2
853e734
 
 
 
 
 
 
 
 
 
921b6d2
853e734
 
 
 
 
921b6d2
853e734
 
 
 
 
611c4ac
bda6e7a
611c4ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f156ceb
 
853e734
921b6d2
f156ceb
 
853e734
 
611c4ac
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# from https://huggingface.co/spaces/iiced/mixtral-46.7b-fastapi/blob/main/main.py
# example of use:
# curl -X POST \
#   -H "Content-Type: application/json" \
#   -d '{
#         "prompt": "What is the capital of France?",
#         "history": [],
#         "system_prompt": "You are a very powerful AI assistant."
#       }' \
#   https://phk0-bai.hf.space/generate/

from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import uvicorn
import torch
# torch.mps.empty_cache()
# torch.set_num_threads(1)

app = FastAPI()

class Item(BaseModel):
    prompt: str
    history: list
    system_prompt: str
    temperature: float = 0.0
    max_new_tokens: int = 900
    top_p: float = 0.15
    repetition_penalty: float = 1.0

def format_prompt(system, message, history):
    prompt = [{"role": "system", "content": system}] 
    for user_prompt, bot_response in history:
        prompt += {"role": "user", "content": user_prompt}
        prompt += {"role": "assistant", "content": bot_response}
    prompt += {"role": "user", "content": message}
    return prompt

# def setup():
#     device = "cuda" if torch.cuda.is_available() else "cpu"

#     # if torch.backends.mps.is_available():
#     #     device = torch.device("mps")
#     #     x = torch.ones(1, device=device)
#     #     print (x)
#     # else:
#     #     device="cpu"
#     #     print ("MPS device not found.")
    
#     # device = "auto"
#     # device=torch.device("cpu")
    
#     model_path = "ibm-granite/granite-34b-code-instruct-8k"
#     tokenizer = AutoTokenizer.from_pretrained(model_path)
#     # drop device_map if running on CPU
#     model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)
#     model.eval()
    
#     return model, tokenizer, device

def generate(item: Item):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model_path = "ibm-granite/granite-34b-code-instruct-8k"
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    # drop device_map if running on CPU
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)
    model.eval()
    # change input text as desired
    chat = format_prompt(item.system_prompt, item.prompt, item.history)
    chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    # tokenize the text
    input_tokens = tokenizer(chat, return_tensors="pt")
    # transfer tokenized inputs to the device
    for i in input_tokens:
        input_tokens[i] = input_tokens[i].to(device)
    # generate output tokens
    output = model.generate(**input_tokens, max_new_tokens=900)
    output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
    return output_text


# model, tokenizer, device = setup()

@app.post("/generate/")
async def generate_text(item: Item):
    return {"response": generate(item)}
    # return {"response": generate(item, model, tokenizer, device)}

@app.get("/")
async def generate_text_root(item: Item):
    return {"response": "try entry point: /generate/"}