HROM-V1 / app.py
TimurHromek's picture
Update app.py
5f4b8bc verified
raw
history blame
7.01 kB
import gradio as gr
import torch
import importlib.util
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import os
# Download and import model components from HF Hub
model_repo = "TimurHromek/HROM-V1"
# 1. Import trainer module components
trainer_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.5_Trainer.py")
spec = importlib.util.spec_from_file_location("HROM_Trainer", trainer_file)
trainer_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(trainer_module)
HROM = trainer_module.HROM
CONFIG = trainer_module.CONFIG
SafetyManager = trainer_module.SafetyManager
# 2. Load tokenizer
tokenizer_file = hf_hub_download(repo_id=model_repo, filename="tokenizer/hrom_tokenizer.json")
tokenizer = Tokenizer.from_file(tokenizer_file)
# 3. Load model checkpoint
checkpoint_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.5_Trained-Model/HROM-V1.5.pt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
model = HROM().to(device)
checkpoint = torch.load(checkpoint_file, map_location=device)
model.load_state_dict(checkpoint['model'])
model.eval()
return model
model = load_model()
safety = SafetyManager(model, tokenizer)
max_response_length = 200
def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200):
device = next(model.parameters()).device
generated_ids = input_ids.copy()
for _ in range(max_length):
input_tensor = torch.tensor([generated_ids], device=device)
with torch.no_grad():
logits = model(input_tensor)
next_token = logits.argmax(-1)[:, -1].item()
if next_token == tokenizer.token_to_id("</s>"):
break
current_text = tokenizer.decode(generated_ids + [next_token])
if not safety_manager.content_filter(current_text):
break
generated_ids.append(next_token)
return generated_ids[len(input_ids):]
def process_message(user_input, chat_history, token_history):
# Process user input
user_turn = f"<user> {user_input} </s>"
user_tokens = tokenizer.encode(user_turn).ids
token_history.extend(user_tokens)
# Prepare input sequence
input_sequence = [tokenizer.token_to_id("<s>")] + token_history
# Truncate if needed
max_input_len = CONFIG["max_seq_len"] - max_response_length
if len(input_sequence) > max_input_len:
input_sequence = input_sequence[-max_input_len:]
token_history = input_sequence[1:]
# Generate response
response_ids = generate_response(model, tokenizer, input_sequence, safety, max_response_length)
# Process assistant response
assistant_text = "I couldn't generate a proper response."
if response_ids:
if response_ids[0] == tokenizer.token_to_id("<assistant>"):
try:
end_idx = response_ids.index(tokenizer.token_to_id("</s>"))
assistant_text = tokenizer.decode(response_ids[1:end_idx])
token_history.extend(response_ids[:end_idx+1])
except ValueError:
assistant_text = tokenizer.decode(response_ids[1:])
token_history.extend(response_ids)
else:
assistant_text = tokenizer.decode(response_ids)
token_history.extend(response_ids)
chat_history.append((user_input, assistant_text))
return chat_history, token_history
def clear_history():
return [], []
# Custom CSS for styling
custom_css = """
body {
background: linear-gradient(to bottom, #1a1a1a, #2a2a2a);
font-family: 'Roboto', sans-serif;
color: white;
margin: 0;
padding: 0;
}
.container {
max-width: 800px;
margin: 0 auto;
padding: 20px;
}
.gr-chatbot {
font-size: 16px;
border: none;
background-color: #1e1e1e;
border-radius: 8px;
padding: 10px;
}
.gr-chatbot .bubble.user {
background-color: #2d2d2d !important;
border-radius: 8px;
padding: 12px;
margin: 8px 0;
}
.gr-chatbot .bubble.assistant {
background-color: #3d3d3d !important;
border-radius: 8px;
padding: 12px;
margin: 8px 0;
}
.gr-button {
background-color: #4CAF50;
color: white;
border: none;
padding: 12px 24px;
font-size: 16px;
border-radius: 4px;
cursor: pointer;
transition: background-color 0.3s;
}
.gr-button:hover {
background-color: #45a049;
}
.gr-text-input input {
background-color: #2d2d2d;
color: white;
border: 1px solid #4CAF50;
border-radius: 4px;
padding: 10px;
font-size: 16px;
}
.header {
display: flex;
align-items: center;
justify-content: center;
padding: 20px 0;
text-align: center;
}
.header img {
width: 60px;
height: 60px;
margin-right: 15px;
}
.footer {
text-align: center;
padding: 20px;
font-size: 14px;
color: #ccc;
margin-top: 30px;
}
.title {
font-size: 28px;
font-weight: bold;
color: #ffffff;
margin: 0;
}
.subtitle {
font-size: 16px;
color: #cccccc;
margin: 5px 0 0 0;
}
"""
with gr.Blocks(
theme="dark",
css=custom_css
) as demo:
with gr.Column(elem_classes=["container"]):
# Header
with gr.Row(elem_classes=["header"]):
gr.Image(
value="https://huggingface.co/TimurHromek/HROM-V1/resolve/main/hrom_icon.png",
interactive=False,
width=60,
height=60,
show_label=False
)
with gr.Column():
gr.Markdown("<div class='title'>HROM-V1 Chatbot</div>")
gr.Markdown("<div class='subtitle'>Powered by Gradio and Hugging Face</div>")
# Chatbot
chatbot = gr.Chatbot(
height=500,
avatar_images=[
("https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/test_image.png", "user"),
("https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/favicon.png", "assistant")
],
bubble_full_width=False
)
# Input
msg = gr.Textbox(
label="Your Message",
placeholder="Type your message here...",
lines=2
)
# Buttons
clear_btn = gr.Button("Clear Chat History")
# State
token_state = gr.State([])
# Event handlers
msg.submit(
process_message,
[msg, chatbot, token_state],
[chatbot, token_state],
queue=False
).then(
lambda: "", None, msg
)
clear_btn.click(
clear_history,
outputs=[chatbot, token_state],
queue=False
)
# Footer
gr.Markdown("<div class='footer'>© 2025 HROM-V1 | Model by Timur Hromek | Assisted by Elapt1c</div>")
if __name__ == "__main__":
demo.launch()