Update app.py
Browse files
app.py
CHANGED
@@ -38,9 +38,9 @@ model.to(device)
|
|
38 |
def generate_text(prompt):
|
39 |
# Prevent gradient calculation to speed up inference
|
40 |
with torch.no_grad():
|
41 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
42 |
outputs = model.generate(inputs['input_ids'], max_length=100, num_return_sequences=1, do_sample=False)
|
43 |
-
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=
|
44 |
return generated_text
|
45 |
|
46 |
# Set up the Gradio interface
|
|
|
38 |
def generate_text(prompt):
|
39 |
# Prevent gradient calculation to speed up inference
|
40 |
with torch.no_grad():
|
41 |
+
inputs = tokenizer("<|STARTOFTEXT|> <|USER|> " + prompt + " <|BOT|> ", return_tensors="pt").to(device)
|
42 |
outputs = model.generate(inputs['input_ids'], max_length=100, num_return_sequences=1, do_sample=False)
|
43 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
44 |
return generated_text
|
45 |
|
46 |
# Set up the Gradio interface
|