GPT-2-test / app.py
lachine's picture
Update app.py
8629bfd
raw
history blame
2.43 kB
import gradio as gr
from transformers import pipeline, set_seed
from random import randint
generator = pipeline('text-generation', model='gpt2')
def generate_text(text, max_length, amount):
"""
Generates text using the GPT-2 model.
:param text: Input text to generate from.
:param max_length: Maximum length of generated text.
:param amount: Number of texts to generate.
:return: List of generated texts.
"""
# Set the seed for reproducibility
set_seed(randint(randint(1000,10000),randint(50000,300000)))
# Generate the text
generated_texts = [d['generated_text'] for d in generator(text, max_length=max_length, num_return_sequences=amount)]
# Return the generated text
return '\nend of text\n'.join(generated_texts)
# Define the inputs
text_input = gr.inputs.Textbox(lines=5, label='Input Text')
max_length_slider = gr.inputs.Slider(minimum=10, maximum=500, step=1, default=100, label='max_length')
amount_slider = gr.inputs.Slider(minimum=1, maximum=5, step=1, default=1, label='num_return_equences (Amount)')
# Define the output
output_textbox = gr.outputs.Textbox(label='Output Text')
# Create the interface
interface = gr.Interface(fn=generate_text,
inputs=[text_input, max_length_slider, amount_slider],
outputs=output_textbox,
title='Minimal GPT-2 Demo',
description='Generate text using GPT-2')
# Set the page layout
interface.layout = 'vertical'
# Set the output text to wrap
interface.outputs[0].type = 'text'
# Add API documentation
interface.api.docs = {
'generate_text': {
'description': 'Generates text using the GPT-2 model.',
'input': [
{
'name': 'text',
'type': 'str',
'description': 'Input text to generate from.'
},
{
'name': 'max_length',
'type': 'int',
'description': 'Maximum length of generated text.'
},
{
'name': 'num_return_sequences',
'type': 'int',
'description': 'Number of texts to generate.'
}
],
'output': {
'type': 'str',
'description': r'The text(s). (seperated by "\nend of text\n"'
}
}
}
# Run the interface
interface.launch(share=True)