prithivMLmods commited on
Commit
1d74de7
·
verified ·
1 Parent(s): bb36c83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -10,7 +10,7 @@ DESCRIPTION = """
10
  # QwQ Distill
11
  """
12
 
13
- css= '''
14
  h1 {
15
  text-align: center;
16
  display: block;
@@ -40,6 +40,9 @@ model = AutoModelForCausalLM.from_pretrained(
40
  model.config.sliding_window = 4096
41
  model.eval()
42
 
 
 
 
43
 
44
  @spaces.GPU(duration=120)
45
  def generate(
@@ -54,15 +57,23 @@ def generate(
54
  conversation = chat_history.copy()
55
  conversation.append({"role": "user", "content": message})
56
 
57
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
 
 
 
 
58
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
59
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
60
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
61
  input_ids = input_ids.to(model.device)
 
62
 
63
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
64
  generate_kwargs = dict(
65
- {"input_ids": input_ids},
 
66
  streamer=streamer,
67
  max_new_tokens=max_new_tokens,
68
  do_sample=True,
@@ -71,6 +82,7 @@ def generate(
71
  temperature=temperature,
72
  num_beams=1,
73
  repetition_penalty=repetition_penalty,
 
74
  )
75
  t = Thread(target=model.generate, kwargs=generate_kwargs)
76
  t.start()
 
10
  # QwQ Distill
11
  """
12
 
13
+ css = '''
14
  h1 {
15
  text-align: center;
16
  display: block;
 
40
  model.config.sliding_window = 4096
41
  model.eval()
42
 
43
+ # Set the pad token ID if it's not already set
44
+ if tokenizer.pad_token_id is None:
45
+ tokenizer.pad_token_id = tokenizer.eos_token_id
46
 
47
  @spaces.GPU(duration=120)
48
  def generate(
 
57
  conversation = chat_history.copy()
58
  conversation.append({"role": "user", "content": message})
59
 
60
+ # Apply chat template and get input_ids and attention_mask
61
+ inputs = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
62
+ input_ids = inputs["input_ids"]
63
+ attention_mask = inputs["attention_mask"]
64
+
65
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
66
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
67
+ attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
68
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
69
+
70
  input_ids = input_ids.to(model.device)
71
+ attention_mask = attention_mask.to(model.device)
72
 
73
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
74
  generate_kwargs = dict(
75
+ input_ids=input_ids,
76
+ attention_mask=attention_mask,
77
  streamer=streamer,
78
  max_new_tokens=max_new_tokens,
79
  do_sample=True,
 
82
  temperature=temperature,
83
  num_beams=1,
84
  repetition_penalty=repetition_penalty,
85
+ pad_token_id=tokenizer.pad_token_id,
86
  )
87
  t = Thread(target=model.generate, kwargs=generate_kwargs)
88
  t.start()