Last commit not found
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") | |
model = AutoModelForSeq2SeqLM.from_pretrained("./checkpoint-15000/") | |
def text_processing(text): | |
inputs = [text] | |
# Tokenize and prepare the inputs for model | |
input_ids = tokenizer(inputs, return_tensors="pt", max_length=512, truncation=True, padding="max_length").input_ids | |
attention_mask = tokenizer(inputs, return_tensors="pt", max_length=512, truncation=True, padding="max_length").attention_mask | |
# Generate prediction | |
output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=512) | |
# Decode the prediction | |
decoded_output = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output] | |
return decoded_output[0] | |
examples = [ | |
["猪笼草原产于热带和亚热带地区 现主要分布在东南亚一带 中国广东 广西等地有分布 猪笼草喜欢湿润和温暖半阴的生长环境 不耐寒 怕积水 怕强光 怕干燥 喜欢疏松 肥沃和透气的腐叶土和泥炭土 对光照要求较为严格 猪笼草的繁殖方式包括扦插繁殖 压条繁殖和播种繁殖"], | |
["都什么年代了 还在抽传统香烟"] | |
] | |
inputs=[gr.inputs.Textbox(default=examples[0][0], label="输入文本")] | |
iface = gr.Interface( | |
fn=text_processing, | |
inputs=[gr.inputs.Textbox(default=examples[0][0], label="输入文本")], | |
outputs='text', | |
title='Punctuation Mark Prediction', | |
description='本模型主要用于语言识别模型输出的后处理。\n输入无符号句子,需要打标点处用空格隔开,返回带标点句子。\n仅支持中文,因为训练数据中只有中文。', | |
examples=examples | |
) | |
iface.launch(inline=False) |