Update handler.py
Browse files- handler.py +24 -13
handler.py
CHANGED
@@ -2,7 +2,6 @@ import os
|
|
2 |
import torch
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
|
5 |
-
# Global variables for model, tokenizer, and device
|
6 |
model = None
|
7 |
tokenizer = None
|
8 |
device = None
|
@@ -13,44 +12,56 @@ def init():
|
|
13 |
"""
|
14 |
global model, tokenizer, device
|
15 |
|
16 |
-
#
|
17 |
model_name_or_path = "0xroyce/NazareAI-Senior-Marketing-Strategist"
|
18 |
|
|
|
19 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
|
|
|
|
20 |
model = AutoModelForCausalLM.from_pretrained(
|
21 |
model_name_or_path,
|
22 |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
23 |
low_cpu_mem_usage=True
|
24 |
)
|
25 |
|
|
|
26 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
27 |
model.to(device)
|
28 |
-
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
def inference(model_inputs: dict) -> dict:
|
31 |
"""
|
32 |
-
This function is called for every request
|
33 |
-
The input is a dictionary
|
|
|
34 |
"""
|
35 |
global model, tokenizer, device
|
36 |
|
37 |
-
#
|
38 |
prompt = model_inputs.get("prompt", "")
|
39 |
if not prompt:
|
40 |
return {"error": "No prompt provided."}
|
41 |
|
42 |
-
# Tokenize
|
43 |
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
44 |
|
45 |
-
#
|
46 |
-
# You can adjust parameters like max_new_tokens, temperature, or top_p as needed
|
47 |
output_ids = model.generate(
|
48 |
-
**inputs,
|
49 |
-
max_new_tokens=200,
|
50 |
-
do_sample=True,
|
51 |
-
top_p=0.9,
|
52 |
temperature=0.7
|
53 |
)
|
|
|
|
|
54 |
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
55 |
|
56 |
return {"generated_text": output_text}
|
|
|
2 |
import torch
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
|
|
|
5 |
model = None
|
6 |
tokenizer = None
|
7 |
device = None
|
|
|
12 |
"""
|
13 |
global model, tokenizer, device
|
14 |
|
15 |
+
# Replace this with your model repository ID
|
16 |
model_name_or_path = "0xroyce/NazareAI-Senior-Marketing-Strategist"
|
17 |
|
18 |
+
# Load the tokenizer
|
19 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
20 |
+
|
21 |
+
# Load the model
|
22 |
model = AutoModelForCausalLM.from_pretrained(
|
23 |
model_name_or_path,
|
24 |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
25 |
low_cpu_mem_usage=True
|
26 |
)
|
27 |
|
28 |
+
# Set up the device
|
29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
model.to(device)
|
31 |
+
model.eval()
|
32 |
+
|
33 |
+
# Store in global variables
|
34 |
+
globals()["model"] = model
|
35 |
+
globals()["tokenizer"] = tokenizer
|
36 |
+
globals()["device"] = device
|
37 |
+
|
38 |
|
39 |
def inference(model_inputs: dict) -> dict:
|
40 |
"""
|
41 |
+
This function is called for every request.
|
42 |
+
The input is a dictionary with a 'prompt' key.
|
43 |
+
The output is a dictionary with 'generated_text'.
|
44 |
"""
|
45 |
global model, tokenizer, device
|
46 |
|
47 |
+
# Get the prompt from the input
|
48 |
prompt = model_inputs.get("prompt", "")
|
49 |
if not prompt:
|
50 |
return {"error": "No prompt provided."}
|
51 |
|
52 |
+
# Tokenize the prompt
|
53 |
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
54 |
|
55 |
+
# Run generation
|
|
|
56 |
output_ids = model.generate(
|
57 |
+
**inputs,
|
58 |
+
max_new_tokens=200,
|
59 |
+
do_sample=True,
|
60 |
+
top_p=0.9,
|
61 |
temperature=0.7
|
62 |
)
|
63 |
+
|
64 |
+
# Decode the output
|
65 |
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
66 |
|
67 |
return {"generated_text": output_text}
|