Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -36,39 +36,53 @@ model = load_model()
|
|
36 |
safety = SafetyManager(model, tokenizer)
|
37 |
max_response_length = 200
|
38 |
|
39 |
-
def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200):
|
40 |
device = next(model.parameters()).device
|
41 |
generated_ids = input_ids.copy()
|
42 |
for _ in range(max_length):
|
43 |
input_tensor = torch.tensor([generated_ids], device=device)
|
44 |
with torch.no_grad():
|
45 |
logits = model(input_tensor)
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
if next_token == tokenizer.token_to_id("</s>"):
|
48 |
break
|
|
|
|
|
49 |
current_text = tokenizer.decode(generated_ids + [next_token])
|
50 |
if not safety_manager.content_filter(current_text):
|
51 |
break
|
|
|
52 |
generated_ids.append(next_token)
|
53 |
return generated_ids[len(input_ids):]
|
54 |
|
55 |
-
def process_message(user_input, chat_history, token_history):
|
56 |
# Process user input
|
57 |
user_turn = f"<user> {user_input} </s>"
|
58 |
user_tokens = tokenizer.encode(user_turn).ids
|
59 |
token_history.extend(user_tokens)
|
60 |
|
61 |
-
# Prepare input sequence
|
62 |
input_sequence = [tokenizer.token_to_id("<s>")] + token_history
|
63 |
|
64 |
-
# Truncate
|
65 |
-
max_input_len =
|
66 |
if len(input_sequence) > max_input_len:
|
67 |
input_sequence = input_sequence[-max_input_len:]
|
68 |
token_history = input_sequence[1:]
|
69 |
|
70 |
-
# Generate response
|
71 |
-
response_ids = generate_response(model, tokenizer, input_sequence, safety,
|
|
|
72 |
|
73 |
# Process assistant response
|
74 |
assistant_text = "I couldn't generate a proper response."
|
@@ -97,9 +111,16 @@ with gr.Blocks() as demo:
|
|
97 |
msg = gr.Textbox(label="Your Message")
|
98 |
token_state = gr.State([])
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
msg.submit(
|
101 |
process_message,
|
102 |
-
[msg, chatbot, token_state],
|
103 |
[chatbot, token_state],
|
104 |
queue=False
|
105 |
).then(
|
|
|
36 |
safety = SafetyManager(model, tokenizer)
|
37 |
max_response_length = 200
|
38 |
|
39 |
+
def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200, temperature=1.0):
|
40 |
device = next(model.parameters()).device
|
41 |
generated_ids = input_ids.copy()
|
42 |
for _ in range(max_length):
|
43 |
input_tensor = torch.tensor([generated_ids], device=device)
|
44 |
with torch.no_grad():
|
45 |
logits = model(input_tensor)
|
46 |
+
|
47 |
+
# Get last token logits and apply temperature
|
48 |
+
next_token_logits = logits[0, -1, :]
|
49 |
+
if temperature != 1.0:
|
50 |
+
next_token_logits = next_token_logits / temperature
|
51 |
+
probs = torch.softmax(next_token_logits, dim=-1)
|
52 |
+
|
53 |
+
# Sample next token
|
54 |
+
next_token = torch.multinomial(probs, num_samples=1).item()
|
55 |
+
|
56 |
+
# Stop if end token is generated
|
57 |
if next_token == tokenizer.token_to_id("</s>"):
|
58 |
break
|
59 |
+
|
60 |
+
# Safety check
|
61 |
current_text = tokenizer.decode(generated_ids + [next_token])
|
62 |
if not safety_manager.content_filter(current_text):
|
63 |
break
|
64 |
+
|
65 |
generated_ids.append(next_token)
|
66 |
return generated_ids[len(input_ids):]
|
67 |
|
68 |
+
def process_message(user_input, chat_history, token_history, temperature, max_context_length):
|
69 |
# Process user input
|
70 |
user_turn = f"<user> {user_input} </s>"
|
71 |
user_tokens = tokenizer.encode(user_turn).ids
|
72 |
token_history.extend(user_tokens)
|
73 |
|
74 |
+
# Prepare input sequence with context limit
|
75 |
input_sequence = [tokenizer.token_to_id("<s>")] + token_history
|
76 |
|
77 |
+
# Truncate based on max context length
|
78 |
+
max_input_len = max_context_length
|
79 |
if len(input_sequence) > max_input_len:
|
80 |
input_sequence = input_sequence[-max_input_len:]
|
81 |
token_history = input_sequence[1:]
|
82 |
|
83 |
+
# Generate response with temperature
|
84 |
+
response_ids = generate_response(model, tokenizer, input_sequence, safety,
|
85 |
+
max_response_length, temperature)
|
86 |
|
87 |
# Process assistant response
|
88 |
assistant_text = "I couldn't generate a proper response."
|
|
|
111 |
msg = gr.Textbox(label="Your Message")
|
112 |
token_state = gr.State([])
|
113 |
|
114 |
+
with gr.Row():
|
115 |
+
temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1,
|
116 |
+
label="Temperature (higher = more random)")
|
117 |
+
max_context = gr.Slider(100, CONFIG["max_seq_len"] - max_response_length,
|
118 |
+
value=CONFIG["max_seq_len"] - max_response_length, step=1,
|
119 |
+
label="Max Context Length")
|
120 |
+
|
121 |
msg.submit(
|
122 |
process_message,
|
123 |
+
[msg, chatbot, token_state, temperature, max_context],
|
124 |
[chatbot, token_state],
|
125 |
queue=False
|
126 |
).then(
|