import gradio as gr import torch import importlib.util from tokenizers import Tokenizer from huggingface_hub import hf_hub_download import os # Download and import model components from HF Hub model_repo = "elapt1c/hrom-testing" # 1. Import trainer module components trainer_file = hf_hub_download(repo_id=model_repo, filename="HROM_Trainer.py") spec = importlib.util.spec_from_file_location("HROM_Trainer", trainer_file) trainer_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(trainer_module) HROM = trainer_module.HROM CONFIG = trainer_module.CONFIG SafetyManager = trainer_module.SafetyManager # 2. Load tokenizer tokenizer_file = hf_hub_download(repo_id=model_repo, filename="hrom_tokenizer.json") tokenizer = Tokenizer.from_file(tokenizer_file) # 3. Load model checkpoint checkpoint_file = hf_hub_download(repo_id=model_repo, filename="hrom.pt") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model(): model = HROM().to(device) checkpoint = torch.load(checkpoint_file, map_location=device) model.load_state_dict(checkpoint['model']) model.eval() return model model = load_model() safety = SafetyManager(model, tokenizer) max_response_length = 200 def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200): device = next(model.parameters()).device generated_ids = input_ids.copy() for _ in range(max_length): input_tensor = torch.tensor([generated_ids], device=device) with torch.no_grad(): logits = model(input_tensor) next_token = logits.argmax(-1)[:, -1].item() if next_token == tokenizer.token_to_id(""): break current_text = tokenizer.decode(generated_ids + [next_token]) if not safety_manager.content_filter(current_text): break generated_ids.append(next_token) return generated_ids[len(input_ids):] def process_message(user_input, chat_history, token_history): # Process user input user_turn = f" {user_input} " user_tokens = tokenizer.encode(user_turn).ids token_history.extend(user_tokens) # Prepare input sequence input_sequence = [tokenizer.token_to_id("")] + token_history # Truncate if needed max_input_len = CONFIG["max_seq_len"] - max_response_length if len(input_sequence) > max_input_len: input_sequence = input_sequence[-max_input_len:] token_history = input_sequence[1:] # Generate response response_ids = generate_response(model, tokenizer, input_sequence, safety, max_response_length) # Process assistant response assistant_text = "I couldn't generate a proper response." if response_ids: if response_ids[0] == tokenizer.token_to_id(""): try: end_idx = response_ids.index(tokenizer.token_to_id("")) assistant_text = tokenizer.decode(response_ids[1:end_idx]) token_history.extend(response_ids[:end_idx+1]) except ValueError: assistant_text = tokenizer.decode(response_ids[1:]) token_history.extend(response_ids) else: assistant_text = tokenizer.decode(response_ids) token_history.extend(response_ids) chat_history.append((user_input, assistant_text)) return chat_history, token_history def clear_history(): return [], [] with gr.Blocks() as demo: gr.Markdown("# HROM-V1 Chatbot") chatbot = gr.Chatbot(height=500) msg = gr.Textbox(label="Your Message") token_state = gr.State([]) msg.submit( process_message, [msg, chatbot, token_state], [chatbot, token_state], queue=False ).then( lambda: "", None, msg ) clear_btn = gr.Button("Clear Chat History") clear_btn.click( clear_history, outputs=[chatbot, token_state], queue=False ) demo.launch()