Canstralian commited on
Commit
940bdb8
·
verified ·
1 Parent(s): 0315757

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -11
app.py CHANGED
@@ -17,19 +17,18 @@ logging.basicConfig(
17
  def load_model():
18
  """
19
  Loads and caches the pre-trained language model and tokenizer.
20
-
21
  Returns:
22
  model: Pre-trained language model.
23
  tokenizer: Tokenizer for the model.
24
  """
25
  model_path = "Canstralian/pentest_ai"
26
  try:
 
27
  model = AutoModelForCausalLM.from_pretrained(
28
  model_path,
29
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
30
- device_map="auto",
31
- load_in_4bit=False,
32
- load_in_8bit=True,
33
  trust_remote_code=True,
34
  )
35
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
@@ -46,7 +45,6 @@ def sanitize_input(text):
46
 
47
  Args:
48
  text (str): User input text.
49
-
50
  Returns:
51
  str: Sanitized text.
52
  """
@@ -59,19 +57,18 @@ def sanitize_input(text):
59
  def generate_text(model, tokenizer, instruction):
60
  """
61
  Generates text based on the provided instruction using the loaded model.
62
-
63
  Args:
64
  model: The language model.
65
  tokenizer: Tokenizer for encoding/decoding.
66
  instruction (str): Instruction text for the model.
67
-
68
  Returns:
69
  str: Generated text response from the model.
70
  """
71
  try:
72
  # Validate and sanitize instruction input
73
  instruction = sanitize_input(instruction)
74
- tokens = tokenizer.encode(instruction, return_tensors='pt').to('cuda')
 
75
  generated_tokens = model.generate(
76
  tokens,
77
  max_length=1024,
@@ -90,7 +87,6 @@ def generate_text(model, tokenizer, instruction):
90
  def load_json_data():
91
  """
92
  Loads JSON data, simulating the loading process with a sample list.
93
-
94
  Returns:
95
  list: A list of dictionaries with sample user data.
96
  """
@@ -138,4 +134,3 @@ for user in user_data:
138
  st.write(f"**Country:** {user['country']}")
139
  st.write(f"**Company:** {user['company']}")
140
  st.write("---")
141
-
 
17
  def load_model():
18
  """
19
  Loads and caches the pre-trained language model and tokenizer.
 
20
  Returns:
21
  model: Pre-trained language model.
22
  tokenizer: Tokenizer for the model.
23
  """
24
  model_path = "Canstralian/pentest_ai"
25
  try:
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
  model = AutoModelForCausalLM.from_pretrained(
28
  model_path,
29
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
30
+ device_map={"": device}, # This will specify CPU or GPU explicitly
31
+ load_in_8bit=False, # Disabled for stability
 
32
  trust_remote_code=True,
33
  )
34
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 
45
 
46
  Args:
47
  text (str): User input text.
 
48
  Returns:
49
  str: Sanitized text.
50
  """
 
57
  def generate_text(model, tokenizer, instruction):
58
  """
59
  Generates text based on the provided instruction using the loaded model.
 
60
  Args:
61
  model: The language model.
62
  tokenizer: Tokenizer for encoding/decoding.
63
  instruction (str): Instruction text for the model.
 
64
  Returns:
65
  str: Generated text response from the model.
66
  """
67
  try:
68
  # Validate and sanitize instruction input
69
  instruction = sanitize_input(instruction)
70
+ device = "cuda" if torch.cuda.is_available() else "cpu"
71
+ tokens = tokenizer.encode(instruction, return_tensors='pt').to(device)
72
  generated_tokens = model.generate(
73
  tokens,
74
  max_length=1024,
 
87
  def load_json_data():
88
  """
89
  Loads JSON data, simulating the loading process with a sample list.
 
90
  Returns:
91
  list: A list of dictionaries with sample user data.
92
  """
 
134
  st.write(f"**Country:** {user['country']}")
135
  st.write(f"**Company:** {user['company']}")
136
  st.write("---")