File size: 1,559 Bytes
6d3e512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import gradio as gr
from tow.model_byt5.tokenizer import Tokenizer_byt5
from tow.model_byt5.model import Transformer_byt5
import json
import torch
from huggingface_hub import hf_hub_download

model_weights_path = hf_hub_download(repo_id="df-h/byt5-base-alibi-mt", filename="pytorch_model.bin")
model_config_path = hf_hub_download(repo_id="df-h/byt5-base-alibi-mt", filename="config.json")

def translate(inputs):
    with open(model_config_path, 'r') as f:
        config = json.load(f)

    state_dict = torch.load(model_weights_path, map_location=torch.device('cpu'))
    model = Transformer_byt5(config=config)
    model.load_state_dict(state_dict)
    model = model.eval()
    tokenizer = Tokenizer_byt5()  
    ids = tokenizer(inputs, max_length=512)
    len_pad = 512 - len(ids)
    if len_pad > 0:
        ids = ids + [0 for x in range(len_pad)]
    print(ids)  
    inputs = torch.tensor([ids]).to(torch.device('cpu'))
    outputs = model.generate(inputs, max_length=512)
    text = tokenizer.ids2text(outputs.tolist()[0])
    return text

demo = gr.Interface(
    fn=translate,
    inputs=[
        gr.components.Textbox(label="input", value="zh2en:一个描述实际事物的函数,其中的高频信息往往对应着很小的 “振幅”, 否则整个函数会很奇怪是个压扁的 “弹簧” ,不具实际意义。"),
    ],
    outputs=["text"],
    cache_examples=False,
    title="Translation",
    description="Support tasks: en2es, en2ja, en2zh, ja2zh, es2zh, es2ja"
)

demo.launch(debug=True, share=True, server_name="0.0.0.0")