Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import importlib.util | |
from tokenizers import Tokenizer | |
from huggingface_hub import hf_hub_download | |
import os | |
# Download and import model components from HF Hub | |
model_repo = "TimurHromek/HROM-V1" | |
# 1. Import trainer module components | |
trainer_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.5_Trainer.py") | |
spec = importlib.util.spec_from_file_location("HROM_Trainer", trainer_file) | |
trainer_module = importlib.util.module_from_spec(spec) | |
spec.loader.exec_module(trainer_module) | |
HROM = trainer_module.HROM | |
CONFIG = trainer_module.CONFIG | |
SafetyManager = trainer_module.SafetyManager | |
# 2. Load tokenizer | |
tokenizer_file = hf_hub_download(repo_id=model_repo, filename="tokenizer/hrom_tokenizer.json") | |
tokenizer = Tokenizer.from_file(tokenizer_file) | |
# 3. Load model checkpoint | |
checkpoint_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.5_Trained-Model/HROM-V1.5.pt") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def load_model(): | |
model = HROM().to(device) | |
checkpoint = torch.load(checkpoint_file, map_location=device) | |
model.load_state_dict(checkpoint['model']) | |
model.eval() | |
return model | |
model = load_model() | |
safety = SafetyManager(model, tokenizer) | |
max_response_length = 200 | |
def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200): | |
device = next(model.parameters()).device | |
generated_ids = input_ids.copy() | |
for _ in range(max_length): | |
input_tensor = torch.tensor([generated_ids], device=device) | |
with torch.no_grad(): | |
logits = model(input_tensor) | |
next_token = logits.argmax(-1)[:, -1].item() | |
if next_token == tokenizer.token_to_id("</s>"): | |
break | |
current_text = tokenizer.decode(generated_ids + [next_token]) | |
if not safety_manager.content_filter(current_text): | |
break | |
generated_ids.append(next_token) | |
return generated_ids[len(input_ids):] | |
def process_message(user_input, chat_history, token_history): | |
# Process user input | |
user_turn = f"<user> {user_input} </s>" | |
user_tokens = tokenizer.encode(user_turn).ids | |
token_history.extend(user_tokens) | |
# Prepare input sequence | |
input_sequence = [tokenizer.token_to_id("<s>")] + token_history | |
# Truncate if needed | |
max_input_len = CONFIG["max_seq_len"] - max_response_length | |
if len(input_sequence) > max_input_len: | |
input_sequence = input_sequence[-max_input_len:] | |
token_history = input_sequence[1:] | |
# Generate response | |
response_ids = generate_response(model, tokenizer, input_sequence, safety, max_response_length) | |
# Process assistant response | |
assistant_text = "I couldn't generate a proper response." | |
if response_ids: | |
if response_ids[0] == tokenizer.token_to_id("<assistant>"): | |
try: | |
end_idx = response_ids.index(tokenizer.token_to_id("</s>")) | |
assistant_text = tokenizer.decode(response_ids[1:end_idx]) | |
token_history.extend(response_ids[:end_idx+1]) | |
except ValueError: | |
assistant_text = tokenizer.decode(response_ids[1:]) | |
token_history.extend(response_ids) | |
else: | |
assistant_text = tokenizer.decode(response_ids) | |
token_history.extend(response_ids) | |
chat_history.append((user_input, assistant_text)) | |
return chat_history, token_history | |
def clear_history(): | |
return [], [] | |
# Custom CSS for styling | |
custom_css = """ | |
body { | |
background: linear-gradient(to bottom, #1a1a1a, #2a2a2a); | |
font-family: 'Roboto', sans-serif; | |
color: white; | |
margin: 0; | |
padding: 0; | |
} | |
.container { | |
max-width: 800px; | |
margin: 0 auto; | |
padding: 20px; | |
} | |
.gr-chatbot { | |
font-size: 16px; | |
border: none; | |
background-color: #1e1e1e; | |
border-radius: 8px; | |
padding: 10px; | |
} | |
.gr-chatbot .bubble.user { | |
background-color: #2d2d2d !important; | |
border-radius: 8px; | |
padding: 12px; | |
margin: 8px 0; | |
} | |
.gr-chatbot .bubble.assistant { | |
background-color: #3d3d3d !important; | |
border-radius: 8px; | |
padding: 12px; | |
margin: 8px 0; | |
} | |
.gr-button { | |
background-color: #4CAF50; | |
color: white; | |
border: none; | |
padding: 12px 24px; | |
font-size: 16px; | |
border-radius: 4px; | |
cursor: pointer; | |
transition: background-color 0.3s; | |
} | |
.gr-button:hover { | |
background-color: #45a049; | |
} | |
.gr-text-input input { | |
background-color: #2d2d2d; | |
color: white; | |
border: 1px solid #4CAF50; | |
border-radius: 4px; | |
padding: 10px; | |
font-size: 16px; | |
} | |
.header { | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
padding: 20px 0; | |
text-align: center; | |
} | |
.header img { | |
width: 60px; | |
height: 60px; | |
margin-right: 15px; | |
} | |
.footer { | |
text-align: center; | |
padding: 20px; | |
font-size: 14px; | |
color: #ccc; | |
margin-top: 30px; | |
} | |
.title { | |
font-size: 28px; | |
font-weight: bold; | |
color: #ffffff; | |
margin: 0; | |
} | |
.subtitle { | |
font-size: 16px; | |
color: #cccccc; | |
margin: 5px 0 0 0; | |
} | |
""" | |
with gr.Blocks( | |
theme="dark", | |
css=custom_css | |
) as demo: | |
with gr.Column(elem_classes=["container"]): | |
# Header | |
with gr.Row(elem_classes=["header"]): | |
gr.Image( | |
value="https://huggingface.co/TimurHromek/HROM-V1/resolve/main/hrom_icon.png", | |
interactive=False, | |
width=60, | |
height=60, | |
show_label=False | |
) | |
with gr.Column(): | |
gr.Markdown("<div class='title'>HROM-V1 Chatbot</div>") | |
gr.Markdown("<div class='subtitle'>Powered by Gradio and Hugging Face</div>") | |
# Chatbot | |
chatbot = gr.Chatbot( | |
height=500, | |
avatar_images=[ | |
("https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/test_image.png", "user"), | |
("https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/favicon.png", "assistant") | |
], | |
bubble_full_width=False | |
) | |
# Input | |
msg = gr.Textbox( | |
label="Your Message", | |
placeholder="Type your message here...", | |
lines=2 | |
) | |
# Buttons | |
clear_btn = gr.Button("Clear Chat History") | |
# State | |
token_state = gr.State([]) | |
# Event handlers | |
msg.submit( | |
process_message, | |
[msg, chatbot, token_state], | |
[chatbot, token_state], | |
queue=False | |
).then( | |
lambda: "", None, msg | |
) | |
clear_btn.click( | |
clear_history, | |
outputs=[chatbot, token_state], | |
queue=False | |
) | |
# Footer | |
gr.Markdown("<div class='footer'>© 2025 HROM-V1 | Model by Timur Hromek | Assisted by Elapt1c</div>") | |
if __name__ == "__main__": | |
demo.launch() |