Spaces:
Runtime error
Runtime error
File size: 3,982 Bytes
95d187a bfe7166 95d187a bfe7166 60624d8 95d187a bfe7166 911802e bfe7166 aa91b2f bfe7166 aa91b2f bfe7166 aa91b2f bfe7166 95d187a bfe7166 95d187a bfe7166 60624d8 95d187a 60624d8 95d187a 60624d8 95d187a 60624d8 95d187a 60624d8 95d187a 60624d8 95d187a 60624d8 c64a110 1966659 95d187a 60624d8 95d187a 60624d8 95d187a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import gradio as gr
import torch
import importlib.util
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import os
# Download and import model components from HF Hub
model_repo = "elapt1c/hrom-testing"
# 1. Import trainer module components
trainer_file = hf_hub_download(repo_id=model_repo, filename="HROM_Trainer.py")
spec = importlib.util.spec_from_file_location("HROM_Trainer", trainer_file)
trainer_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(trainer_module)
HROM = trainer_module.HROM
CONFIG = trainer_module.CONFIG
SafetyManager = trainer_module.SafetyManager
# 2. Load tokenizer
tokenizer_file = hf_hub_download(repo_id=model_repo, filename="hrom_tokenizer.json")
tokenizer = Tokenizer.from_file(tokenizer_file)
# 3. Load model checkpoint
checkpoint_file = hf_hub_download(repo_id=model_repo, filename="hrom.pt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
model = HROM().to(device)
checkpoint = torch.load(checkpoint_file, map_location=device)
model.load_state_dict(checkpoint['model'])
model.eval()
return model
model = load_model()
safety = SafetyManager(model, tokenizer)
max_response_length = 200
def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200):
device = next(model.parameters()).device
generated_ids = input_ids.copy()
for _ in range(max_length):
input_tensor = torch.tensor([generated_ids], device=device)
with torch.no_grad():
logits = model(input_tensor)
next_token = logits.argmax(-1)[:, -1].item()
if next_token == tokenizer.token_to_id("</s>"):
break
current_text = tokenizer.decode(generated_ids + [next_token])
if not safety_manager.content_filter(current_text):
break
generated_ids.append(next_token)
return generated_ids[len(input_ids):]
def process_message(user_input, chat_history, token_history):
# Process user input
user_turn = f"<user> {user_input} </s>"
user_tokens = tokenizer.encode(user_turn).ids
token_history.extend(user_tokens)
# Prepare input sequence
input_sequence = [tokenizer.token_to_id("<s>")] + token_history
# Truncate if needed
max_input_len = CONFIG["max_seq_len"] - max_response_length
if len(input_sequence) > max_input_len:
input_sequence = input_sequence[-max_input_len:]
token_history = input_sequence[1:]
# Generate response
response_ids = generate_response(model, tokenizer, input_sequence, safety, max_response_length)
# Process assistant response
assistant_text = "I couldn't generate a proper response."
if response_ids:
if response_ids[0] == tokenizer.token_to_id("<assistant>"):
try:
end_idx = response_ids.index(tokenizer.token_to_id("</s>"))
assistant_text = tokenizer.decode(response_ids[1:end_idx])
token_history.extend(response_ids[:end_idx+1])
except ValueError:
assistant_text = tokenizer.decode(response_ids[1:])
token_history.extend(response_ids)
else:
assistant_text = tokenizer.decode(response_ids)
token_history.extend(response_ids)
chat_history.append((user_input, assistant_text))
return chat_history, token_history
def clear_history():
return [], []
with gr.Blocks() as demo:
gr.Markdown("# HROM-V1 Chatbot")
chatbot = gr.Chatbot(height=500)
msg = gr.Textbox(label="Your Message")
token_state = gr.State([])
msg.submit(
process_message,
[msg, chatbot, token_state],
[chatbot, token_state],
queue=False
).then(
lambda: "", None, msg
)
clear_btn = gr.Button("Clear Chat History")
clear_btn.click(
clear_history,
outputs=[chatbot, token_state],
queue=False
)
demo.launch() |