File size: 3,749 Bytes
95d187a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
import torch
from tokenizers import Tokenizer
import os
from HROM_Trainer import HROM, CONFIG, SafetyManager

def load_latest_checkpoint(model, device):
    checkpoint_dir = "checkpoints"
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")]
    if not checkpoints:
        raise FileNotFoundError("No checkpoints found.")
    checkpoints = sorted(checkpoints, key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
    latest_checkpoint = os.path.join(checkpoint_dir, checkpoints[0])
    checkpoint = torch.load(latest_checkpoint, map_location=device)
    model.load_state_dict(checkpoint['model'])
    return model

def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200):
    device = next(model.parameters()).device
    generated_ids = input_ids.copy()
    for _ in range(max_length):
        input_tensor = torch.tensor([generated_ids], device=device)
        with torch.no_grad():
            logits = model(input_tensor)
        next_token = logits.argmax(-1)[:, -1].item()
        if next_token == tokenizer.token_to_id("</s>"):
            break
        current_text = tokenizer.decode(generated_ids + [next_token])
        if not safety_manager.content_filter(current_text):
            break
        generated_ids.append(next_token)
    return generated_ids[len(input_ids):]

# Initialize components once
tokenizer = Tokenizer.from_file("tokenizer/hrom_tokenizer.json")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HROM().to(device)
model = load_latest_checkpoint(model, device)
model.eval()
safety = SafetyManager(model, tokenizer)
max_response_length = 200

def process_message(user_input, chat_history, token_history):
    # Process user input
    user_turn = f"<user> {user_input} </s>"
    user_tokens = tokenizer.encode(user_turn).ids
    token_history.extend(user_tokens)
    
    # Prepare input sequence
    input_sequence = [tokenizer.token_to_id("<s>")] + token_history
    
    # Truncate if needed
    max_input_len = CONFIG["max_seq_len"] - max_response_length
    if len(input_sequence) > max_input_len:
        input_sequence = input_sequence[-max_input_len:]
        token_history = input_sequence[1:]
    
    # Generate response
    response_ids = generate_response(model, tokenizer, input_sequence, safety, max_response_length)
    
    # Process assistant response
    assistant_text = "I couldn't generate a proper response."
    if response_ids:
        if response_ids[0] == tokenizer.token_to_id("<assistant>"):
            try:
                end_idx = response_ids.index(tokenizer.token_to_id("</s>"))
                assistant_text = tokenizer.decode(response_ids[1:end_idx])
                token_history.extend(response_ids[:end_idx+1])
            except ValueError:
                assistant_text = tokenizer.decode(response_ids[1:])
                token_history.extend(response_ids)
        else:
            assistant_text = tokenizer.decode(response_ids)
            token_history.extend(response_ids)
    
    chat_history.append((user_input, assistant_text))
    return chat_history, token_history

def clear_history():
    return [], []

with gr.Blocks() as demo:
    gr.Markdown("# HROM Chatbot")
    chatbot = gr.Chatbot(height=500)
    msg = gr.Textbox(label="Your Message")
    token_state = gr.State([])
    
    msg.submit(
        process_message,
        [msg, chatbot, token_state],
        [chatbot, token_state],
        queue=False
    ).then(
        lambda: "", None, msg
    )
    
    clear_btn = gr.Button("Clear Chat History")
    clear_btn.click(
        clear_history,
        outputs=[chatbot, token_state],
        queue=False
    )

demo.launch()