File size: 2,359 Bytes
2e605bf
 
 
 
 
7198503
 
2e605bf
7198503
2e605bf
 
 
 
 
7198503
2e605bf
7198503
 
2e605bf
 
7198503
 
2e605bf
7198503
 
2e605bf
 
 
7198503
 
2e605bf
 
7198503
 
 
 
2e605bf
 
 
 
 
 
7198503
 
2e605bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7198503
2e605bf
 
7198503
2e605bf
7198503
 
 
 
2e605bf
 
 
 
 
 
7198503
2e605bf
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.generation.hugging_face import (
    HuggingFaceSeq2SeqGenerator,
    HuggingFaceGenerationAlgorithm
)
from transformers import AutoTokenizer

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

def run_inference(
    model_name_or_path: str,
    prefix: str,
    prompt: str,
    num_beams: int,
):

    config = HuggingFaceSeq2SeqGenerator(
        algorithm_version=model_name_or_path,
        prefix=prefix,
        prompt=prompt,
        num_beams=num_beams
    )

    model = HuggingFaceGenerationAlgorithm(config)
    tokenizer = AutoTokenizer.from_pretrained("t5-small")

    text = list(model.sample(1))[0]

    text = text.split(tokenizer.eos_token)[0]
    text = text.replace(tokenizer.pad_token, "")
    text = text.strip()

    return text


if __name__ == "__main__":

    # Preparation (retrieve all available algorithms)
    models = ["text-chem-t5-small-standard", "text-chem-t5-small-augm",
              "text-chem-t5-base-standard", "text-chem-t5-base-augm"]

    # Load metadata
    metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")

    examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
        ""
    )
    print("Examples: ", examples.values.tolist())

    with open(metadata_root.joinpath("article.md"), "r") as f:
        article = f.read()
    with open(metadata_root.joinpath("description.md"), "r") as f:
        description = f.read()

    demo = gr.Interface(
        fn=run_inference,
        title="Text-chem-T5 model",
        inputs=[
            gr.Dropdown(
                models,
                label="Language model",
                value="text-chem-t5-base-augm",
            ),
            gr.Textbox(
                label="Prefix", placeholder="A task-specific prefix", lines=1
            ),
            gr.Textbox(
                label="Text prompt",
                placeholder="I'm a stochastic parrot.",
                lines=1,
            ),
            gr.Slider(minimum=1, maximum=50, value=10, label="num_beams", step=1),
        ],
        outputs=gr.Textbox(label="Output"),
        article=article,
        description=description,
        examples=examples.values.tolist(),
    )
    demo.launch(debug=True, show_error=True)