Spaces:
Sleeping
Sleeping
File size: 4,983 Bytes
95d187a bfe7166 95d187a bfe7166 95d187a bfe7166 77b6e8a bfe7166 95d187a bfe7166 95d187a bfe7166 1966659 95d187a 1966659 95d187a 1966659 95d187a 1966659 95d187a 1966659 95d187a 1966659 95d187a 1966659 95d187a 1966659 95d187a bfe7166 95d187a 1966659 95d187a 1966659 95d187a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
import torch
import importlib.util
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import os
# Download and import model components from HF Hub
model_repo = "TimurHromek/HROM-V1"
# 1. Import trainer module components
trainer_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.5_Trainer.py")
spec = importlib.util.spec_from_file_location("HROM_Trainer", trainer_file)
trainer_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(trainer_module)
HROM = trainer_module.HROM
CONFIG = trainer_module.CONFIG
SafetyManager = trainer_module.SafetyManager
# 2. Load tokenizer
tokenizer_file = hf_hub_download(repo_id=model_repo, filename="tokenizer/hrom_tokenizer.json")
tokenizer = Tokenizer.from_file(tokenizer_file)
# 3. Load model checkpoint
checkpoint_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.5_Trained-Model/HROM-V1.5.pt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
model = HROM().to(device)
checkpoint = torch.load(checkpoint_file, map_location=device)
model.load_state_dict(checkpoint['model'])
model.eval()
return model
model = load_model()
safety = SafetyManager(model, tokenizer)
max_response_length = 200
def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200, temperature=1.0):
device = next(model.parameters()).device
generated_ids = input_ids.copy()
for _ in range(max_length):
input_tensor = torch.tensor([generated_ids], device=device)
with torch.no_grad():
logits = model(input_tensor)
# Get last token logits and apply temperature
next_token_logits = logits[0, -1, :]
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
probs = torch.softmax(next_token_logits, dim=-1)
# Sample next token
next_token = torch.multinomial(probs, num_samples=1).item()
# Stop if end token is generated
if next_token == tokenizer.token_to_id("</s>"):
break
# Safety check
current_text = tokenizer.decode(generated_ids + [next_token])
if not safety_manager.content_filter(current_text):
break
generated_ids.append(next_token)
return generated_ids[len(input_ids):]
def process_message(user_input, chat_history, token_history, temperature, max_context_length):
# Process user input
user_turn = f"<user> {user_input} </s>"
user_tokens = tokenizer.encode(user_turn).ids
token_history.extend(user_tokens)
# Prepare input sequence with context limit
input_sequence = [tokenizer.token_to_id("<s>")] + token_history
# Truncate based on max context length
max_input_len = max_context_length
if len(input_sequence) > max_input_len:
input_sequence = input_sequence[-max_input_len:]
token_history = input_sequence[1:]
# Generate response with temperature
response_ids = generate_response(model, tokenizer, input_sequence, safety,
max_response_length, temperature)
# Process assistant response
assistant_text = "I couldn't generate a proper response."
if response_ids:
if response_ids[0] == tokenizer.token_to_id("<assistant>"):
try:
end_idx = response_ids.index(tokenizer.token_to_id("</s>"))
assistant_text = tokenizer.decode(response_ids[1:end_idx])
token_history.extend(response_ids[:end_idx+1])
except ValueError:
assistant_text = tokenizer.decode(response_ids[1:])
token_history.extend(response_ids)
else:
assistant_text = tokenizer.decode(response_ids)
token_history.extend(response_ids)
chat_history.append((user_input, assistant_text))
return chat_history, token_history
def clear_history():
return [], []
with gr.Blocks() as demo:
gr.Markdown("# HROM-V1 Chatbot")
chatbot = gr.Chatbot(height=500)
msg = gr.Textbox(label="Your Message")
token_state = gr.State([])
with gr.Row():
temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1,
label="Temperature (higher = more random)")
max_context = gr.Slider(100, CONFIG["max_seq_len"] - max_response_length,
value=CONFIG["max_seq_len"] - max_response_length, step=1,
label="Max Context Length")
msg.submit(
process_message,
[msg, chatbot, token_state, temperature, max_context],
[chatbot, token_state],
queue=False
).then(
lambda: "", None, msg
)
clear_btn = gr.Button("Clear Chat History")
clear_btn.click(
clear_history,
outputs=[chatbot, token_state],
queue=False
)
demo.launch() |