File size: 4,780 Bytes
6e74145
 
6c8898d
6e74145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c8898d
6e74145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import torch
import spaces
import itertools
import pandas as pd
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


model_name = 'philipp-zettl/t5-small-long-qa'
qa_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model_name = 'philipp-zettl/t5-small-qg'
qg_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-small')

# Move only the student model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
qa_model = qa_model.to(device)
qg_model = qg_model.to(device)

max_questions = 1
max_answers = 1


def run_model(inputs, tokenizer, model, temperature=0.5, num_return_sequences=1):
    all_outputs = []
    for input_text in inputs:
        model_inputs = tokenizer([input_text], max_length=512, padding=True, truncation=True)
        input_ids = torch.tensor(model_inputs['input_ids']).to(device)
        for sample in input_ids:
            sample_outputs = []
            with torch.no_grad():
                sample_output = model.generate(
                    input_ids[:1],
                    max_length=85,
                    temperature=temperature,
                    do_sample=True,
                    num_return_sequences=num_return_sequences,
                    low_memory=True,
                    num_beams=max(2, num_return_sequences),
                    use_cache=True,
                )
                for i, sample_output in enumerate(sample_output):
                    sample_output = sample_output.unsqueeze(0)
                    sample_output = tokenizer.decode(sample_output[0], skip_special_tokens=True)
                    sample_outputs.append(sample_output)

            all_outputs.append(sample_outputs)
    return all_outputs


@spaces.GPU
def gen(content, temperature_qg=0.5, temperature_qa=0.75, num_return_sequences_qg=1, num_return_sequences_qa=1):
    inputs = [
        f'context: {content}'
    ]
    question = run_model(inputs, tokenizer, qg_model, temperature_qg, num_return_sequences_qg)

    inputs = list(itertools.chain.from_iterable([
        [f'question: {q} {inputs[0]}' for q in q_set] for q_set in question
    ]))
    answer = run_model(inputs, tokenizer, qa_model, temperature_qa, num_return_sequences_qa)

    questions = list(itertools.chain.from_iterable(question))
    answers = list(itertools.chain.from_iterable(answer))

    results = []
    for idx, ans in enumerate(answers):
        results.append({'question': questions[idx % num_return_sequences_qg], 'answer': ans})
    return results


def variable_outputs(k, max_elems=10):
    k = int(k)
    return [gr.Text(visible=True)] * k + [gr.Text(visible=False)] * (max(max_elems, 10)- k)


def set_outputs(content, max_elems=10):
    c = eval(content)
    print('received content: ', c)
    return [gr.Text(value=t, visible=True) for t in c] + [gr.Text(visible=False)] * (max(max_elems, 10) - len(c))


def create_file_download(qnas):
    with open('qnas.tsv', 'w') as f:
        for idx, qna in qnas.iterrows():
            f.write(qna['Question'] + '\t' + qna['Answer'])
            if idx < len(qnas) - 1:
                f.write('\n')
    return 'qnas.tsv'


with gr.Blocks() as demo:
    with gr.Row(equal_height=True):
        with gr.Group("Content"):
            content = gr.Textbox(label='Content', lines=15, placeholder='Enter text here', max_lines=10_000)
        with gr.Group("Settings"):
            temperature_qg = gr.Slider(label='Temperature QG', value=0.5, minimum=0, maximum=1, step=0.01)
            temperature_qa = gr.Slider(label='Temperature QA', value=0.75, minimum=0, maximum=1, step=0.01)
            num_return_sequences_qg = gr.Number(label='Number Questions', value=max_questions, minimum=1, step=1, maximum=max(max_questions, 10))
            num_return_sequences_qa = gr.Number(label="Number Answers", value=max_answers, minimum=1, step=1, maximum=max(max_questions, 10))

    with gr.Row():
        gen_btn = gr.Button("Generate")

    @gr.render(inputs=[content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa], triggers=[gen_btn.click])
    def render_results(content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa):
        qnas = gen(content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa)
        df = gr.Dataframe(
            value=[u.values() for u in qnas],
            headers=['Question', 'Answer'],
            col_count=2,
            wrap=True
        )
        pd_df = pd.DataFrame([u.values() for u in qnas], columns=['Question', 'Answer'])

        download = gr.DownloadButton(label='Download (without headers)', value=create_file_download(pd_df))


demo.launch()