TimurHromek commited on
Commit
c17825d
·
verified ·
1 Parent(s): 76b9d13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -9,7 +9,7 @@ import os
9
  model_repo = "TimurHromek/HROM-V1"
10
 
11
  # 1. Import trainer module components
12
- trainer_file = hf_hub_download(repo_id=model_repo, filename="trainer-v1.6.py")
13
  spec = importlib.util.spec_from_file_location("HROM_Trainer", trainer_file)
14
  trainer_module = importlib.util.module_from_spec(spec)
15
  spec.loader.exec_module(trainer_module)
@@ -18,11 +18,11 @@ CONFIG = trainer_module.CONFIG
18
  SafetyManager = trainer_module.SafetyManager
19
 
20
  # 2. Load tokenizer
21
- tokenizer_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.6/tokenizer/hrom_tokenizer.json")
22
  tokenizer = Tokenizer.from_file(tokenizer_file)
23
 
24
  # 3. Load model checkpoint
25
- checkpoint_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.6/HROM-V1.6.pt")
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
 
28
  def load_model():
@@ -33,7 +33,7 @@ def load_model():
33
  return model
34
 
35
  model = load_model()
36
- safety = SafetyManager(model, tokenizer, device)
37
  max_response_length = 200
38
 
39
  def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200):
 
9
  model_repo = "TimurHromek/HROM-V1"
10
 
11
  # 1. Import trainer module components
12
+ trainer_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.5_Trainer.py")
13
  spec = importlib.util.spec_from_file_location("HROM_Trainer", trainer_file)
14
  trainer_module = importlib.util.module_from_spec(spec)
15
  spec.loader.exec_module(trainer_module)
 
18
  SafetyManager = trainer_module.SafetyManager
19
 
20
  # 2. Load tokenizer
21
+ tokenizer_file = hf_hub_download(repo_id=model_repo, filename="tokenizer/hrom_tokenizer.json")
22
  tokenizer = Tokenizer.from_file(tokenizer_file)
23
 
24
  # 3. Load model checkpoint
25
+ checkpoint_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.5_Trained-Model/HROM-V1.5.pt")
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
 
28
  def load_model():
 
33
  return model
34
 
35
  model = load_model()
36
+ safety = SafetyManager(model, tokenizer)
37
  max_response_length = 200
38
 
39
  def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200):