Spaces:
Sleeping
Sleeping
Commit
·
bfe7166
1
Parent(s):
3d2f665
Update app.py
Browse files
app.py
CHANGED
@@ -1,20 +1,41 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
|
|
3 |
from tokenizers import Tokenizer
|
|
|
4 |
import os
|
5 |
-
from HROM_Trainer import HROM, CONFIG, SafetyManager
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
model.load_state_dict(checkpoint['model'])
|
|
|
16 |
return model
|
17 |
|
|
|
|
|
|
|
|
|
18 |
def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200):
|
19 |
device = next(model.parameters()).device
|
20 |
generated_ids = input_ids.copy()
|
@@ -31,15 +52,6 @@ def generate_response(model, tokenizer, input_ids, safety_manager, max_length=20
|
|
31 |
generated_ids.append(next_token)
|
32 |
return generated_ids[len(input_ids):]
|
33 |
|
34 |
-
# Initialize components once
|
35 |
-
tokenizer = Tokenizer.from_file("tokenizer/hrom_tokenizer.json")
|
36 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
-
model = HROM().to(device)
|
38 |
-
model = load_latest_checkpoint(model, device)
|
39 |
-
model.eval()
|
40 |
-
safety = SafetyManager(model, tokenizer)
|
41 |
-
max_response_length = 200
|
42 |
-
|
43 |
def process_message(user_input, chat_history, token_history):
|
44 |
# Process user input
|
45 |
user_turn = f"<user> {user_input} </s>"
|
@@ -80,7 +92,7 @@ def clear_history():
|
|
80 |
return [], []
|
81 |
|
82 |
with gr.Blocks() as demo:
|
83 |
-
gr.Markdown("# HROM Chatbot")
|
84 |
chatbot = gr.Chatbot(height=500)
|
85 |
msg = gr.Textbox(label="Your Message")
|
86 |
token_state = gr.State([])
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
import importlib.util
|
4 |
from tokenizers import Tokenizer
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
import os
|
|
|
7 |
|
8 |
+
# Download and import model components from HF Hub
|
9 |
+
model_repo = "TimurHromek/HROM-V1"
|
10 |
+
|
11 |
+
# 1. Import trainer module components
|
12 |
+
trainer_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.5_Trainer.py")
|
13 |
+
spec = importlib.util.spec_from_file_location("HROM_Trainer", trainer_file)
|
14 |
+
trainer_module = importlib.util.module_from_spec(spec)
|
15 |
+
spec.loader.exec_module(trainer_module)
|
16 |
+
HROM = trainer_module.HROM
|
17 |
+
CONFIG = trainer_module.CONFIG
|
18 |
+
SafetyManager = trainer_module.SafetyManager
|
19 |
+
|
20 |
+
# 2. Load tokenizer
|
21 |
+
tokenizer_file = hf_hub_download(repo_id=model_repo, filename="tokenizer/hrom_tokenizer.json")
|
22 |
+
tokenizer = Tokenizer.from_file(tokenizer_file)
|
23 |
+
|
24 |
+
# 3. Load model checkpoint
|
25 |
+
checkpoint_file = hf_hub_download(repo_id=model_repo, filename="checkpoints/HROM-V1.5.pt")
|
26 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
27 |
+
|
28 |
+
def load_model():
|
29 |
+
model = HROM().to(device)
|
30 |
+
checkpoint = torch.load(checkpoint_file, map_location=device)
|
31 |
model.load_state_dict(checkpoint['model'])
|
32 |
+
model.eval()
|
33 |
return model
|
34 |
|
35 |
+
model = load_model()
|
36 |
+
safety = SafetyManager(model, tokenizer)
|
37 |
+
max_response_length = 200
|
38 |
+
|
39 |
def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200):
|
40 |
device = next(model.parameters()).device
|
41 |
generated_ids = input_ids.copy()
|
|
|
52 |
generated_ids.append(next_token)
|
53 |
return generated_ids[len(input_ids):]
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
def process_message(user_input, chat_history, token_history):
|
56 |
# Process user input
|
57 |
user_turn = f"<user> {user_input} </s>"
|
|
|
92 |
return [], []
|
93 |
|
94 |
with gr.Blocks() as demo:
|
95 |
+
gr.Markdown("# HROM-V1 Chatbot")
|
96 |
chatbot = gr.Chatbot(height=500)
|
97 |
msg = gr.Textbox(label="Your Message")
|
98 |
token_state = gr.State([])
|