Gokulram2710 commited on
Commit
dd96f01
·
verified ·
1 Parent(s): 528b129

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +26 -25
handler.py CHANGED
@@ -1,26 +1,27 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
2
- import torch
3
-
4
- class CustomModelHandler:
5
- def __init__(self, model_name_or_path: str):
6
- self.model_name_or_path = model_name_or_path
7
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
- self.load_model()
9
-
10
- def load_model(self):
11
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, trust_remote_code=True)
12
- self.model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, trust_remote_code=True)
13
- self.model.to(self.device)
14
-
15
- def predict(self, inputs):
16
- inputs = self.tokenizer(inputs, return_tensors="pt").to(self.device)
17
- outputs = self.model.generate(**inputs)
18
- predictions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
19
- return predictions
20
-
21
- # Initialize the handler with the model path
22
- handler = CustomModelHandler("microsoft/Phi-3-vision-128k-instruct")
23
-
24
- # Example prediction function
25
- def predict(input_text):
 
26
  return handler.predict(input_text)
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ import torch
3
+
4
+ class CustomModelHandler:
5
+ def __init__(self, model_name_or_path: str):
6
+ self.model_name_or_path = model_name_or_path
7
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ self.load_model()
9
+
10
+ def load_model(self):
11
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, trust_remote_code=True)
12
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, trust_remote_code=True, torch_dtype="auto",
13
+ use_flash_attn=False)
14
+ self.model.to(self.device)
15
+
16
+ def predict(self, inputs):
17
+ inputs = self.tokenizer(inputs, return_tensors="pt").to(self.device)
18
+ outputs = self.model.generate(**inputs)
19
+ predictions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
20
+ return predictions
21
+
22
+ # Initialize the handler with the model path
23
+ handler = CustomModelHandler("microsoft/Phi-3-vision-128k-instruct")
24
+
25
+ # Example prediction function
26
+ def predict(input_text):
27
  return handler.predict(input_text)