bai-granite / main.py
Pratham Bhat
Added granite files
bda6e7a
raw
history blame
3.08 kB
# from https://huggingface.co/spaces/iiced/mixtral-46.7b-fastapi/blob/main/main.py
# example of use:
# curl -X POST \
# -H "Content-Type: application/json" \
# -d '{
# "prompt": "What is the capital of France?",
# "history": [],
# "system_prompt": "You are a very powerful AI assistant."
# }' \
# https://phk0-bai.hf.space/generate/
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import uvicorn
import torch
# torch.mps.empty_cache()
# torch.set_num_threads(1)
app = FastAPI()
class Item(BaseModel):
prompt: str
history: list
system_prompt: str
temperature: float = 0.0
max_new_tokens: int = 900
top_p: float = 0.15
repetition_penalty: float = 1.0
def format_prompt(system, message, history):
prompt = [{"role": "system", "content": system}]
for user_prompt, bot_response in history:
prompt += {"role": "user", "content": user_prompt}
prompt += {"role": "assistant", "content": bot_response}
prompt += {"role": "user", "content": message}
return prompt
# def setup():
# device = "cuda" if torch.cuda.is_available() else "cpu"
# # if torch.backends.mps.is_available():
# # device = torch.device("mps")
# # x = torch.ones(1, device=device)
# # print (x)
# # else:
# # device="cpu"
# # print ("MPS device not found.")
# # device = "auto"
# # device=torch.device("cpu")
# model_path = "ibm-granite/granite-34b-code-instruct-8k"
# tokenizer = AutoTokenizer.from_pretrained(model_path)
# # drop device_map if running on CPU
# model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)
# model.eval()
# return model, tokenizer, device
def generate(item: Item):
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "ibm-granite/granite-34b-code-instruct-8k"
tokenizer = AutoTokenizer.from_pretrained(model_path)
# drop device_map if running on CPU
model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)
model.eval()
# change input text as desired
chat = format_prompt(item.system_prompt, item.prompt, item.history)
chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
# tokenize the text
input_tokens = tokenizer(chat, return_tensors="pt")
# transfer tokenized inputs to the device
for i in input_tokens:
input_tokens[i] = input_tokens[i].to(device)
# generate output tokens
output = model.generate(**input_tokens, max_new_tokens=900)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
return output_text
# model, tokenizer, device = setup()
@app.post("/generate/")
async def generate_text(item: Item):
return {"response": generate(item)}
# return {"response": generate(item, model, tokenizer, device)}
@app.get("/")
async def generate_text_root(item: Item):
return {"response": "try entry point: /generate/"}