TimurHromek commited on
Commit
bfe7166
·
1 Parent(s): 3d2f665

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -19
app.py CHANGED
@@ -1,20 +1,41 @@
1
  import gradio as gr
2
  import torch
 
3
  from tokenizers import Tokenizer
 
4
  import os
5
- from HROM_Trainer import HROM, CONFIG, SafetyManager
6
 
7
- def load_latest_checkpoint(model, device):
8
- checkpoint_dir = "checkpoints"
9
- checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")]
10
- if not checkpoints:
11
- raise FileNotFoundError("No checkpoints found.")
12
- checkpoints = sorted(checkpoints, key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
13
- latest_checkpoint = os.path.join(checkpoint_dir, checkpoints[0])
14
- checkpoint = torch.load(latest_checkpoint, map_location=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  model.load_state_dict(checkpoint['model'])
 
16
  return model
17
 
 
 
 
 
18
  def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200):
19
  device = next(model.parameters()).device
20
  generated_ids = input_ids.copy()
@@ -31,15 +52,6 @@ def generate_response(model, tokenizer, input_ids, safety_manager, max_length=20
31
  generated_ids.append(next_token)
32
  return generated_ids[len(input_ids):]
33
 
34
- # Initialize components once
35
- tokenizer = Tokenizer.from_file("tokenizer/hrom_tokenizer.json")
36
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
- model = HROM().to(device)
38
- model = load_latest_checkpoint(model, device)
39
- model.eval()
40
- safety = SafetyManager(model, tokenizer)
41
- max_response_length = 200
42
-
43
  def process_message(user_input, chat_history, token_history):
44
  # Process user input
45
  user_turn = f"<user> {user_input} </s>"
@@ -80,7 +92,7 @@ def clear_history():
80
  return [], []
81
 
82
  with gr.Blocks() as demo:
83
- gr.Markdown("# HROM Chatbot")
84
  chatbot = gr.Chatbot(height=500)
85
  msg = gr.Textbox(label="Your Message")
86
  token_state = gr.State([])
 
1
  import gradio as gr
2
  import torch
3
+ import importlib.util
4
  from tokenizers import Tokenizer
5
+ from huggingface_hub import hf_hub_download
6
  import os
 
7
 
8
+ # Download and import model components from HF Hub
9
+ model_repo = "TimurHromek/HROM-V1"
10
+
11
+ # 1. Import trainer module components
12
+ trainer_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.5_Trainer.py")
13
+ spec = importlib.util.spec_from_file_location("HROM_Trainer", trainer_file)
14
+ trainer_module = importlib.util.module_from_spec(spec)
15
+ spec.loader.exec_module(trainer_module)
16
+ HROM = trainer_module.HROM
17
+ CONFIG = trainer_module.CONFIG
18
+ SafetyManager = trainer_module.SafetyManager
19
+
20
+ # 2. Load tokenizer
21
+ tokenizer_file = hf_hub_download(repo_id=model_repo, filename="tokenizer/hrom_tokenizer.json")
22
+ tokenizer = Tokenizer.from_file(tokenizer_file)
23
+
24
+ # 3. Load model checkpoint
25
+ checkpoint_file = hf_hub_download(repo_id=model_repo, filename="checkpoints/HROM-V1.5.pt")
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+
28
+ def load_model():
29
+ model = HROM().to(device)
30
+ checkpoint = torch.load(checkpoint_file, map_location=device)
31
  model.load_state_dict(checkpoint['model'])
32
+ model.eval()
33
  return model
34
 
35
+ model = load_model()
36
+ safety = SafetyManager(model, tokenizer)
37
+ max_response_length = 200
38
+
39
  def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200):
40
  device = next(model.parameters()).device
41
  generated_ids = input_ids.copy()
 
52
  generated_ids.append(next_token)
53
  return generated_ids[len(input_ids):]
54
 
 
 
 
 
 
 
 
 
 
55
  def process_message(user_input, chat_history, token_history):
56
  # Process user input
57
  user_turn = f"<user> {user_input} </s>"
 
92
  return [], []
93
 
94
  with gr.Blocks() as demo:
95
+ gr.Markdown("# HROM-V1 Chatbot")
96
  chatbot = gr.Chatbot(height=500)
97
  msg = gr.Textbox(label="Your Message")
98
  token_state = gr.State([])