osiria commited on
Commit
9136be2
·
1 Parent(s): 53f2dd4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import sys
4
+
5
+ def install(package):
6
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
7
+
8
+ install("numpy")
9
+ install("torch")
10
+ install("transformers")
11
+ install("unidecode")
12
+ install("gradio_client==0.2.7")
13
+
14
+ import numpy as np
15
+ import torch
16
+ from transformers import DebertaV2TokenizerFast, DebertaV2ForQuestionAnswering
17
+ import re
18
+ import string
19
+ from transformers.pipelines import QuestionAnsweringPipeline
20
+ from transformers import pipeline
21
+ from collections import Counter
22
+ from unidecode import unidecode
23
+ import gradio as gr
24
+
25
+ tokenizer = DebertaV2TokenizerFast.from_pretrained("osiria/deberta-italian-question-answering")
26
+ model = DebertaV2ForQuestionAnswering.from_pretrained("osiria/deberta-italian-question-answering")
27
+
28
+ class OsiriaQA(QuestionAnsweringPipeline):
29
+
30
+ def __init__(self, punctuation = ',;.:!?()[\]{}', **kwargs):
31
+
32
+ QuestionAnsweringPipeline.__init__(self, **kwargs)
33
+ self.post_regex_left = "^[\s" + punctuation + "]+"
34
+ self.post_regex_right = "[\s" + punctuation + "]+$"
35
+
36
+ def postprocess(self, output):
37
+
38
+ output = QuestionAnsweringPipeline.postprocess(self, model_outputs=output)
39
+ output_length = len(output["answer"])
40
+ output["answer"] = re.sub(self.post_regex_left, "", output["answer"])
41
+ output["start"] = output["start"] + (output_length - len(output["answer"]))
42
+ output_length = len(output["answer"])
43
+ output["answer"] = re.sub(self.post_regex_right, "", output["answer"])
44
+ output["end"] = output["end"] - (output_length - len(output["answer"]))
45
+
46
+ return output
47
+
48
+
49
+ device = torch.device("cpu")
50
+ model = model.to(device)
51
+ model.eval()
52
+
53
+
54
+ pipeline_qa = OsiriaQA(model = model, tokenizer = tokenizer)
55
+
56
+
57
+ header = '''--------------------------------------------------------------------------------------------------
58
+ <style>
59
+ .vertical-text {
60
+ writing-mode: vertical-lr;
61
+ text-orientation: upright;
62
+ background-color:red;
63
+ }
64
+ </style>
65
+ <center>
66
+ <body>
67
+ <span class="vertical-text" style="background-color:lightgreen;border-radius: 3px;padding: 3px;"> </span>
68
+ <span class="vertical-text" style="background-color:orange;border-radius: 3px;padding: 3px;"> D</span>
69
+ <span class="vertical-text" style="background-color:lightblue;border-radius: 3px;padding: 3px;">    E</span>
70
+ <span class="vertical-text" style="background-color:tomato;border-radius: 3px;padding: 3px;">    M</span>
71
+ <span class="vertical-text" style="background-color:lightgrey;border-radius: 3px;padding: 3px;"> O</span>
72
+ <span class="vertical-text" style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"> </span>
73
+ </body>
74
+ </center>
75
+ <br>
76
+ <center>(BETA)</center>
77
+ '''
78
+
79
+ def extract(question, context):
80
+
81
+ res = pipeline_qa(context = context,
82
+ question = question)
83
+
84
+ out_text = context[0:res["start"]] + '<span style="background-color:lightgreen;border-radius: 3px;padding: 3px;"><b>ᴀɴs </b> ' + context[res["start"]:res["end"]] + '</span>' + context[res["end"]:]
85
+
86
+ return out_text
87
+
88
+
89
+ init_question= "Cos'è l'Agenzia Spaziale Italiana?"
90
+
91
+ init_context = '''L'Agenzia Spaziale Italiana (ASI) è un ente governativo italiano, istituito nel 1988, che ha il compito di predisporre e attuare la politica aerospaziale italiana. Dipende e utilizza i fondi ricevuti dal Governo italiano per finanziare il progetto, lo sviluppo e la gestione operativa di missioni spaziali, con obiettivi scientifici e applicativi.
92
+
93
+ Gestisce missioni spaziali in proprio e in collaborazione con i maggiori organismi spaziali internazionali, prima tra tutte l'Agenzia Spaziale Europea (dove l'Italia è il terzo maggior contribuente dopo Francia e Germania, e a cui l'ASI corrisponde una parte del proprio budget), quindi la NASA e le altre agenzie spaziali nazionali. Per la realizzazione di satelliti e strumenti scientifici, l'ASI stipula contratti con le imprese, italiane e non, operanti nel settore aerospaziale.
94
+
95
+ Ha la sede principale a Roma e centri operativi a Matera (sede del Centro di geodesia spaziale Giuseppe Colombo) e Malindi, Kenya (sede del Centro spaziale Luigi Broglio). Il centro di Trapani-Milo, usato per i lanci di palloni stratosferici dal 1975, non è più operativo dal 2010.'''
96
+
97
+ init_output = extract(question = init_question, context = init_context)
98
+
99
+
100
+ with gr.Blocks(css="footer {visibility: hidden}", theme=gr.themes.Default(text_size="lg", spacing_size="lg")) as interface:
101
+
102
+ with gr.Row():
103
+ gr.Markdown(header)
104
+ with gr.Row():
105
+ context = gr.Text(label="Context", lines = 10, value = init_context)
106
+ with gr.Row():
107
+ question = gr.Text(label="Question", lines = 1, value = init_question)
108
+ with gr.Row():
109
+ gr.Examples([["Cosa fa l'Agenzia Spaziale Italiana?"],
110
+ ["Qual è la sigla dell'Agenzia Spaziale Italiana?"],
111
+ ["Quando è stata fondata l'ASI?"],
112
+ ["Chi finanzia l'ASI?"],
113
+ ["Chi altro contribuisce all'Agenzia Spaziale Europea oltre all'Italia?"],
114
+ ["Dove ha sede l'Agenzia Spaziale Italiana?"],
115
+ ["Dove si trova il centro spaziale Giuseppe Colombo?"],
116
+ ["Dove si trova il centro spaziale Luigi Broglio?"],
117
+ ["Il centro di Trapani-Milo è ancora in funzione?"]],
118
+ inputs=[question])
119
+ with gr.Row():
120
+ with gr.Column():
121
+ button = gr.Button("Ask").style(full_width=False)
122
+ with gr.Row():
123
+ with gr.Column():
124
+ output = gr.Markdown(init_output)
125
+
126
+ with gr.Row():
127
+ with gr.Column():
128
+ gr.Markdown("<center>The input examples in this demo are extracted from https://it.wikipedia.org</center>")
129
+
130
+ button.click(extract, inputs=[question, context], outputs = [output])
131
+
132
+
133
+ interface.launch()