miracFence commited on
Commit
d5c8018
Β·
verified Β·
1 Parent(s): e3a0c35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -25,6 +25,7 @@ model = AutoModelForCausalLM.from_pretrained(model_name,
25
  quantization_config=quantization_config,
26
  device_map="auto")
27
  model.eval()
 
28
 
29
  @spaces.GPU(duration=90)
30
  def generate(
@@ -47,9 +48,9 @@ def generate(
47
  conversation.append({"role": "user", "content": message})
48
 
49
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
50
- if input_ids.shape[1] > 4096:
51
- input_ids = input_ids[:, -4096:]
52
- gr.Warning(f"Trimmed input from conversation as it was longer than {4096} tokens.")
53
  input_ids = input_ids.to(model.device)
54
 
55
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
 
25
  quantization_config=quantization_config,
26
  device_map="auto")
27
  model.eval()
28
+ max_token_length = 4096
29
 
30
  @spaces.GPU(duration=90)
31
  def generate(
 
48
  conversation.append({"role": "user", "content": message})
49
 
50
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
51
+ if input_ids.shape[1] > max_token_length:
52
+ input_ids = input_ids[:, -max_token_length:]
53
+ gr.Warning(f"Trimmed input from conversation as it was longer than {max_token_length} tokens.")
54
  input_ids = input_ids.to(model.device)
55
 
56
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)