Gokulram2710 commited on
Commit
3333487
·
verified ·
1 Parent(s): a8458ec

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +18 -0
handler.py CHANGED
@@ -2,12 +2,14 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import torch
3
  import subprocess
4
 
 
5
  subprocess.run('pip install flash-attn', shell=True)
6
 
7
  class CustomModelHandler:
8
  def __init__(self, model_name_or_path: str):
9
  self.model_name_or_path = model_name_or_path
10
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
11
  self.load_model()
12
 
13
  def load_model(self):
@@ -21,6 +23,16 @@ class CustomModelHandler:
21
  )
22
  self.model.to(self.device)
23
  print(f"Model loaded and moved to {self.device}")
 
 
 
 
 
 
 
 
 
 
24
  except Exception as e:
25
  print(f"An error occurred while loading the model: {e}")
26
  raise
@@ -41,3 +53,9 @@ handler = CustomModelHandler("microsoft/Phi-3-vision-128k-instruct")
41
  # Example prediction function
42
  def predict(input_text):
43
  return handler.predict(input_text)
 
 
 
 
 
 
 
2
  import torch
3
  import subprocess
4
 
5
+ # Install flash-attn
6
  subprocess.run('pip install flash-attn', shell=True)
7
 
8
  class CustomModelHandler:
9
  def __init__(self, model_name_or_path: str):
10
  self.model_name_or_path = model_name_or_path
11
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ print(f"Using device: {self.device}")
13
  self.load_model()
14
 
15
  def load_model(self):
 
23
  )
24
  self.model.to(self.device)
25
  print(f"Model loaded and moved to {self.device}")
26
+
27
+ # Check if the model parameters are on the GPU
28
+ all_on_gpu = all(param.device.type == 'cuda' for param in self.model.parameters())
29
+ if not all_on_gpu:
30
+ print("Warning: Not all model parameters are on the GPU!")
31
+ else:
32
+ print("All model parameters are on the GPU.")
33
+
34
+ # Confirm model device
35
+ print(f"Model is on device: {self.model.device}")
36
  except Exception as e:
37
  print(f"An error occurred while loading the model: {e}")
38
  raise
 
53
  # Example prediction function
54
  def predict(input_text):
55
  return handler.predict(input_text)
56
+
57
+ # Example usage
58
+ if __name__ == "__main__":
59
+ input_text = "Hello, how are you?"
60
+ predictions = predict(input_text)
61
+ print("Predictions:", predictions)