Update handler.py
Browse files- handler.py +18 -0
handler.py
CHANGED
@@ -2,12 +2,14 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
2 |
import torch
|
3 |
import subprocess
|
4 |
|
|
|
5 |
subprocess.run('pip install flash-attn', shell=True)
|
6 |
|
7 |
class CustomModelHandler:
|
8 |
def __init__(self, model_name_or_path: str):
|
9 |
self.model_name_or_path = model_name_or_path
|
10 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
11 |
self.load_model()
|
12 |
|
13 |
def load_model(self):
|
@@ -21,6 +23,16 @@ class CustomModelHandler:
|
|
21 |
)
|
22 |
self.model.to(self.device)
|
23 |
print(f"Model loaded and moved to {self.device}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
except Exception as e:
|
25 |
print(f"An error occurred while loading the model: {e}")
|
26 |
raise
|
@@ -41,3 +53,9 @@ handler = CustomModelHandler("microsoft/Phi-3-vision-128k-instruct")
|
|
41 |
# Example prediction function
|
42 |
def predict(input_text):
|
43 |
return handler.predict(input_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
import subprocess
|
4 |
|
5 |
+
# Install flash-attn
|
6 |
subprocess.run('pip install flash-attn', shell=True)
|
7 |
|
8 |
class CustomModelHandler:
|
9 |
def __init__(self, model_name_or_path: str):
|
10 |
self.model_name_or_path = model_name_or_path
|
11 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
+
print(f"Using device: {self.device}")
|
13 |
self.load_model()
|
14 |
|
15 |
def load_model(self):
|
|
|
23 |
)
|
24 |
self.model.to(self.device)
|
25 |
print(f"Model loaded and moved to {self.device}")
|
26 |
+
|
27 |
+
# Check if the model parameters are on the GPU
|
28 |
+
all_on_gpu = all(param.device.type == 'cuda' for param in self.model.parameters())
|
29 |
+
if not all_on_gpu:
|
30 |
+
print("Warning: Not all model parameters are on the GPU!")
|
31 |
+
else:
|
32 |
+
print("All model parameters are on the GPU.")
|
33 |
+
|
34 |
+
# Confirm model device
|
35 |
+
print(f"Model is on device: {self.model.device}")
|
36 |
except Exception as e:
|
37 |
print(f"An error occurred while loading the model: {e}")
|
38 |
raise
|
|
|
53 |
# Example prediction function
|
54 |
def predict(input_text):
|
55 |
return handler.predict(input_text)
|
56 |
+
|
57 |
+
# Example usage
|
58 |
+
if __name__ == "__main__":
|
59 |
+
input_text = "Hello, how are you?"
|
60 |
+
predictions = predict(input_text)
|
61 |
+
print("Predictions:", predictions)
|