import gradio as gr | |
import transformers | |
import tokenizers | |
# https://huggingface.co/docs/hub/spaces-sdks-gradio | |
tokenizer_bert = BertTokenizerFast.from_pretrained('bert-base-chinese', | |
additional_special_tokens=["<s>","<pad>","</s>","<unk>","<mask>"], | |
pad_token='<pad>' ,max_len=512) | |
configuration = GPT2Config(vocab_size=25000, n_layer=8) | |
model = GPT2LMHeadModel(config=configuration) | |
#%% | |
path2pytorch_model = "pytorch_model.bin" | |
model.load_state_dict(torch.load(path2pytorch_model)) | |
generator = pipeline('text-generation', model=model, tokenizer=tokenizer_bert) | |
def generate(prompt): | |
outputs = generator(prompt, max_length=30, num_return_sequences=5, num_beams=10, top_p=0.999, repetition_penalty=1.5) | |
return outputs[0]['generated_text'] | |
iface = gr.Interface(fn=generate, inputs="text", outputs="text") | |
iface.launch() | |