File size: 3,982 Bytes
95d187a
 
bfe7166
95d187a
bfe7166
60624d8
95d187a
bfe7166
911802e
bfe7166
 
aa91b2f
bfe7166
 
 
 
 
 
 
 
aa91b2f
bfe7166
 
 
aa91b2f
bfe7166
 
 
 
 
95d187a
bfe7166
95d187a
 
bfe7166
 
 
 
60624d8
95d187a
 
 
 
 
 
60624d8
95d187a
 
 
 
 
 
 
 
60624d8
95d187a
 
 
 
 
60624d8
95d187a
 
60624d8
 
95d187a
 
 
 
60624d8
 
95d187a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60624d8
 
 
 
c64a110
1966659
95d187a
 
60624d8
95d187a
 
 
 
 
 
60624d8
95d187a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import torch
import importlib.util
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import os

# Download and import model components from HF Hub
model_repo = "elapt1c/hrom-testing"

# 1. Import trainer module components
trainer_file = hf_hub_download(repo_id=model_repo, filename="HROM_Trainer.py")
spec = importlib.util.spec_from_file_location("HROM_Trainer", trainer_file)
trainer_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(trainer_module)
HROM = trainer_module.HROM
CONFIG = trainer_module.CONFIG
SafetyManager = trainer_module.SafetyManager

# 2. Load tokenizer
tokenizer_file = hf_hub_download(repo_id=model_repo, filename="hrom_tokenizer.json")
tokenizer = Tokenizer.from_file(tokenizer_file)

# 3. Load model checkpoint
checkpoint_file = hf_hub_download(repo_id=model_repo, filename="hrom.pt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model():
    model = HROM().to(device)
    checkpoint = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(checkpoint['model'])
    model.eval()
    return model

model = load_model()
safety = SafetyManager(model, tokenizer)
max_response_length = 200

def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200):
    device = next(model.parameters()).device
    generated_ids = input_ids.copy()
    for _ in range(max_length):
        input_tensor = torch.tensor([generated_ids], device=device)
        with torch.no_grad():
            logits = model(input_tensor)
        next_token = logits.argmax(-1)[:, -1].item()
        if next_token == tokenizer.token_to_id("</s>"):
            break
        current_text = tokenizer.decode(generated_ids + [next_token])
        if not safety_manager.content_filter(current_text):
            break
        generated_ids.append(next_token)
    return generated_ids[len(input_ids):]

def process_message(user_input, chat_history, token_history):
    # Process user input
    user_turn = f"<user> {user_input} </s>"
    user_tokens = tokenizer.encode(user_turn).ids
    token_history.extend(user_tokens)
    
    # Prepare input sequence
    input_sequence = [tokenizer.token_to_id("<s>")] + token_history
    
    # Truncate if needed
    max_input_len = CONFIG["max_seq_len"] - max_response_length
    if len(input_sequence) > max_input_len:
        input_sequence = input_sequence[-max_input_len:]
        token_history = input_sequence[1:]
    
    # Generate response
    response_ids = generate_response(model, tokenizer, input_sequence, safety, max_response_length)
    
    # Process assistant response
    assistant_text = "I couldn't generate a proper response."
    if response_ids:
        if response_ids[0] == tokenizer.token_to_id("<assistant>"):
            try:
                end_idx = response_ids.index(tokenizer.token_to_id("</s>"))
                assistant_text = tokenizer.decode(response_ids[1:end_idx])
                token_history.extend(response_ids[:end_idx+1])
            except ValueError:
                assistant_text = tokenizer.decode(response_ids[1:])
                token_history.extend(response_ids)
        else:
            assistant_text = tokenizer.decode(response_ids)
            token_history.extend(response_ids)
    
    chat_history.append((user_input, assistant_text))
    return chat_history, token_history

def clear_history():
    return [], []

with gr.Blocks() as demo:
    gr.Markdown("# HROM-V1 Chatbot")
    chatbot = gr.Chatbot(height=500)
    msg = gr.Textbox(label="Your Message")
    token_state = gr.State([])
    
    msg.submit(
        process_message,
        [msg, chatbot, token_state],
        [chatbot, token_state],
        queue=False
    ).then(
        lambda: "", None, msg
    )
    
    clear_btn = gr.Button("Clear Chat History")
    clear_btn.click(
        clear_history,
        outputs=[chatbot, token_state],
        queue=False
    )

demo.launch()