Pratham Bhat commited on
Commit
921b6d2
·
1 Parent(s): 611c4ac

Loads model before starting server

Browse files
Files changed (2) hide show
  1. main.py +23 -3
  2. requirements.txt +1 -1
main.py CHANGED
@@ -14,7 +14,8 @@ from pydantic import BaseModel
14
  from transformers import AutoModelForCausalLM, AutoTokenizer
15
  import uvicorn
16
  import torch
17
-
 
18
 
19
  app = FastAPI()
20
 
@@ -35,13 +36,29 @@ def format_prompt(system, message, history):
35
  prompt += {"role": "user", "content": message}
36
  return prompt
37
 
38
- def generate(item: Item):
39
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
40
  model_path = "ibm-granite/granite-34b-code-instruct-8k"
41
  tokenizer = AutoTokenizer.from_pretrained(model_path)
42
  # drop device_map if running on CPU
43
  model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)
44
  model.eval()
 
 
 
 
45
  # change input text as desired
46
  chat = format_prompt(item.system_prompt, item.prompt, item.history)
47
  chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
@@ -56,9 +73,12 @@ def generate(item: Item):
56
  return output_text
57
 
58
 
 
 
59
  @app.post("/generate/")
60
  async def generate_text(item: Item):
61
- return {"response": generate(item)}
 
62
 
63
  @app.get("/")
64
  async def generate_text_root(item: Item):
 
14
  from transformers import AutoModelForCausalLM, AutoTokenizer
15
  import uvicorn
16
  import torch
17
+ import sys
18
+ # torch.mps.empty_cache()
19
 
20
  app = FastAPI()
21
 
 
36
  prompt += {"role": "user", "content": message}
37
  return prompt
38
 
39
+ def setup():
40
  device = "cuda" if torch.cuda.is_available() else "cpu"
41
+
42
+ # if torch.backends.mps.is_available():
43
+ # device = torch.device("mps")
44
+ # x = torch.ones(1, device=device)
45
+ # print (x)
46
+ # else:
47
+ # device="cpu"
48
+ # print ("MPS device not found.")
49
+
50
+ # device = "auto"
51
+ # device=torch.device("cpu")
52
+
53
  model_path = "ibm-granite/granite-34b-code-instruct-8k"
54
  tokenizer = AutoTokenizer.from_pretrained(model_path)
55
  # drop device_map if running on CPU
56
  model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)
57
  model.eval()
58
+
59
+ return model, tokenizer, device
60
+
61
+ def generate(item: Item, model, tokenizer, device):
62
  # change input text as desired
63
  chat = format_prompt(item.system_prompt, item.prompt, item.history)
64
  chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
 
73
  return output_text
74
 
75
 
76
+ model, tokenizer, device = setup()
77
+
78
  @app.post("/generate/")
79
  async def generate_text(item: Item):
80
+ print(item, file=sys.stderr)
81
+ return {"response": generate(item, model, tokenizer, device)}
82
 
83
  @app.get("/")
84
  async def generate_text_root(item: Item):
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  fastapi
2
- uvicorn
3
  huggingface_hub
4
  pydantic
5
  transformers
 
1
  fastapi
2
+ uvicorn[standard]
3
  huggingface_hub
4
  pydantic
5
  transformers