crystal99 commited on
Commit
efe6226
·
verified ·
1 Parent(s): 05eaf03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -38,9 +38,9 @@ model.to(device)
38
  def generate_text(prompt):
39
  # Prevent gradient calculation to speed up inference
40
  with torch.no_grad():
41
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
42
  outputs = model.generate(inputs['input_ids'], max_length=100, num_return_sequences=1, do_sample=False)
43
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
44
  return generated_text
45
 
46
  # Set up the Gradio interface
 
38
  def generate_text(prompt):
39
  # Prevent gradient calculation to speed up inference
40
  with torch.no_grad():
41
+ inputs = tokenizer("<|STARTOFTEXT|> <|USER|> " + prompt + " <|BOT|> ", return_tensors="pt").to(device)
42
  outputs = model.generate(inputs['input_ids'], max_length=100, num_return_sequences=1, do_sample=False)
43
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
44
  return generated_text
45
 
46
  # Set up the Gradio interface