1inkusFace commited on
Commit
1e482e7
·
verified ·
1 Parent(s): 90a7060

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -196,7 +196,8 @@ def filter_text(text,phraseC):
196
  if match:
197
  filtered_text = match.group(2)
198
  for phrase in phraseC:
199
- filtered_text = re.sub(phrase, "", filtered_text, flags=re.DOTALL)
 
200
  return filtered_text
201
  else:
202
  return filtered_text
@@ -300,6 +301,7 @@ def expand_prompt(prompt):
300
  do_sample=True,
301
  )
302
  enhanced_prompt = txt_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
303
  input_text_2 = f"{system_prompt_rewrite} {user_prompt_rewrite_2} {enhanced_prompt}"
304
  encoded_inputs_2 = txt_tokenizer(input_text_2, return_tensors="pt", return_attention_mask=True).to("cuda:0")
305
  input_ids_2 = encoded_inputs_2["input_ids"].to("cuda:0")
@@ -315,7 +317,6 @@ def expand_prompt(prompt):
315
  # Use the encoded tensor 'text_inputs' here
316
  enhanced_prompt_2 = txt_tokenizer.decode(outputs_2[0], skip_special_tokens=True)
317
  print('-- generated prompt --')
318
- enhanced_prompt = filter_text(enhanced_prompt,prompt)
319
  enhanced_prompt_2 = filter_text(enhanced_prompt_2,prompt)
320
  print('-- filtered prompt --')
321
  print(enhanced_prompt)
 
196
  if match:
197
  filtered_text = match.group(2)
198
  for phrase in phraseC:
199
+ for wrd in phrase
200
+ filtered_text = re.sub(wrd, "", filtered_text, flags=re.DOTALL)
201
  return filtered_text
202
  else:
203
  return filtered_text
 
301
  do_sample=True,
302
  )
303
  enhanced_prompt = txt_tokenizer.decode(outputs[0], skip_special_tokens=True)
304
+ enhanced_prompt = filter_text(enhanced_prompt,prompt)
305
  input_text_2 = f"{system_prompt_rewrite} {user_prompt_rewrite_2} {enhanced_prompt}"
306
  encoded_inputs_2 = txt_tokenizer(input_text_2, return_tensors="pt", return_attention_mask=True).to("cuda:0")
307
  input_ids_2 = encoded_inputs_2["input_ids"].to("cuda:0")
 
317
  # Use the encoded tensor 'text_inputs' here
318
  enhanced_prompt_2 = txt_tokenizer.decode(outputs_2[0], skip_special_tokens=True)
319
  print('-- generated prompt --')
 
320
  enhanced_prompt_2 = filter_text(enhanced_prompt_2,prompt)
321
  print('-- filtered prompt --')
322
  print(enhanced_prompt)