ajsbsd commited on
Commit
7666164
·
verified ·
1 Parent(s): 4d692df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -15
app.py CHANGED
@@ -159,21 +159,18 @@ def generate_response_and_audio(
159
  input_text += f"Assistant: {item['content']}\n"
160
  input_text += f"User: {message}\nAssistant:"
161
 
162
- input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(llm_model.device)
163
-
164
- with torch.no_grad():
165
- output_ids = llm_model.generate(
166
- input_ids,
167
- max_new_tokens=MAX_NEW_TOKENS,
168
- do_sample=DO_SAMPLE,
169
- temperature=TEMPERATURE,
170
- top_k=TOP_K,
171
- top_p=TOP_P,
172
- pad_token_id=tokenizer.eos_token_id
173
- )
174
-
175
- generated_token_ids = output_ids[0][input_ids.shape[-1]:]
176
- generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()
177
 
178
  # --- 2. Generate Audio from Response (TTS) ---
179
  audio_path = None
 
159
  input_text += f"Assistant: {item['content']}\n"
160
  input_text += f"User: {message}\nAssistant:"
161
 
162
+ input_ids = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(llm_model.device)
163
+ with torch.no_grad():
164
+ output_ids = llm_model.generate(
165
+ input_ids["input_ids"],
166
+ attention_mask=input_ids["attention_mask"], # <-- Add this line
167
+ max_new_tokens=MAX_NEW_TOKENS,
168
+ do_sample=DO_SAMPLE,
169
+ temperature=TEMPERATURE,
170
+ top_k=TOP_K,
171
+ top_p=TOP_P,
172
+ pad_token_id=tokenizer.eos_token_id
173
+ )
 
 
 
174
 
175
  # --- 2. Generate Audio from Response (TTS) ---
176
  audio_path = None