Spaces:
Runtime error
Runtime error
Merge branch 'main' into Ling
Browse files- .gitignore +1 -1
- Dockerfile +3 -1
- README.md +1 -1
- api.py +51 -0
- model/vocab.pkl +0 -0
- requirements.txt +11 -3
- src/api.py +54 -10
- src/dataloader.py +56 -37
- src/fine_tune_t5.py +204 -0
- src/inference.py +16 -23
- src/inference_t5.py +5 -5
- src/model.py +40 -38
- src/script.py +0 -90
- src/train.py +75 -6
- templates/index.html.jinja +30 -9
.gitignore
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
.venv/**
|
2 |
data/**
|
3 |
src/__pycache__
|
4 |
-
model
|
5 |
html5up-helios/**
|
6 |
**/__pycache__/**
|
|
|
1 |
.venv/**
|
2 |
data/**
|
3 |
src/__pycache__
|
4 |
+
model/model.pt
|
5 |
html5up-helios/**
|
6 |
**/__pycache__/**
|
Dockerfile
CHANGED
@@ -8,4 +8,6 @@ RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
|
8 |
|
9 |
COPY . .
|
10 |
|
11 |
-
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
|
|
|
8 |
|
9 |
COPY . .
|
10 |
|
11 |
+
CMD ["uvicorn", "src.api:app", "--host", "0.0.0.0", "--port", "7860"]
|
12 |
+
|
13 |
+
#CMD python3 -m uvicorn --app-dir ./src api:app --host 0.0.0.0 --port 3001
|
README.md
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
---
|
2 |
title: SummaryProject
|
3 |
sdk: docker
|
4 |
-
app_file: app.py
|
5 |
pinned: false
|
6 |
---
|
7 |
# Initialisation
|
|
|
1 |
---
|
2 |
title: SummaryProject
|
3 |
sdk: docker
|
4 |
+
app_file: src/app.py
|
5 |
pinned: false
|
6 |
---
|
7 |
# Initialisation
|
api.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uvicorn
|
2 |
+
from fastapi import FastAPI, Form, Request
|
3 |
+
from fastapi.staticfiles import StaticFiles
|
4 |
+
from fastapi.templating import Jinja2Templates
|
5 |
+
|
6 |
+
from inference import inferenceAPI
|
7 |
+
|
8 |
+
|
9 |
+
# ------ MODELE --------------------------------------------------------------
|
10 |
+
# appel de la fonction inference, adaptee pour une entree txt
|
11 |
+
def summarize(text: str):
|
12 |
+
return " ".join(inferenceAPI(text))
|
13 |
+
|
14 |
+
|
15 |
+
# ----------------------------------------------------------------------------------
|
16 |
+
|
17 |
+
|
18 |
+
# -------- API ---------------------------------------------------------------------
|
19 |
+
app = FastAPI()
|
20 |
+
|
21 |
+
# static files pour envoi du css au navigateur
|
22 |
+
templates = Jinja2Templates(directory="templates")
|
23 |
+
app.mount("/", StaticFiles(directory="templates", html=True), name="templates")
|
24 |
+
|
25 |
+
|
26 |
+
@app.get("/")
|
27 |
+
async def index(request: Request):
|
28 |
+
return templates.TemplateResponse("index.html.jinja", {"request": request})
|
29 |
+
|
30 |
+
|
31 |
+
# retourner le texte, les predictions et message d'erreur si formulaire envoye vide
|
32 |
+
@app.post("/")
|
33 |
+
async def prediction(request: Request, text: str = Form(None)):
|
34 |
+
if not text:
|
35 |
+
error = "Merci de saisir votre texte."
|
36 |
+
return templates.TemplateResponse(
|
37 |
+
"index.html.jinja", {"request": request, "text": error}
|
38 |
+
)
|
39 |
+
else:
|
40 |
+
summary = summarize(text)
|
41 |
+
return templates.TemplateResponse(
|
42 |
+
"index.html.jinja", {"request": request, "text": text, "summary": summary}
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
# ------------------------------------------------------------------------------------
|
47 |
+
|
48 |
+
|
49 |
+
# lancer le serveur et le recharger a chaque modification sauvegardee
|
50 |
+
# if __name__ == "__main__":
|
51 |
+
# uvicorn.run("api:app", port=8000, reload=True)
|
model/vocab.pkl
ADDED
Binary file (63.4 kB). View file
|
|
requirements.txt
CHANGED
@@ -6,7 +6,6 @@ anyascii==0.3.1
|
|
6 |
anyio==3.6.2
|
7 |
async-timeout==4.0.2
|
8 |
attrs==22.2.0
|
9 |
-
autopep8==2.0.2
|
10 |
banal==1.0.6
|
11 |
blis==0.7.9
|
12 |
catalogue==2.0.8
|
@@ -33,6 +32,11 @@ fsspec==2023.3.0
|
|
33 |
greenlet==2.0.2
|
34 |
h11==0.14.0
|
35 |
huggingface-hub==0.12.1
|
|
|
|
|
|
|
|
|
|
|
36 |
idna==3.4
|
37 |
importlib-metadata==6.0.0
|
38 |
importlib-resources==5.12.0
|
@@ -47,7 +51,6 @@ mccabe==0.7.0
|
|
47 |
multidict==6.0.4
|
48 |
multiprocess==0.70.14
|
49 |
murmurhash==1.0.9
|
50 |
-
nltk==3.8.1
|
51 |
numpy==1.24.2
|
52 |
nvidia-cublas-cu11==11.10.3.66
|
53 |
nvidia-cuda-nvrtc-cu11==11.7.99
|
@@ -77,7 +80,6 @@ scikit-learn==1.2.1
|
|
77 |
scipy==1.10.0
|
78 |
sentencepiece==0.1.97
|
79 |
six==1.16.0
|
80 |
-
sklearn==0.0.post1
|
81 |
smart-open==6.3.0
|
82 |
sniffio==1.3.0
|
83 |
spacy==3.5.0
|
@@ -98,6 +100,12 @@ transformers==4.26.1
|
|
98 |
typer==0.7.0
|
99 |
typing-extensions==4.4.0
|
100 |
urllib3==1.26.14
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
uvicorn==0.20.0
|
102 |
wasabi==1.1.1
|
103 |
xxhash==3.2.0
|
|
|
6 |
anyio==3.6.2
|
7 |
async-timeout==4.0.2
|
8 |
attrs==22.2.0
|
|
|
9 |
banal==1.0.6
|
10 |
blis==0.7.9
|
11 |
catalogue==2.0.8
|
|
|
32 |
greenlet==2.0.2
|
33 |
h11==0.14.0
|
34 |
huggingface-hub==0.12.1
|
35 |
+
certifi==2022.12.7
|
36 |
+
charset-normalizer==3.1.0
|
37 |
+
click==8.1.3
|
38 |
+
fastapi==0.92.0
|
39 |
+
filelock==3.9.0
|
40 |
idna==3.4
|
41 |
importlib-metadata==6.0.0
|
42 |
importlib-resources==5.12.0
|
|
|
51 |
multidict==6.0.4
|
52 |
multiprocess==0.70.14
|
53 |
murmurhash==1.0.9
|
|
|
54 |
numpy==1.24.2
|
55 |
nvidia-cublas-cu11==11.10.3.66
|
56 |
nvidia-cuda-nvrtc-cu11==11.7.99
|
|
|
80 |
scipy==1.10.0
|
81 |
sentencepiece==0.1.97
|
82 |
six==1.16.0
|
|
|
83 |
smart-open==6.3.0
|
84 |
sniffio==1.3.0
|
85 |
spacy==3.5.0
|
|
|
100 |
typer==0.7.0
|
101 |
typing-extensions==4.4.0
|
102 |
urllib3==1.26.14
|
103 |
+
starlette==0.25.0
|
104 |
+
tokenizers==0.13.2
|
105 |
+
torch==1.13.1
|
106 |
+
tqdm==4.65.0
|
107 |
+
typing_extensions==4.5.0
|
108 |
+
urllib3==1.26.15
|
109 |
uvicorn==0.20.0
|
110 |
wasabi==1.1.1
|
111 |
xxhash==3.2.0
|
src/api.py
CHANGED
@@ -2,17 +2,33 @@ import uvicorn
|
|
2 |
from fastapi import FastAPI, Form, Request
|
3 |
from fastapi.staticfiles import StaticFiles
|
4 |
from fastapi.templating import Jinja2Templates
|
|
|
5 |
|
6 |
-
from inference import inferenceAPI
|
|
|
7 |
|
8 |
|
9 |
-
# ------
|
10 |
# appel de la fonction inference, adaptee pour une entree txt
|
11 |
def summarize(text: str):
|
12 |
-
|
|
|
|
|
|
|
|
|
13 |
# ----------------------------------------------------------------------------------
|
14 |
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
# -------- API ---------------------------------------------------------------------
|
17 |
app = FastAPI()
|
18 |
|
@@ -20,26 +36,54 @@ app = FastAPI()
|
|
20 |
templates = Jinja2Templates(directory="templates")
|
21 |
app.mount("/templates", StaticFiles(directory="templates"), name="templates")
|
22 |
|
|
|
23 |
@app.get("/")
|
24 |
async def index(request: Request):
|
25 |
return templates.TemplateResponse("index.html.jinja", {"request": request})
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
# retourner le texte, les predictions et message d'erreur si formulaire envoye vide
|
28 |
-
@app.post("/")
|
29 |
async def prediction(request: Request, text: str = Form(None)):
|
30 |
-
if not text
|
31 |
error = "Merci de saisir votre texte."
|
32 |
return templates.TemplateResponse(
|
33 |
-
|
34 |
-
|
35 |
-
else
|
36 |
summary = summarize(text)
|
37 |
return templates.TemplateResponse(
|
38 |
"index.html.jinja", {"request": request, "text": text, "summary": summary}
|
39 |
)
|
|
|
|
|
40 |
# ------------------------------------------------------------------------------------
|
41 |
|
42 |
|
43 |
# lancer le serveur et le recharger a chaque modification sauvegardee
|
44 |
-
if __name__ == "__main__":
|
45 |
-
|
|
|
2 |
from fastapi import FastAPI, Form, Request
|
3 |
from fastapi.staticfiles import StaticFiles
|
4 |
from fastapi.templating import Jinja2Templates
|
5 |
+
import re
|
6 |
|
7 |
+
from src.inference import inferenceAPI
|
8 |
+
from src.inference_t5 import inferenceAPI_t5
|
9 |
|
10 |
|
11 |
+
# ------ INFERENCE MODEL --------------------------------------------------------------
|
12 |
# appel de la fonction inference, adaptee pour une entree txt
|
13 |
def summarize(text: str):
|
14 |
+
if choisir_modele.var == 'lstm' :
|
15 |
+
return " ".join(inferenceAPI(text))
|
16 |
+
elif choisir_modele.var == "fineTunedT5":
|
17 |
+
text = inferenceAPI_t5(text)
|
18 |
+
|
19 |
# ----------------------------------------------------------------------------------
|
20 |
|
21 |
|
22 |
+
def choisir_modele(choixModele):
|
23 |
+
print("ON A RECUP LE CHOIX MODELE")
|
24 |
+
if choixModele == "lstm" :
|
25 |
+
choisir_modele.var ='lstm'
|
26 |
+
elif choixModele == "fineTunedT5":
|
27 |
+
choisir_modele.var = "fineTunedT5"
|
28 |
+
else :
|
29 |
+
"le modele n'est pas defini"
|
30 |
+
|
31 |
+
|
32 |
# -------- API ---------------------------------------------------------------------
|
33 |
app = FastAPI()
|
34 |
|
|
|
36 |
templates = Jinja2Templates(directory="templates")
|
37 |
app.mount("/templates", StaticFiles(directory="templates"), name="templates")
|
38 |
|
39 |
+
|
40 |
@app.get("/")
|
41 |
async def index(request: Request):
|
42 |
return templates.TemplateResponse("index.html.jinja", {"request": request})
|
43 |
|
44 |
+
@app.get("/model")
|
45 |
+
async def index(request: Request):
|
46 |
+
return templates.TemplateResponse("index.html.jinja", {"request": request})
|
47 |
+
|
48 |
+
@app.get("/predict")
|
49 |
+
async def index(request: Request):
|
50 |
+
return templates.TemplateResponse("index.html.jinja", {"request": request})
|
51 |
+
|
52 |
+
|
53 |
+
@app.post("/model")
|
54 |
+
async def choix_model(request: Request, choixModel:str = Form(None)):
|
55 |
+
print(choixModel)
|
56 |
+
if not choixModel:
|
57 |
+
erreur_modele = "Merci de saisir un modèle."
|
58 |
+
return templates.TemplateResponse(
|
59 |
+
"index.html.jinja", {"request": request, "text": erreur_modele}
|
60 |
+
)
|
61 |
+
else :
|
62 |
+
choisir_modele(choixModel)
|
63 |
+
print("C'est bon on utilise le modèle demandé")
|
64 |
+
return templates.TemplateResponse(
|
65 |
+
"index.html.jinja", {"request": request}
|
66 |
+
)
|
67 |
+
|
68 |
+
|
69 |
# retourner le texte, les predictions et message d'erreur si formulaire envoye vide
|
70 |
+
@app.post("/predict")
|
71 |
async def prediction(request: Request, text: str = Form(None)):
|
72 |
+
if not text:
|
73 |
error = "Merci de saisir votre texte."
|
74 |
return templates.TemplateResponse(
|
75 |
+
"index.html.jinja", {"request": request, "text": error}
|
76 |
+
)
|
77 |
+
else:
|
78 |
summary = summarize(text)
|
79 |
return templates.TemplateResponse(
|
80 |
"index.html.jinja", {"request": request, "text": text, "summary": summary}
|
81 |
)
|
82 |
+
|
83 |
+
|
84 |
# ------------------------------------------------------------------------------------
|
85 |
|
86 |
|
87 |
# lancer le serveur et le recharger a chaque modification sauvegardee
|
88 |
+
# if __name__ == "__main__":
|
89 |
+
# uvicorn.run("api:app", port=8000, reload=True)
|
src/dataloader.py
CHANGED
@@ -11,17 +11,15 @@
|
|
11 |
Création d'un Vectoriserà partir du vocabulaire :
|
12 |
|
13 |
"""
|
|
|
14 |
import string
|
15 |
from collections import Counter
|
16 |
|
17 |
import pandas as pd
|
18 |
import torch
|
19 |
-
from nltk import word_tokenize
|
20 |
|
21 |
-
# nltk.download('punkt')
|
22 |
|
23 |
-
|
24 |
-
class Data:
|
25 |
"""
|
26 |
A class used to get data from file
|
27 |
...
|
@@ -44,8 +42,27 @@ class Data:
|
|
44 |
create a dataset with cleaned data
|
45 |
"""
|
46 |
|
47 |
-
def __init__(self, path: str) -> None:
|
48 |
self.path = path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
def open(self) -> pd.DataFrame:
|
51 |
"""
|
@@ -85,26 +102,13 @@ class Data:
|
|
85 |
# - s'occuper des noms propres (mots commençant par une majuscule qui se suivent)
|
86 |
for text in texts:
|
87 |
text = text.translate(str.maketrans("", "", string.punctuation))
|
88 |
-
text =
|
89 |
tokenized_texts.append(text)
|
90 |
|
91 |
if text_type == "summary":
|
92 |
return [["<start>", *summary, "<end>"] for summary in tokenized_texts]
|
93 |
return tokenized_texts
|
94 |
|
95 |
-
def pad_sequence(self):
|
96 |
-
"""
|
97 |
-
pad summary with empty token
|
98 |
-
"""
|
99 |
-
texts = self.clean_data("text")
|
100 |
-
summaries = self.clean_data("summary")
|
101 |
-
padded_summary = []
|
102 |
-
for text, summary in zip(texts, summaries):
|
103 |
-
if len(summary) != len(text):
|
104 |
-
summary += ["<empty>"] * (len(text) - len(summary))
|
105 |
-
padded_summary.append(summary)
|
106 |
-
return texts, padded_summary
|
107 |
-
|
108 |
def get_words(self) -> list:
|
109 |
"""
|
110 |
Create a dictionnary of the data vocabulary
|
@@ -114,15 +118,20 @@ class Data:
|
|
114 |
summary_words = [word for text in summaries for word in text]
|
115 |
return text_words + summary_words
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
|
128 |
class Vectoriser:
|
@@ -146,12 +155,25 @@ class Vectoriser:
|
|
146 |
encode an entire row from the dataset
|
147 |
"""
|
148 |
|
149 |
-
def __init__(self, vocab) -> None:
|
150 |
self.vocab = vocab
|
151 |
self.word_count = Counter(word.lower().strip(",.\\-") for word in self.vocab)
|
152 |
self.idx_to_token = sorted([t for t, c in self.word_count.items() if c > 1])
|
153 |
self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
def encode(self, tokens) -> torch.tensor:
|
156 |
"""
|
157 |
Encode une phrase selon les mots qu'elle contient
|
@@ -165,7 +187,7 @@ class Vectoriser:
|
|
165 |
:return: words_idx : tensor
|
166 |
Un tensor contenant les index des mots de la phrase
|
167 |
"""
|
168 |
-
if
|
169 |
words_idx = torch.tensor(
|
170 |
[
|
171 |
self.token_to_idx.get(t.lower(), len(self.token_to_idx))
|
@@ -175,7 +197,7 @@ class Vectoriser:
|
|
175 |
)
|
176 |
|
177 |
# Permet d'encoder mots par mots
|
178 |
-
elif
|
179 |
words_idx = torch.tensor(self.token_to_idx.get(tokens.lower()))
|
180 |
|
181 |
return words_idx
|
@@ -184,9 +206,9 @@ class Vectoriser:
|
|
184 |
"""
|
185 |
Decode une phrase selon le procédé inverse que la fonction encode
|
186 |
"""
|
187 |
-
|
188 |
idxs = words_idx_tensor.tolist()
|
189 |
-
if
|
190 |
words = [self.idx_to_token[idxs]]
|
191 |
else:
|
192 |
words = []
|
@@ -195,10 +217,7 @@ class Vectoriser:
|
|
195 |
words.append(self.idx_to_token[idx])
|
196 |
return words
|
197 |
|
198 |
-
def
|
199 |
-
pass
|
200 |
-
|
201 |
-
def vectorize(self, row) -> torch.tensor:
|
202 |
"""
|
203 |
Encode les données d'une ligne du dataframe
|
204 |
----------
|
|
|
11 |
Création d'un Vectoriserà partir du vocabulaire :
|
12 |
|
13 |
"""
|
14 |
+
import pickle
|
15 |
import string
|
16 |
from collections import Counter
|
17 |
|
18 |
import pandas as pd
|
19 |
import torch
|
|
|
20 |
|
|
|
21 |
|
22 |
+
class Data(torch.utils.data.Dataset):
|
|
|
23 |
"""
|
24 |
A class used to get data from file
|
25 |
...
|
|
|
42 |
create a dataset with cleaned data
|
43 |
"""
|
44 |
|
45 |
+
def __init__(self, path: str, transform=None) -> None:
|
46 |
self.path = path
|
47 |
+
self.data = pd.read_json(path_or_buf=self.path, lines=True)
|
48 |
+
self.transform = transform
|
49 |
+
|
50 |
+
def __len__(self):
|
51 |
+
return len(self.data)
|
52 |
+
|
53 |
+
def __getitem__(self, idx):
|
54 |
+
row = self.data.iloc[idx]
|
55 |
+
text = row["text"].translate(str.maketrans("", "", string.punctuation)).split()
|
56 |
+
summary = (
|
57 |
+
row["summary"].translate(str.maketrans("", "", string.punctuation)).split()
|
58 |
+
)
|
59 |
+
summary = ["<start>", *summary, "<end>"]
|
60 |
+
sample = {"text": text, "summary": summary}
|
61 |
+
|
62 |
+
if self.transform:
|
63 |
+
sample = self.transform(sample)
|
64 |
+
|
65 |
+
return sample
|
66 |
|
67 |
def open(self) -> pd.DataFrame:
|
68 |
"""
|
|
|
102 |
# - s'occuper des noms propres (mots commençant par une majuscule qui se suivent)
|
103 |
for text in texts:
|
104 |
text = text.translate(str.maketrans("", "", string.punctuation))
|
105 |
+
text = text.split()
|
106 |
tokenized_texts.append(text)
|
107 |
|
108 |
if text_type == "summary":
|
109 |
return [["<start>", *summary, "<end>"] for summary in tokenized_texts]
|
110 |
return tokenized_texts
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
def get_words(self) -> list:
|
113 |
"""
|
114 |
Create a dictionnary of the data vocabulary
|
|
|
118 |
summary_words = [word for text in summaries for word in text]
|
119 |
return text_words + summary_words
|
120 |
|
121 |
+
|
122 |
+
def pad_collate(data):
|
123 |
+
text_batch = [element[0] for element in data]
|
124 |
+
summary_batch = [element[1] for element in data]
|
125 |
+
max_len = max([len(element) for element in summary_batch + text_batch])
|
126 |
+
text_batch = [
|
127 |
+
torch.nn.functional.pad(element, (0, max_len - len(element)), value=-100)
|
128 |
+
for element in text_batch
|
129 |
+
]
|
130 |
+
summary_batch = [
|
131 |
+
torch.nn.functional.pad(element, (0, max_len - len(element)), value=-100)
|
132 |
+
for element in summary_batch
|
133 |
+
]
|
134 |
+
return text_batch, summary_batch
|
135 |
|
136 |
|
137 |
class Vectoriser:
|
|
|
155 |
encode an entire row from the dataset
|
156 |
"""
|
157 |
|
158 |
+
def __init__(self, vocab=None) -> None:
|
159 |
self.vocab = vocab
|
160 |
self.word_count = Counter(word.lower().strip(",.\\-") for word in self.vocab)
|
161 |
self.idx_to_token = sorted([t for t, c in self.word_count.items() if c > 1])
|
162 |
self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
|
163 |
|
164 |
+
def load(self, path):
|
165 |
+
with open(path, "rb") as file:
|
166 |
+
self.vocab = pickle.load(file)
|
167 |
+
self.word_count = Counter(
|
168 |
+
word.lower().strip(",.\\-") for word in self.vocab
|
169 |
+
)
|
170 |
+
self.idx_to_token = sorted([t for t, c in self.word_count.items() if c > 1])
|
171 |
+
self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
|
172 |
+
|
173 |
+
def save(self, path):
|
174 |
+
with open(path, "wb") as file:
|
175 |
+
pickle.dump(self.vocab, file)
|
176 |
+
|
177 |
def encode(self, tokens) -> torch.tensor:
|
178 |
"""
|
179 |
Encode une phrase selon les mots qu'elle contient
|
|
|
187 |
:return: words_idx : tensor
|
188 |
Un tensor contenant les index des mots de la phrase
|
189 |
"""
|
190 |
+
if isinstance(tokens, list):
|
191 |
words_idx = torch.tensor(
|
192 |
[
|
193 |
self.token_to_idx.get(t.lower(), len(self.token_to_idx))
|
|
|
197 |
)
|
198 |
|
199 |
# Permet d'encoder mots par mots
|
200 |
+
elif isinstance(tokens, str):
|
201 |
words_idx = torch.tensor(self.token_to_idx.get(tokens.lower()))
|
202 |
|
203 |
return words_idx
|
|
|
206 |
"""
|
207 |
Decode une phrase selon le procédé inverse que la fonction encode
|
208 |
"""
|
209 |
+
|
210 |
idxs = words_idx_tensor.tolist()
|
211 |
+
if isinstance(idxs, int):
|
212 |
words = [self.idx_to_token[idxs]]
|
213 |
else:
|
214 |
words = []
|
|
|
217 |
words.append(self.idx_to_token[idx])
|
218 |
return words
|
219 |
|
220 |
+
def __call__(self, row) -> torch.tensor:
|
|
|
|
|
|
|
221 |
"""
|
222 |
Encode les données d'une ligne du dataframe
|
223 |
----------
|
src/fine_tune_t5.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import datasets
|
3 |
+
from datasets import Dataset, DatasetDict
|
4 |
+
import pandas as pd
|
5 |
+
from tqdm import tqdm
|
6 |
+
import re
|
7 |
+
import os
|
8 |
+
import nltk
|
9 |
+
import string
|
10 |
+
import contractions
|
11 |
+
from transformers import pipeline
|
12 |
+
import evaluate
|
13 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer,AutoConfig
|
14 |
+
from transformers import Seq2SeqTrainingArguments ,Seq2SeqTrainer
|
15 |
+
from transformers import DataCollatorForSeq2Seq
|
16 |
+
|
17 |
+
# cuda out of memory
|
18 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:200"
|
19 |
+
|
20 |
+
nltk.download('stopwords')
|
21 |
+
nltk.download('punkt')
|
22 |
+
|
23 |
+
|
24 |
+
def clean_data(texts):
|
25 |
+
texts = texts.lower()
|
26 |
+
texts = contractions.fix(texts)
|
27 |
+
texts = texts.translate(str.maketrans("", "", string.punctuation))
|
28 |
+
texts = re.sub(r'\n',' ',texts)
|
29 |
+
return texts
|
30 |
+
|
31 |
+
def datasetmaker (path=str):
|
32 |
+
data = pd.read_json(path, lines=True)
|
33 |
+
df = data.drop(['url','archive','title','date','compression','coverage','density','compression_bin','coverage_bin','density_bin'],axis=1)
|
34 |
+
tqdm.pandas()
|
35 |
+
df['text'] = df.text.apply(lambda texts : clean_data(texts))
|
36 |
+
df['summary'] = df.summary.apply(lambda summary : clean_data(summary))
|
37 |
+
# df['text'] = df['text'].map(str)
|
38 |
+
# df['summary'] = df['summary'].map(str)
|
39 |
+
dataset = Dataset.from_dict(df)
|
40 |
+
return dataset
|
41 |
+
|
42 |
+
#voir si le model par hasard esr déjà bien
|
43 |
+
|
44 |
+
# test_text = dataset['text'][0]
|
45 |
+
# pipe = pipeline('summarization',model = model_ckpt)
|
46 |
+
# pipe_out = pipe(test_text)
|
47 |
+
# print (pipe_out[0]['summary_text'].replace('.<n>','.\n'))
|
48 |
+
# print(dataset['summary'][0])
|
49 |
+
|
50 |
+
def generate_batch_sized_chunks(list_elements, batch_size):
|
51 |
+
"""split the dataset into smaller batches that we can process simultaneously
|
52 |
+
Yield successive batch-sized chunks from list_of_elements."""
|
53 |
+
for i in range(0, len(list_elements), batch_size):
|
54 |
+
yield list_elements[i : i + batch_size]
|
55 |
+
|
56 |
+
def calculate_metric(dataset, metric, model, tokenizer,
|
57 |
+
batch_size, device,
|
58 |
+
column_text='text',
|
59 |
+
column_summary='summary'):
|
60 |
+
article_batches = list(str(generate_batch_sized_chunks(dataset[column_text], batch_size)))
|
61 |
+
target_batches = list(str(generate_batch_sized_chunks(dataset[column_summary], batch_size)))
|
62 |
+
|
63 |
+
for article_batch, target_batch in tqdm(
|
64 |
+
zip(article_batches, target_batches), total=len(article_batches)):
|
65 |
+
|
66 |
+
inputs = tokenizer(article_batch, max_length=1024, truncation=True,
|
67 |
+
padding="max_length", return_tensors="pt")
|
68 |
+
|
69 |
+
summaries = model.generate(input_ids=inputs["input_ids"].to(device),
|
70 |
+
attention_mask=inputs["attention_mask"].to(device),
|
71 |
+
length_penalty=0.8, num_beams=8, max_length=128)
|
72 |
+
''' parameter for length penalty ensures that the model does not generate sequences that are too long. '''
|
73 |
+
|
74 |
+
# Décode les textes
|
75 |
+
# renplacer les tokens, ajouter des textes décodés avec les rédéfences vers la métrique.
|
76 |
+
decoded_summaries = [tokenizer.decode(s, skip_special_tokens=True,
|
77 |
+
clean_up_tokenization_spaces=True)
|
78 |
+
for s in summaries]
|
79 |
+
|
80 |
+
decoded_summaries = [d.replace("", " ") for d in decoded_summaries]
|
81 |
+
|
82 |
+
|
83 |
+
metric.add_batch(predictions=decoded_summaries, references=target_batch)
|
84 |
+
|
85 |
+
#compute et return les ROUGE scores.
|
86 |
+
results = metric.compute()
|
87 |
+
rouge_names = ['rouge1','rouge2','rougeL','rougeLsum']
|
88 |
+
rouge_dict = dict((rn, results[rn] ) for rn in rouge_names )
|
89 |
+
return pd.DataFrame(rouge_dict, index = ['T5'])
|
90 |
+
|
91 |
+
|
92 |
+
def convert_ex_to_features(example_batch):
|
93 |
+
input_encodings = tokenizer(example_batch['text'],max_length = 1024,truncation = True)
|
94 |
+
|
95 |
+
labels =tokenizer(example_batch['summary'], max_length = 128, truncation = True )
|
96 |
+
|
97 |
+
return {
|
98 |
+
'input_ids' : input_encodings['input_ids'],
|
99 |
+
'attention_mask': input_encodings['attention_mask'],
|
100 |
+
'labels': labels['input_ids']
|
101 |
+
}
|
102 |
+
|
103 |
+
if __name__=='__main__':
|
104 |
+
|
105 |
+
train_dataset = datasetmaker('data/train_extract_100.jsonl')
|
106 |
+
|
107 |
+
dev_dataset = datasetmaker('data/dev_extract_100.jsonl')
|
108 |
+
|
109 |
+
test_dataset = datasetmaker('data/test_extract_100.jsonl')
|
110 |
+
|
111 |
+
dataset = datasets.DatasetDict({'train':train_dataset,'dev':dev_dataset ,'test':test_dataset})
|
112 |
+
|
113 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
114 |
+
|
115 |
+
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
|
116 |
+
mt5_config = AutoConfig.from_pretrained(
|
117 |
+
"google/mt5-small",
|
118 |
+
max_length=128,
|
119 |
+
length_penalty=0.6,
|
120 |
+
no_repeat_ngram_size=2,
|
121 |
+
num_beams=15,
|
122 |
+
)
|
123 |
+
model = (AutoModelForSeq2SeqLM
|
124 |
+
.from_pretrained("google/mt5-small", config=mt5_config)
|
125 |
+
.to(device))
|
126 |
+
|
127 |
+
dataset_pt= dataset.map(convert_ex_to_features,remove_columns=["summary", "text"],batched = True,batch_size=128)
|
128 |
+
|
129 |
+
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model,return_tensors="pt")
|
130 |
+
|
131 |
+
|
132 |
+
training_args = Seq2SeqTrainingArguments(
|
133 |
+
output_dir = "mt5_sum",
|
134 |
+
log_level = "error",
|
135 |
+
num_train_epochs = 10,
|
136 |
+
learning_rate = 5e-4,
|
137 |
+
# lr_scheduler_type = "linear",
|
138 |
+
warmup_steps = 0,
|
139 |
+
optim = "adafactor",
|
140 |
+
weight_decay = 0.01,
|
141 |
+
per_device_train_batch_size = 2,
|
142 |
+
per_device_eval_batch_size = 1,
|
143 |
+
gradient_accumulation_steps = 16,
|
144 |
+
evaluation_strategy = "steps",
|
145 |
+
eval_steps = 100,
|
146 |
+
predict_with_generate=True,
|
147 |
+
generation_max_length = 128,
|
148 |
+
save_steps = 500,
|
149 |
+
logging_steps = 10,
|
150 |
+
# push_to_hub = True
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
trainer = Seq2SeqTrainer(
|
155 |
+
model = model,
|
156 |
+
args = training_args,
|
157 |
+
data_collator = data_collator,
|
158 |
+
# compute_metrics = calculate_metric,
|
159 |
+
train_dataset=dataset_pt['train'],
|
160 |
+
eval_dataset=dataset_pt['dev'].select(range(10)),
|
161 |
+
tokenizer = tokenizer,
|
162 |
+
)
|
163 |
+
|
164 |
+
trainer.train()
|
165 |
+
rouge_metric = evaluate.load("rouge")
|
166 |
+
|
167 |
+
score = calculate_metric(test_dataset, rouge_metric, trainer.model, tokenizer,
|
168 |
+
batch_size=2, device=device,
|
169 |
+
column_text='text',
|
170 |
+
column_summary='summary')
|
171 |
+
print (score)
|
172 |
+
|
173 |
+
|
174 |
+
#Fine Tuning terminés et à sauvgarder
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
# save fine-tuned model in local
|
179 |
+
os.makedirs("./summarization_t5", exist_ok=True)
|
180 |
+
if hasattr(trainer.model, "module"):
|
181 |
+
trainer.model.module.save_pretrained("./summarization_t5")
|
182 |
+
else:
|
183 |
+
trainer.model.save_pretrained("./summarization_t5")
|
184 |
+
tokenizer.save_pretrained("./summarization_t5")
|
185 |
+
# load local model
|
186 |
+
model = (AutoModelForSeq2SeqLM
|
187 |
+
.from_pretrained("./summarization_t5")
|
188 |
+
.to(device))
|
189 |
+
|
190 |
+
|
191 |
+
# mettre en usage : TEST
|
192 |
+
|
193 |
+
|
194 |
+
# gen_kwargs = {"length_penalty": 0.8, "num_beams":8, "max_length": 128}
|
195 |
+
# sample_text = dataset["test"][0]["text"]
|
196 |
+
# reference = dataset["test"][0]["summary"]
|
197 |
+
# pipe = pipeline("summarization", model='./summarization_t5')
|
198 |
+
|
199 |
+
# print("Text:")
|
200 |
+
# print(sample_text)
|
201 |
+
# print("\nReference Summary:")
|
202 |
+
# print(reference)
|
203 |
+
# print("\nModel Summary:")
|
204 |
+
# print(pipe(sample_text, **gen_kwargs)[0]["summary_text"])
|
src/inference.py
CHANGED
@@ -1,22 +1,16 @@
|
|
1 |
"""
|
2 |
Allows to predict the summary for a given entry text
|
3 |
"""
|
4 |
-
import
|
5 |
-
from nltk import word_tokenize
|
6 |
|
7 |
-
import
|
8 |
-
from model import Decoder, Encoder, EncoderDecoderModel
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
data1 = dataloader.Data("data/train_extract.jsonl")
|
13 |
-
data2 = dataloader.Data("data/dev_extract.jsonl")
|
14 |
-
train_dataset = data1.make_dataset()
|
15 |
-
dev_dataset = data2.make_dataset()
|
16 |
-
words = data1.get_words()
|
17 |
|
|
|
|
|
18 |
vectoriser = dataloader.Vectoriser(words)
|
19 |
-
word_counts = vectoriser.word_count
|
20 |
|
21 |
|
22 |
def inferenceAPI(text: str) -> str:
|
@@ -30,22 +24,20 @@ def inferenceAPI(text: str) -> str:
|
|
30 |
str
|
31 |
The summary for the input text
|
32 |
"""
|
33 |
-
text =
|
34 |
# On défini les paramètres d'entrée pour le modèle
|
35 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
36 |
-
encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
|
37 |
-
|
38 |
-
)
|
39 |
-
decoder
|
40 |
-
device
|
41 |
-
)
|
42 |
|
43 |
# On instancie le modèle
|
44 |
-
model = EncoderDecoderModel(encoder, decoder, device)
|
45 |
|
46 |
-
model.load_state_dict(torch.load("model/model.pt", map_location=device))
|
47 |
-
model.eval()
|
48 |
-
model.to(device)
|
49 |
|
50 |
# On vectorise le texte
|
51 |
source = vectoriser.encode(text)
|
@@ -55,6 +47,7 @@ def inferenceAPI(text: str) -> str:
|
|
55 |
with torch.no_grad():
|
56 |
output = model(source).to(device)
|
57 |
output.to(device)
|
|
|
58 |
return vectoriser.decode(output)
|
59 |
|
60 |
|
|
|
1 |
"""
|
2 |
Allows to predict the summary for a given entry text
|
3 |
"""
|
4 |
+
import pickle
|
|
|
5 |
|
6 |
+
import torch
|
|
|
7 |
|
8 |
+
from src import dataloader
|
9 |
+
from src.model import Decoder, Encoder, EncoderDecoderModel
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
+
with open("model/vocab.pkl", "rb") as vocab:
|
12 |
+
words = pickle.load(vocab)
|
13 |
vectoriser = dataloader.Vectoriser(words)
|
|
|
14 |
|
15 |
|
16 |
def inferenceAPI(text: str) -> str:
|
|
|
24 |
str
|
25 |
The summary for the input text
|
26 |
"""
|
27 |
+
text = text.split()
|
28 |
# On défini les paramètres d'entrée pour le modèle
|
29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
+
encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
|
31 |
+
encoder.to(device)
|
32 |
+
decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
|
33 |
+
decoder.to(device)
|
|
|
|
|
34 |
|
35 |
# On instancie le modèle
|
36 |
+
model = EncoderDecoderModel(encoder, decoder, vectoriser, device)
|
37 |
|
38 |
+
# model.load_state_dict(torch.load("model/model.pt", map_location=device))
|
39 |
+
# model.eval()
|
40 |
+
# model.to(device)
|
41 |
|
42 |
# On vectorise le texte
|
43 |
source = vectoriser.encode(text)
|
|
|
47 |
with torch.no_grad():
|
48 |
output = model(source).to(device)
|
49 |
output.to(device)
|
50 |
+
output = output.argmax(dim=-1)
|
51 |
return vectoriser.decode(output)
|
52 |
|
53 |
|
src/inference_t5.py
CHANGED
@@ -27,6 +27,7 @@ def inferenceAPI(text: str) -> str:
|
|
27 |
str
|
28 |
The summary for the input text
|
29 |
"""
|
|
|
30 |
# On défini les paramètres d'entrée pour le modèle
|
31 |
text = clean_text(text)
|
32 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -35,6 +36,7 @@ def inferenceAPI(text: str) -> str:
|
|
35 |
model = (AutoModelForSeq2SeqLM
|
36 |
.from_pretrained("Linggg/t5_summary")
|
37 |
.to(device))
|
|
|
38 |
text_encoding = tokenizer(
|
39 |
text,
|
40 |
max_length=1024,
|
@@ -60,8 +62,6 @@ def inferenceAPI(text: str) -> str:
|
|
60 |
return "".join(preds)
|
61 |
|
62 |
|
63 |
-
if __name__ == "__main__":
|
64 |
-
|
65 |
-
|
66 |
-
text = input('Entrez votre phrase à résumer : ')
|
67 |
-
print('summary:', inferenceAPI(text))
|
|
|
27 |
str
|
28 |
The summary for the input text
|
29 |
"""
|
30 |
+
|
31 |
# On défini les paramètres d'entrée pour le modèle
|
32 |
text = clean_text(text)
|
33 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
36 |
model = (AutoModelForSeq2SeqLM
|
37 |
.from_pretrained("Linggg/t5_summary")
|
38 |
.to(device))
|
39 |
+
|
40 |
text_encoding = tokenizer(
|
41 |
text,
|
42 |
max_length=1024,
|
|
|
62 |
return "".join(preds)
|
63 |
|
64 |
|
65 |
+
# if __name__ == "__main__":
|
66 |
+
# text = input('Entrez votre phrase à résumer : ')
|
67 |
+
# print('summary:', inferenceAPI(text))
|
|
|
|
src/model.py
CHANGED
@@ -6,14 +6,8 @@ import logging
|
|
6 |
|
7 |
import torch
|
8 |
|
9 |
-
import dataloader
|
10 |
-
|
11 |
logging.basicConfig(level=logging.DEBUG)
|
12 |
|
13 |
-
data1 = dataloader.Data("data/train_extract.jsonl")
|
14 |
-
words = data1.get_words()
|
15 |
-
vectoriser = dataloader.Vectoriser(words)
|
16 |
-
|
17 |
|
18 |
class Encoder(torch.nn.Module):
|
19 |
def __init__(
|
@@ -86,51 +80,59 @@ class Decoder(torch.nn.Module):
|
|
86 |
|
87 |
|
88 |
class EncoderDecoderModel(torch.nn.Module):
|
89 |
-
def __init__(self, encoder, decoder, device):
|
90 |
# Une idiosyncrasie de torch, pour qu'iel puisse faire sa magie
|
91 |
super().__init__()
|
92 |
self.encoder = encoder
|
93 |
self.decoder = decoder
|
|
|
94 |
self.device = device
|
95 |
|
96 |
-
def forward(self, source, num_beams=3):
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
# last hidden state of the encoder is used as the initial hidden state of the decoder
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
)
|
117 |
-
cell.to(
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
# first input to the decoder is the <start> token.
|
122 |
-
input = vectoriser.encode("<start>") # Mot de départ du MOdèle
|
123 |
-
input.to(self.device) # idiosyncrasie de torch pour mmettre sur GPU
|
124 |
-
|
125 |
-
### DÉBUT DE L'INSTANCIATION TEST ###
|
126 |
# If you wonder, b stands for better
|
127 |
values = None
|
128 |
b_outputs = torch.zeros(target_len, target_vocab_size).to(self.device)
|
129 |
b_outputs.to(self.device)
|
130 |
|
131 |
-
for i in range(
|
132 |
-
|
133 |
-
): # On va déterminer autant de mot que la taille du texte souhaité
|
134 |
# insert input token embedding, previous hidden and previous cell states
|
135 |
# receive output tensor (predictions) and new hidden and cell states.
|
136 |
|
|
|
6 |
|
7 |
import torch
|
8 |
|
|
|
|
|
9 |
logging.basicConfig(level=logging.DEBUG)
|
10 |
|
|
|
|
|
|
|
|
|
11 |
|
12 |
class Encoder(torch.nn.Module):
|
13 |
def __init__(
|
|
|
80 |
|
81 |
|
82 |
class EncoderDecoderModel(torch.nn.Module):
|
83 |
+
def __init__(self, encoder, decoder, vectoriser, device):
|
84 |
# Une idiosyncrasie de torch, pour qu'iel puisse faire sa magie
|
85 |
super().__init__()
|
86 |
self.encoder = encoder
|
87 |
self.decoder = decoder
|
88 |
+
self.vectoriser = vectoriser
|
89 |
self.device = device
|
90 |
|
91 |
+
def forward(self, source, num_beams=3, summary_len=0.2):
|
92 |
+
"""
|
93 |
+
:param source: tensor
|
94 |
+
the input text
|
95 |
+
:param num_beams: int
|
96 |
+
the number of outputs to iterate on for beam_search
|
97 |
+
:param summary_len: int
|
98 |
+
length ratio of the summary compared to the text
|
99 |
+
"""
|
100 |
+
# The ratio must be inferior to 1 to allow text compression
|
101 |
+
assert summary_len < 1, f"number lesser than 1 expected, got {summary_len}"
|
102 |
+
|
103 |
+
target_len = int(
|
104 |
+
summary_len * source.shape[0]
|
105 |
+
) # Expected summary length (in words)
|
106 |
+
target_vocab_size = self.decoder.vocab_size # Word Embedding length
|
107 |
+
|
108 |
+
# Output of the right format (expected summmary length x word embedding length)
|
109 |
+
# filled with zeros. On each iteration, we will replace one of the row of this
|
110 |
+
# matrix with the choosen word embedding
|
111 |
+
outputs = torch.zeros(target_len, target_vocab_size)
|
112 |
+
|
113 |
+
# put the tensors on the device (useless if CPU bus very useful in case of GPU)
|
114 |
+
outputs.to(self.device)
|
115 |
+
source.to(self.device)
|
116 |
|
117 |
# last hidden state of the encoder is used as the initial hidden state of the decoder
|
118 |
+
hidden, cell = self.encoder(source) # Encode the input text
|
119 |
+
input = self.vectoriser.encode(
|
120 |
+
"<start>"
|
121 |
+
) # Encode the first word of the summary
|
122 |
+
|
123 |
+
# put the tensors on the device
|
124 |
+
hidden.to(self.device)
|
125 |
+
cell.to(self.device)
|
126 |
+
input.to(self.device)
|
127 |
+
|
128 |
+
### BEAM SEARCH ###
|
|
|
|
|
|
|
|
|
|
|
129 |
# If you wonder, b stands for better
|
130 |
values = None
|
131 |
b_outputs = torch.zeros(target_len, target_vocab_size).to(self.device)
|
132 |
b_outputs.to(self.device)
|
133 |
|
134 |
+
for i in range(1, target_len):
|
135 |
+
# On va déterminer autant de mot que la taille du texte souhaité
|
|
|
136 |
# insert input token embedding, previous hidden and previous cell states
|
137 |
# receive output tensor (predictions) and new hidden and cell states.
|
138 |
|
src/script.py
DELETED
@@ -1,90 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
DONE :
|
3 |
-
- Separer la partie vectoriser du Classifeur
|
4 |
-
- Ajouter un LSTM au Classifieur
|
5 |
-
- entrainer le Classifieur
|
6 |
-
TO DO :
|
7 |
-
- Améliorer les résultats du modèle
|
8 |
-
"""
|
9 |
-
import logging
|
10 |
-
import random
|
11 |
-
from typing import Sequence
|
12 |
-
|
13 |
-
import torch
|
14 |
-
|
15 |
-
import dataloader
|
16 |
-
from model import Decoder, Encoder, EncoderDecoderModel
|
17 |
-
from train import train_network
|
18 |
-
|
19 |
-
# logging INFO, WARNING, ERROR, CRITICAL, DEBUG
|
20 |
-
logging.basicConfig(level=logging.INFO)
|
21 |
-
logging.disable(level=10)
|
22 |
-
|
23 |
-
import os
|
24 |
-
|
25 |
-
os.environ[
|
26 |
-
"CUBLAS_WORKSPACE_CONFIG"
|
27 |
-
] = ":16:8" # pour que ça marche en deterministe sur mon pc boulot
|
28 |
-
# variable environnement dans git bash export CUBLAS_WORKSPACE_CONFIG=:16:8
|
29 |
-
# from datasets import load_dataset
|
30 |
-
|
31 |
-
### OPEN DATASET###
|
32 |
-
# dataset = load_dataset("newsroom", data_dir=DATA_PATH, data_files="data/train.jsonl")
|
33 |
-
|
34 |
-
data1 = dataloader.Data("data/train_extract.jsonl")
|
35 |
-
data2 = dataloader.Data("data/dev_extract.jsonl")
|
36 |
-
train_dataset = data1.make_dataset()
|
37 |
-
dev_dataset = data2.make_dataset()
|
38 |
-
words = data1.get_words()
|
39 |
-
|
40 |
-
vectoriser = dataloader.Vectoriser(words)
|
41 |
-
word_counts = vectoriser.word_count
|
42 |
-
|
43 |
-
|
44 |
-
def predict(model, tokens: Sequence[str]) -> Sequence[str]:
|
45 |
-
"""Predict the POS for a tokenized sequence"""
|
46 |
-
words_idx = vectoriser.encode(tokens).to(device)
|
47 |
-
# Pas de calcul de gradient ici : c'est juste pour les prédictions
|
48 |
-
with torch.no_grad():
|
49 |
-
# equivalent to model(input) when called out of class
|
50 |
-
out = model(words_idx).to(device)
|
51 |
-
out_predictions = out.to(device)
|
52 |
-
return vectoriser.decode(out_predictions)
|
53 |
-
|
54 |
-
|
55 |
-
if __name__ == "__main__":
|
56 |
-
### NEURAL NETWORK ###
|
57 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
58 |
-
print("Device check. You are using:", device)
|
59 |
-
|
60 |
-
### RÉSEAU ENTRAÎNÉ ###
|
61 |
-
# Pour s'assurer que les résultats seront les mêmes à chaque run du notebook
|
62 |
-
torch.use_deterministic_algorithms(True)
|
63 |
-
torch.manual_seed(0)
|
64 |
-
random.seed(0)
|
65 |
-
|
66 |
-
# On peut également entraîner encoder séparemment
|
67 |
-
encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
|
68 |
-
decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
|
69 |
-
# S'ils sont entraînés, on peut les sauvegarder
|
70 |
-
torch.save(encoder.state_dict(), "model/encoder.pt")
|
71 |
-
torch.save(encoder.state_dict(), "model/encoder.pt")
|
72 |
-
|
73 |
-
trained_classifier = EncoderDecoderModel(encoder, decoder, device).to(device)
|
74 |
-
|
75 |
-
print(next(trained_classifier.parameters()).device)
|
76 |
-
# print(train_dataset.is_cuda)
|
77 |
-
|
78 |
-
train_network(
|
79 |
-
trained_classifier,
|
80 |
-
[vectoriser.vectorize(row) for index, row in train_dataset.iterrows()],
|
81 |
-
[vectoriser.vectorize(row) for index, row in dev_dataset.iterrows()],
|
82 |
-
5,
|
83 |
-
)
|
84 |
-
|
85 |
-
torch.save(trained_classifier.state_dict(), "model/model.pt")
|
86 |
-
|
87 |
-
print(f'test text : {dev_dataset.iloc[6]["summary"]}')
|
88 |
-
print(
|
89 |
-
f'test prediction : {predict(trained_classifier, dev_dataset.iloc[6]["text"])}'
|
90 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/train.py
CHANGED
@@ -3,21 +3,19 @@ Training the network
|
|
3 |
"""
|
4 |
import datetime
|
5 |
import logging
|
|
|
6 |
import time
|
7 |
from typing import Sequence, Tuple
|
8 |
|
9 |
import torch
|
10 |
|
11 |
import dataloader
|
|
|
12 |
|
13 |
# logging INFO, WARNING, ERROR, CRITICAL, DEBUG
|
14 |
logging.basicConfig(level=logging.INFO)
|
15 |
logging.disable(level=10)
|
16 |
|
17 |
-
data1 = dataloader.Data("data/train_extract.jsonl")
|
18 |
-
words = data1.get_words()
|
19 |
-
vectoriser = dataloader.Vectoriser(words)
|
20 |
-
|
21 |
|
22 |
def train_network(
|
23 |
model: torch.nn.Module,
|
@@ -47,7 +45,6 @@ def train_network(
|
|
47 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
48 |
model = model.to(device)
|
49 |
print("Device check. You are using:", model.device)
|
50 |
-
model.train()
|
51 |
|
52 |
# with torch.no_grad():
|
53 |
|
@@ -81,10 +78,12 @@ def train_network(
|
|
81 |
|
82 |
out = model(source).to(device)
|
83 |
logging.debug(f"outputs = {out.shape}")
|
|
|
84 |
target = torch.nn.functional.pad(
|
85 |
target, (0, len(out) - len(target)), value=-100
|
86 |
)
|
87 |
-
|
|
|
88 |
loss = torch.nn.functional.nll_loss(out, target).to(device)
|
89 |
loss.backward()
|
90 |
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
|
@@ -131,3 +130,73 @@ def train_network(
|
|
131 |
print(
|
132 |
f"{epoch_n}\t{epoch_loss/epoch_length:.5}\t{abs(dev_correct/dev_total):.2%}\t\t{datetime.timedelta(seconds=epoch_compute_time)}"
|
133 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
"""
|
4 |
import datetime
|
5 |
import logging
|
6 |
+
import random
|
7 |
import time
|
8 |
from typing import Sequence, Tuple
|
9 |
|
10 |
import torch
|
11 |
|
12 |
import dataloader
|
13 |
+
from model import Decoder, Encoder, EncoderDecoderModel
|
14 |
|
15 |
# logging INFO, WARNING, ERROR, CRITICAL, DEBUG
|
16 |
logging.basicConfig(level=logging.INFO)
|
17 |
logging.disable(level=10)
|
18 |
|
|
|
|
|
|
|
|
|
19 |
|
20 |
def train_network(
|
21 |
model: torch.nn.Module,
|
|
|
45 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
46 |
model = model.to(device)
|
47 |
print("Device check. You are using:", model.device)
|
|
|
48 |
|
49 |
# with torch.no_grad():
|
50 |
|
|
|
78 |
|
79 |
out = model(source).to(device)
|
80 |
logging.debug(f"outputs = {out.shape}")
|
81 |
+
|
82 |
target = torch.nn.functional.pad(
|
83 |
target, (0, len(out) - len(target)), value=-100
|
84 |
)
|
85 |
+
|
86 |
+
# logging.debug(f"prediction : {vectoriser.decode(output_predictions)}")
|
87 |
loss = torch.nn.functional.nll_loss(out, target).to(device)
|
88 |
loss.backward()
|
89 |
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
|
|
|
130 |
print(
|
131 |
f"{epoch_n}\t{epoch_loss/epoch_length:.5}\t{abs(dev_correct/dev_total):.2%}\t\t{datetime.timedelta(seconds=epoch_compute_time)}"
|
132 |
)
|
133 |
+
|
134 |
+
|
135 |
+
def predict(model, tokens: Sequence[str]) -> Sequence[str]:
|
136 |
+
"""Predict the POS for a tokenized sequence"""
|
137 |
+
words_idx = vectoriser.encode(tokens).to(device)
|
138 |
+
# Pas de calcul de gradient ici : c'est juste pour les prédictions
|
139 |
+
with torch.no_grad():
|
140 |
+
# equivalent to model(input) when called out of class
|
141 |
+
out = model(words_idx).to(device)
|
142 |
+
out_predictions = out.to(device)
|
143 |
+
print(out_predictions)
|
144 |
+
out_predictions = out_predictions.argmax(dim=-1)
|
145 |
+
return vectoriser.decode(out_predictions)
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
train_dataset = dataloader.Data("data/train_extract.jsonl")
|
150 |
+
words = train_dataset.get_words()
|
151 |
+
vectoriser = dataloader.Vectoriser(words)
|
152 |
+
|
153 |
+
train_dataset = dataloader.Data("data/train_extract.jsonl", transform=vectoriser)
|
154 |
+
dev_dataset = dataloader.Data("data/dev_extract.jsonl", transform=vectoriser)
|
155 |
+
|
156 |
+
train_dataloader = torch.utils.data.DataLoader(
|
157 |
+
train_dataset, batch_size=2, shuffle=True, collate_fn=dataloader.pad_collate
|
158 |
+
)
|
159 |
+
|
160 |
+
dev_dataloader = torch.utils.data.DataLoader(
|
161 |
+
dev_dataset, batch_size=4, shuffle=True, collate_fn=dataloader.pad_collate
|
162 |
+
)
|
163 |
+
|
164 |
+
for i_batch, batch in enumerate(train_dataloader):
|
165 |
+
print(i_batch, batch[0], batch[1])
|
166 |
+
|
167 |
+
### NEURAL NETWORK ###
|
168 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
169 |
+
print("Device check. You are using:", device)
|
170 |
+
|
171 |
+
### RÉSEAU ENTRAÎNÉ ###
|
172 |
+
# Pour s'assurer que les résultats seront les mêmes à chaque run du notebook
|
173 |
+
torch.use_deterministic_algorithms(True)
|
174 |
+
torch.manual_seed(0)
|
175 |
+
random.seed(0)
|
176 |
+
|
177 |
+
# On peut également entraîner encoder séparemment
|
178 |
+
encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
|
179 |
+
decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
|
180 |
+
|
181 |
+
trained_classifier = EncoderDecoderModel(encoder, decoder, vectoriser, device).to(
|
182 |
+
device
|
183 |
+
)
|
184 |
+
|
185 |
+
print(next(trained_classifier.parameters()).device)
|
186 |
+
# print(train_dataset.is_cuda)
|
187 |
+
|
188 |
+
train_network(
|
189 |
+
trained_classifier,
|
190 |
+
train_dataset,
|
191 |
+
dev_dataset,
|
192 |
+
2,
|
193 |
+
)
|
194 |
+
|
195 |
+
torch.save(trained_classifier.state_dict(), "model/model.pt")
|
196 |
+
vectoriser.save("model/vocab.pkl")
|
197 |
+
trained_classifier.push_to_hub("SummaryProject-LSTM")
|
198 |
+
|
199 |
+
print(f"test summary : {vectoriser.decode(dev_dataset[6][1])}")
|
200 |
+
print(
|
201 |
+
f"test prediction : {predict(trained_classifier, vectoriser.decode(dev_dataset[6][0]))}"
|
202 |
+
)
|
templates/index.html.jinja
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
<title>Text summarization API</title>
|
5 |
<meta charset="utf-8" />
|
6 |
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=no" />
|
7 |
-
<link rel="stylesheet" href="{{ url_for('templates', path='site_style/css/main.css') }}" />
|
8 |
<script>
|
9 |
function customReset()
|
10 |
{
|
@@ -13,6 +13,23 @@
|
|
13 |
document.getElementById("summary").value = "";
|
14 |
}
|
15 |
</script>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
</head>
|
17 |
<body>
|
18 |
<div id="header">
|
@@ -28,18 +45,21 @@
|
|
28 |
</nav>
|
29 |
|
30 |
<div class="choixModel">
|
31 |
-
<
|
32 |
-
|
33 |
-
<
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
36 |
</div>
|
37 |
|
38 |
<div>
|
39 |
<table>
|
40 |
<tr>
|
41 |
<td>
|
42 |
-
<form id = "my_form" action="/" method="post" class="formulaire">
|
43 |
<textarea id="text" name="text" placeholder="Enter your text here!" rows="15" cols="75">{{text}}</textarea>
|
44 |
<input type="hidden" name="textarea_value" value="{{ text }}">
|
45 |
</form>
|
@@ -51,8 +71,9 @@
|
|
51 |
</table>
|
52 |
</div>
|
53 |
<div class="buttons">
|
54 |
-
|
55 |
-
|
|
|
56 |
</div>
|
57 |
|
58 |
<div class="copyright">
|
|
|
4 |
<title>Text summarization API</title>
|
5 |
<meta charset="utf-8" />
|
6 |
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=no" />
|
7 |
+
<link rel="stylesheet" href="{{ url_for('templates', path='templates/site_style/css/main.css') }}" />
|
8 |
<script>
|
9 |
function customReset()
|
10 |
{
|
|
|
13 |
document.getElementById("summary").value = "";
|
14 |
}
|
15 |
</script>
|
16 |
+
<script>
|
17 |
+
function submitBothForms()
|
18 |
+
{
|
19 |
+
document.getElementById("my_form").submit();
|
20 |
+
document.getElementById("choixModel").submit();
|
21 |
+
}
|
22 |
+
</script>
|
23 |
+
<script>
|
24 |
+
function getValue() {
|
25 |
+
var e = document.getElementById("choixModel");
|
26 |
+
var value = e.value;
|
27 |
+
var text = e.options[e.selectedIndex].text;
|
28 |
+
return text}
|
29 |
+
</script>
|
30 |
+
<script type="text/javascript">
|
31 |
+
document.getElementById('choixModel').value = "<?php echo $_GET['choixModel'];?>";
|
32 |
+
</script>
|
33 |
</head>
|
34 |
<body>
|
35 |
<div id="header">
|
|
|
45 |
</nav>
|
46 |
|
47 |
<div class="choixModel">
|
48 |
+
<form id="choixModel" method="post" action="/model">
|
49 |
+
<label for="selectModel">Choose a model :</label>
|
50 |
+
<select name="choixModel" class="selectModel" id="choixModel">
|
51 |
+
<option value="lstm">LSTM</option>
|
52 |
+
<option value="fineTunedT5">Fine-tuned T5</option>
|
53 |
+
</select>
|
54 |
+
</form>
|
55 |
+
<button form ="choixModel" class='search_bn' type="submit" class="btn btn-primary btn-block btn-large" rows="1" cols="50">Select model</button>
|
56 |
</div>
|
57 |
|
58 |
<div>
|
59 |
<table>
|
60 |
<tr>
|
61 |
<td>
|
62 |
+
<form id = "my_form" action="/predict" method="post" class="formulaire">
|
63 |
<textarea id="text" name="text" placeholder="Enter your text here!" rows="15" cols="75">{{text}}</textarea>
|
64 |
<input type="hidden" name="textarea_value" value="{{ text }}">
|
65 |
</form>
|
|
|
71 |
</table>
|
72 |
</div>
|
73 |
<div class="buttons">
|
74 |
+
<!-- <button id="submit" type="submit" onclick=submitBothForms()>SUBMIT</button> -->
|
75 |
+
<button form ="my_form" class='search_bn' type="submit" class="btn btn-primary btn-block btn-large" rows="1" cols="50">Go !</button>
|
76 |
+
<button form ="my_form" type="button" value="Reset" onclick="customReset();">Reset</button>
|
77 |
</div>
|
78 |
|
79 |
<div class="copyright">
|