TimurHromek commited on
Commit
1966659
·
verified ·
1 Parent(s): 77b6e8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -9
app.py CHANGED
@@ -36,39 +36,53 @@ model = load_model()
36
  safety = SafetyManager(model, tokenizer)
37
  max_response_length = 200
38
 
39
- def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200):
40
  device = next(model.parameters()).device
41
  generated_ids = input_ids.copy()
42
  for _ in range(max_length):
43
  input_tensor = torch.tensor([generated_ids], device=device)
44
  with torch.no_grad():
45
  logits = model(input_tensor)
46
- next_token = logits.argmax(-1)[:, -1].item()
 
 
 
 
 
 
 
 
 
 
47
  if next_token == tokenizer.token_to_id("</s>"):
48
  break
 
 
49
  current_text = tokenizer.decode(generated_ids + [next_token])
50
  if not safety_manager.content_filter(current_text):
51
  break
 
52
  generated_ids.append(next_token)
53
  return generated_ids[len(input_ids):]
54
 
55
- def process_message(user_input, chat_history, token_history):
56
  # Process user input
57
  user_turn = f"<user> {user_input} </s>"
58
  user_tokens = tokenizer.encode(user_turn).ids
59
  token_history.extend(user_tokens)
60
 
61
- # Prepare input sequence
62
  input_sequence = [tokenizer.token_to_id("<s>")] + token_history
63
 
64
- # Truncate if needed
65
- max_input_len = CONFIG["max_seq_len"] - max_response_length
66
  if len(input_sequence) > max_input_len:
67
  input_sequence = input_sequence[-max_input_len:]
68
  token_history = input_sequence[1:]
69
 
70
- # Generate response
71
- response_ids = generate_response(model, tokenizer, input_sequence, safety, max_response_length)
 
72
 
73
  # Process assistant response
74
  assistant_text = "I couldn't generate a proper response."
@@ -97,9 +111,16 @@ with gr.Blocks() as demo:
97
  msg = gr.Textbox(label="Your Message")
98
  token_state = gr.State([])
99
 
 
 
 
 
 
 
 
100
  msg.submit(
101
  process_message,
102
- [msg, chatbot, token_state],
103
  [chatbot, token_state],
104
  queue=False
105
  ).then(
 
36
  safety = SafetyManager(model, tokenizer)
37
  max_response_length = 200
38
 
39
+ def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200, temperature=1.0):
40
  device = next(model.parameters()).device
41
  generated_ids = input_ids.copy()
42
  for _ in range(max_length):
43
  input_tensor = torch.tensor([generated_ids], device=device)
44
  with torch.no_grad():
45
  logits = model(input_tensor)
46
+
47
+ # Get last token logits and apply temperature
48
+ next_token_logits = logits[0, -1, :]
49
+ if temperature != 1.0:
50
+ next_token_logits = next_token_logits / temperature
51
+ probs = torch.softmax(next_token_logits, dim=-1)
52
+
53
+ # Sample next token
54
+ next_token = torch.multinomial(probs, num_samples=1).item()
55
+
56
+ # Stop if end token is generated
57
  if next_token == tokenizer.token_to_id("</s>"):
58
  break
59
+
60
+ # Safety check
61
  current_text = tokenizer.decode(generated_ids + [next_token])
62
  if not safety_manager.content_filter(current_text):
63
  break
64
+
65
  generated_ids.append(next_token)
66
  return generated_ids[len(input_ids):]
67
 
68
+ def process_message(user_input, chat_history, token_history, temperature, max_context_length):
69
  # Process user input
70
  user_turn = f"<user> {user_input} </s>"
71
  user_tokens = tokenizer.encode(user_turn).ids
72
  token_history.extend(user_tokens)
73
 
74
+ # Prepare input sequence with context limit
75
  input_sequence = [tokenizer.token_to_id("<s>")] + token_history
76
 
77
+ # Truncate based on max context length
78
+ max_input_len = max_context_length
79
  if len(input_sequence) > max_input_len:
80
  input_sequence = input_sequence[-max_input_len:]
81
  token_history = input_sequence[1:]
82
 
83
+ # Generate response with temperature
84
+ response_ids = generate_response(model, tokenizer, input_sequence, safety,
85
+ max_response_length, temperature)
86
 
87
  # Process assistant response
88
  assistant_text = "I couldn't generate a proper response."
 
111
  msg = gr.Textbox(label="Your Message")
112
  token_state = gr.State([])
113
 
114
+ with gr.Row():
115
+ temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1,
116
+ label="Temperature (higher = more random)")
117
+ max_context = gr.Slider(100, CONFIG["max_seq_len"] - max_response_length,
118
+ value=CONFIG["max_seq_len"] - max_response_length, step=1,
119
+ label="Max Context Length")
120
+
121
  msg.submit(
122
  process_message,
123
+ [msg, chatbot, token_state, temperature, max_context],
124
  [chatbot, token_state],
125
  queue=False
126
  ).then(