EveSa commited on
Commit
4874293
·
unverified ·
2 Parent(s): 5925e5f dfd43d0

Merge branch 'main' into Ling

Browse files
.gitignore CHANGED
@@ -1,6 +1,6 @@
1
  .venv/**
2
  data/**
3
  src/__pycache__
4
- model/**
5
  html5up-helios/**
6
  **/__pycache__/**
 
1
  .venv/**
2
  data/**
3
  src/__pycache__
4
+ model/model.pt
5
  html5up-helios/**
6
  **/__pycache__/**
Dockerfile CHANGED
@@ -8,4 +8,6 @@ RUN pip install --no-cache-dir --upgrade -r requirements.txt
8
 
9
  COPY . .
10
 
11
- CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
 
8
 
9
  COPY . .
10
 
11
+ CMD ["uvicorn", "src.api:app", "--host", "0.0.0.0", "--port", "7860"]
12
+
13
+ #CMD python3 -m uvicorn --app-dir ./src api:app --host 0.0.0.0 --port 3001
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  title: SummaryProject
3
  sdk: docker
4
- app_file: app.py
5
  pinned: false
6
  ---
7
  # Initialisation
 
1
  ---
2
  title: SummaryProject
3
  sdk: docker
4
+ app_file: src/app.py
5
  pinned: false
6
  ---
7
  # Initialisation
api.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+ from fastapi import FastAPI, Form, Request
3
+ from fastapi.staticfiles import StaticFiles
4
+ from fastapi.templating import Jinja2Templates
5
+
6
+ from inference import inferenceAPI
7
+
8
+
9
+ # ------ MODELE --------------------------------------------------------------
10
+ # appel de la fonction inference, adaptee pour une entree txt
11
+ def summarize(text: str):
12
+ return " ".join(inferenceAPI(text))
13
+
14
+
15
+ # ----------------------------------------------------------------------------------
16
+
17
+
18
+ # -------- API ---------------------------------------------------------------------
19
+ app = FastAPI()
20
+
21
+ # static files pour envoi du css au navigateur
22
+ templates = Jinja2Templates(directory="templates")
23
+ app.mount("/", StaticFiles(directory="templates", html=True), name="templates")
24
+
25
+
26
+ @app.get("/")
27
+ async def index(request: Request):
28
+ return templates.TemplateResponse("index.html.jinja", {"request": request})
29
+
30
+
31
+ # retourner le texte, les predictions et message d'erreur si formulaire envoye vide
32
+ @app.post("/")
33
+ async def prediction(request: Request, text: str = Form(None)):
34
+ if not text:
35
+ error = "Merci de saisir votre texte."
36
+ return templates.TemplateResponse(
37
+ "index.html.jinja", {"request": request, "text": error}
38
+ )
39
+ else:
40
+ summary = summarize(text)
41
+ return templates.TemplateResponse(
42
+ "index.html.jinja", {"request": request, "text": text, "summary": summary}
43
+ )
44
+
45
+
46
+ # ------------------------------------------------------------------------------------
47
+
48
+
49
+ # lancer le serveur et le recharger a chaque modification sauvegardee
50
+ # if __name__ == "__main__":
51
+ # uvicorn.run("api:app", port=8000, reload=True)
model/vocab.pkl ADDED
Binary file (63.4 kB). View file
 
requirements.txt CHANGED
@@ -6,7 +6,6 @@ anyascii==0.3.1
6
  anyio==3.6.2
7
  async-timeout==4.0.2
8
  attrs==22.2.0
9
- autopep8==2.0.2
10
  banal==1.0.6
11
  blis==0.7.9
12
  catalogue==2.0.8
@@ -33,6 +32,11 @@ fsspec==2023.3.0
33
  greenlet==2.0.2
34
  h11==0.14.0
35
  huggingface-hub==0.12.1
 
 
 
 
 
36
  idna==3.4
37
  importlib-metadata==6.0.0
38
  importlib-resources==5.12.0
@@ -47,7 +51,6 @@ mccabe==0.7.0
47
  multidict==6.0.4
48
  multiprocess==0.70.14
49
  murmurhash==1.0.9
50
- nltk==3.8.1
51
  numpy==1.24.2
52
  nvidia-cublas-cu11==11.10.3.66
53
  nvidia-cuda-nvrtc-cu11==11.7.99
@@ -77,7 +80,6 @@ scikit-learn==1.2.1
77
  scipy==1.10.0
78
  sentencepiece==0.1.97
79
  six==1.16.0
80
- sklearn==0.0.post1
81
  smart-open==6.3.0
82
  sniffio==1.3.0
83
  spacy==3.5.0
@@ -98,6 +100,12 @@ transformers==4.26.1
98
  typer==0.7.0
99
  typing-extensions==4.4.0
100
  urllib3==1.26.14
 
 
 
 
 
 
101
  uvicorn==0.20.0
102
  wasabi==1.1.1
103
  xxhash==3.2.0
 
6
  anyio==3.6.2
7
  async-timeout==4.0.2
8
  attrs==22.2.0
 
9
  banal==1.0.6
10
  blis==0.7.9
11
  catalogue==2.0.8
 
32
  greenlet==2.0.2
33
  h11==0.14.0
34
  huggingface-hub==0.12.1
35
+ certifi==2022.12.7
36
+ charset-normalizer==3.1.0
37
+ click==8.1.3
38
+ fastapi==0.92.0
39
+ filelock==3.9.0
40
  idna==3.4
41
  importlib-metadata==6.0.0
42
  importlib-resources==5.12.0
 
51
  multidict==6.0.4
52
  multiprocess==0.70.14
53
  murmurhash==1.0.9
 
54
  numpy==1.24.2
55
  nvidia-cublas-cu11==11.10.3.66
56
  nvidia-cuda-nvrtc-cu11==11.7.99
 
80
  scipy==1.10.0
81
  sentencepiece==0.1.97
82
  six==1.16.0
 
83
  smart-open==6.3.0
84
  sniffio==1.3.0
85
  spacy==3.5.0
 
100
  typer==0.7.0
101
  typing-extensions==4.4.0
102
  urllib3==1.26.14
103
+ starlette==0.25.0
104
+ tokenizers==0.13.2
105
+ torch==1.13.1
106
+ tqdm==4.65.0
107
+ typing_extensions==4.5.0
108
+ urllib3==1.26.15
109
  uvicorn==0.20.0
110
  wasabi==1.1.1
111
  xxhash==3.2.0
src/api.py CHANGED
@@ -2,17 +2,33 @@ import uvicorn
2
  from fastapi import FastAPI, Form, Request
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.templating import Jinja2Templates
 
5
 
6
- from inference import inferenceAPI
 
7
 
8
 
9
- # ------ MODELE --------------------------------------------------------------
10
  # appel de la fonction inference, adaptee pour une entree txt
11
  def summarize(text: str):
12
- return " ".join(inferenceAPI(text))
 
 
 
 
13
  # ----------------------------------------------------------------------------------
14
 
15
 
 
 
 
 
 
 
 
 
 
 
16
  # -------- API ---------------------------------------------------------------------
17
  app = FastAPI()
18
 
@@ -20,26 +36,54 @@ app = FastAPI()
20
  templates = Jinja2Templates(directory="templates")
21
  app.mount("/templates", StaticFiles(directory="templates"), name="templates")
22
 
 
23
  @app.get("/")
24
  async def index(request: Request):
25
  return templates.TemplateResponse("index.html.jinja", {"request": request})
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # retourner le texte, les predictions et message d'erreur si formulaire envoye vide
28
- @app.post("/")
29
  async def prediction(request: Request, text: str = Form(None)):
30
- if not text :
31
  error = "Merci de saisir votre texte."
32
  return templates.TemplateResponse(
33
- "index.html.jinja", {"request": request, "text": error}
34
- )
35
- else :
36
  summary = summarize(text)
37
  return templates.TemplateResponse(
38
  "index.html.jinja", {"request": request, "text": text, "summary": summary}
39
  )
 
 
40
  # ------------------------------------------------------------------------------------
41
 
42
 
43
  # lancer le serveur et le recharger a chaque modification sauvegardee
44
- if __name__ == "__main__":
45
- uvicorn.run("api:app", port=8000, reload=True)
 
2
  from fastapi import FastAPI, Form, Request
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.templating import Jinja2Templates
5
+ import re
6
 
7
+ from src.inference import inferenceAPI
8
+ from src.inference_t5 import inferenceAPI_t5
9
 
10
 
11
+ # ------ INFERENCE MODEL --------------------------------------------------------------
12
  # appel de la fonction inference, adaptee pour une entree txt
13
  def summarize(text: str):
14
+ if choisir_modele.var == 'lstm' :
15
+ return " ".join(inferenceAPI(text))
16
+ elif choisir_modele.var == "fineTunedT5":
17
+ text = inferenceAPI_t5(text)
18
+
19
  # ----------------------------------------------------------------------------------
20
 
21
 
22
+ def choisir_modele(choixModele):
23
+ print("ON A RECUP LE CHOIX MODELE")
24
+ if choixModele == "lstm" :
25
+ choisir_modele.var ='lstm'
26
+ elif choixModele == "fineTunedT5":
27
+ choisir_modele.var = "fineTunedT5"
28
+ else :
29
+ "le modele n'est pas defini"
30
+
31
+
32
  # -------- API ---------------------------------------------------------------------
33
  app = FastAPI()
34
 
 
36
  templates = Jinja2Templates(directory="templates")
37
  app.mount("/templates", StaticFiles(directory="templates"), name="templates")
38
 
39
+
40
  @app.get("/")
41
  async def index(request: Request):
42
  return templates.TemplateResponse("index.html.jinja", {"request": request})
43
 
44
+ @app.get("/model")
45
+ async def index(request: Request):
46
+ return templates.TemplateResponse("index.html.jinja", {"request": request})
47
+
48
+ @app.get("/predict")
49
+ async def index(request: Request):
50
+ return templates.TemplateResponse("index.html.jinja", {"request": request})
51
+
52
+
53
+ @app.post("/model")
54
+ async def choix_model(request: Request, choixModel:str = Form(None)):
55
+ print(choixModel)
56
+ if not choixModel:
57
+ erreur_modele = "Merci de saisir un modèle."
58
+ return templates.TemplateResponse(
59
+ "index.html.jinja", {"request": request, "text": erreur_modele}
60
+ )
61
+ else :
62
+ choisir_modele(choixModel)
63
+ print("C'est bon on utilise le modèle demandé")
64
+ return templates.TemplateResponse(
65
+ "index.html.jinja", {"request": request}
66
+ )
67
+
68
+
69
  # retourner le texte, les predictions et message d'erreur si formulaire envoye vide
70
+ @app.post("/predict")
71
  async def prediction(request: Request, text: str = Form(None)):
72
+ if not text:
73
  error = "Merci de saisir votre texte."
74
  return templates.TemplateResponse(
75
+ "index.html.jinja", {"request": request, "text": error}
76
+ )
77
+ else:
78
  summary = summarize(text)
79
  return templates.TemplateResponse(
80
  "index.html.jinja", {"request": request, "text": text, "summary": summary}
81
  )
82
+
83
+
84
  # ------------------------------------------------------------------------------------
85
 
86
 
87
  # lancer le serveur et le recharger a chaque modification sauvegardee
88
+ # if __name__ == "__main__":
89
+ # uvicorn.run("api:app", port=8000, reload=True)
src/dataloader.py CHANGED
@@ -11,17 +11,15 @@
11
  Création d'un Vectoriserà partir du vocabulaire :
12
 
13
  """
 
14
  import string
15
  from collections import Counter
16
 
17
  import pandas as pd
18
  import torch
19
- from nltk import word_tokenize
20
 
21
- # nltk.download('punkt')
22
 
23
-
24
- class Data:
25
  """
26
  A class used to get data from file
27
  ...
@@ -44,8 +42,27 @@ class Data:
44
  create a dataset with cleaned data
45
  """
46
 
47
- def __init__(self, path: str) -> None:
48
  self.path = path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  def open(self) -> pd.DataFrame:
51
  """
@@ -85,26 +102,13 @@ class Data:
85
  # - s'occuper des noms propres (mots commençant par une majuscule qui se suivent)
86
  for text in texts:
87
  text = text.translate(str.maketrans("", "", string.punctuation))
88
- text = word_tokenize(text)
89
  tokenized_texts.append(text)
90
 
91
  if text_type == "summary":
92
  return [["<start>", *summary, "<end>"] for summary in tokenized_texts]
93
  return tokenized_texts
94
 
95
- def pad_sequence(self):
96
- """
97
- pad summary with empty token
98
- """
99
- texts = self.clean_data("text")
100
- summaries = self.clean_data("summary")
101
- padded_summary = []
102
- for text, summary in zip(texts, summaries):
103
- if len(summary) != len(text):
104
- summary += ["<empty>"] * (len(text) - len(summary))
105
- padded_summary.append(summary)
106
- return texts, padded_summary
107
-
108
  def get_words(self) -> list:
109
  """
110
  Create a dictionnary of the data vocabulary
@@ -114,15 +118,20 @@ class Data:
114
  summary_words = [word for text in summaries for word in text]
115
  return text_words + summary_words
116
 
117
- def make_dataset(self) -> pd.DataFrame:
118
- """
119
- Create a Pandas Dataframe with cleaned data
120
- --------------------
121
- param: self: Data
122
- return: pd.DataFrame
123
- """
124
- texts, summaries = self.clean_data("text"), self.clean_data("summary")
125
- return pd.DataFrame(list(zip(texts, summaries)), columns=["text", "summary"])
 
 
 
 
 
126
 
127
 
128
  class Vectoriser:
@@ -146,12 +155,25 @@ class Vectoriser:
146
  encode an entire row from the dataset
147
  """
148
 
149
- def __init__(self, vocab) -> None:
150
  self.vocab = vocab
151
  self.word_count = Counter(word.lower().strip(",.\\-") for word in self.vocab)
152
  self.idx_to_token = sorted([t for t, c in self.word_count.items() if c > 1])
153
  self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  def encode(self, tokens) -> torch.tensor:
156
  """
157
  Encode une phrase selon les mots qu'elle contient
@@ -165,7 +187,7 @@ class Vectoriser:
165
  :return: words_idx : tensor
166
  Un tensor contenant les index des mots de la phrase
167
  """
168
- if type(tokens) == list:
169
  words_idx = torch.tensor(
170
  [
171
  self.token_to_idx.get(t.lower(), len(self.token_to_idx))
@@ -175,7 +197,7 @@ class Vectoriser:
175
  )
176
 
177
  # Permet d'encoder mots par mots
178
- elif type(tokens) == str:
179
  words_idx = torch.tensor(self.token_to_idx.get(tokens.lower()))
180
 
181
  return words_idx
@@ -184,9 +206,9 @@ class Vectoriser:
184
  """
185
  Decode une phrase selon le procédé inverse que la fonction encode
186
  """
187
- words_idx_tensor = words_idx_tensor.argmax(dim=-1)
188
  idxs = words_idx_tensor.tolist()
189
- if type(idxs) == int:
190
  words = [self.idx_to_token[idxs]]
191
  else:
192
  words = []
@@ -195,10 +217,7 @@ class Vectoriser:
195
  words.append(self.idx_to_token[idx])
196
  return words
197
 
198
- def beam_search(self, words_idx_tensor) -> list:
199
- pass
200
-
201
- def vectorize(self, row) -> torch.tensor:
202
  """
203
  Encode les données d'une ligne du dataframe
204
  ----------
 
11
  Création d'un Vectoriserà partir du vocabulaire :
12
 
13
  """
14
+ import pickle
15
  import string
16
  from collections import Counter
17
 
18
  import pandas as pd
19
  import torch
 
20
 
 
21
 
22
+ class Data(torch.utils.data.Dataset):
 
23
  """
24
  A class used to get data from file
25
  ...
 
42
  create a dataset with cleaned data
43
  """
44
 
45
+ def __init__(self, path: str, transform=None) -> None:
46
  self.path = path
47
+ self.data = pd.read_json(path_or_buf=self.path, lines=True)
48
+ self.transform = transform
49
+
50
+ def __len__(self):
51
+ return len(self.data)
52
+
53
+ def __getitem__(self, idx):
54
+ row = self.data.iloc[idx]
55
+ text = row["text"].translate(str.maketrans("", "", string.punctuation)).split()
56
+ summary = (
57
+ row["summary"].translate(str.maketrans("", "", string.punctuation)).split()
58
+ )
59
+ summary = ["<start>", *summary, "<end>"]
60
+ sample = {"text": text, "summary": summary}
61
+
62
+ if self.transform:
63
+ sample = self.transform(sample)
64
+
65
+ return sample
66
 
67
  def open(self) -> pd.DataFrame:
68
  """
 
102
  # - s'occuper des noms propres (mots commençant par une majuscule qui se suivent)
103
  for text in texts:
104
  text = text.translate(str.maketrans("", "", string.punctuation))
105
+ text = text.split()
106
  tokenized_texts.append(text)
107
 
108
  if text_type == "summary":
109
  return [["<start>", *summary, "<end>"] for summary in tokenized_texts]
110
  return tokenized_texts
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def get_words(self) -> list:
113
  """
114
  Create a dictionnary of the data vocabulary
 
118
  summary_words = [word for text in summaries for word in text]
119
  return text_words + summary_words
120
 
121
+
122
+ def pad_collate(data):
123
+ text_batch = [element[0] for element in data]
124
+ summary_batch = [element[1] for element in data]
125
+ max_len = max([len(element) for element in summary_batch + text_batch])
126
+ text_batch = [
127
+ torch.nn.functional.pad(element, (0, max_len - len(element)), value=-100)
128
+ for element in text_batch
129
+ ]
130
+ summary_batch = [
131
+ torch.nn.functional.pad(element, (0, max_len - len(element)), value=-100)
132
+ for element in summary_batch
133
+ ]
134
+ return text_batch, summary_batch
135
 
136
 
137
  class Vectoriser:
 
155
  encode an entire row from the dataset
156
  """
157
 
158
+ def __init__(self, vocab=None) -> None:
159
  self.vocab = vocab
160
  self.word_count = Counter(word.lower().strip(",.\\-") for word in self.vocab)
161
  self.idx_to_token = sorted([t for t, c in self.word_count.items() if c > 1])
162
  self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
163
 
164
+ def load(self, path):
165
+ with open(path, "rb") as file:
166
+ self.vocab = pickle.load(file)
167
+ self.word_count = Counter(
168
+ word.lower().strip(",.\\-") for word in self.vocab
169
+ )
170
+ self.idx_to_token = sorted([t for t, c in self.word_count.items() if c > 1])
171
+ self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
172
+
173
+ def save(self, path):
174
+ with open(path, "wb") as file:
175
+ pickle.dump(self.vocab, file)
176
+
177
  def encode(self, tokens) -> torch.tensor:
178
  """
179
  Encode une phrase selon les mots qu'elle contient
 
187
  :return: words_idx : tensor
188
  Un tensor contenant les index des mots de la phrase
189
  """
190
+ if isinstance(tokens, list):
191
  words_idx = torch.tensor(
192
  [
193
  self.token_to_idx.get(t.lower(), len(self.token_to_idx))
 
197
  )
198
 
199
  # Permet d'encoder mots par mots
200
+ elif isinstance(tokens, str):
201
  words_idx = torch.tensor(self.token_to_idx.get(tokens.lower()))
202
 
203
  return words_idx
 
206
  """
207
  Decode une phrase selon le procédé inverse que la fonction encode
208
  """
209
+
210
  idxs = words_idx_tensor.tolist()
211
+ if isinstance(idxs, int):
212
  words = [self.idx_to_token[idxs]]
213
  else:
214
  words = []
 
217
  words.append(self.idx_to_token[idx])
218
  return words
219
 
220
+ def __call__(self, row) -> torch.tensor:
 
 
 
221
  """
222
  Encode les données d'une ligne du dataframe
223
  ----------
src/fine_tune_t5.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import datasets
3
+ from datasets import Dataset, DatasetDict
4
+ import pandas as pd
5
+ from tqdm import tqdm
6
+ import re
7
+ import os
8
+ import nltk
9
+ import string
10
+ import contractions
11
+ from transformers import pipeline
12
+ import evaluate
13
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer,AutoConfig
14
+ from transformers import Seq2SeqTrainingArguments ,Seq2SeqTrainer
15
+ from transformers import DataCollatorForSeq2Seq
16
+
17
+ # cuda out of memory
18
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:200"
19
+
20
+ nltk.download('stopwords')
21
+ nltk.download('punkt')
22
+
23
+
24
+ def clean_data(texts):
25
+ texts = texts.lower()
26
+ texts = contractions.fix(texts)
27
+ texts = texts.translate(str.maketrans("", "", string.punctuation))
28
+ texts = re.sub(r'\n',' ',texts)
29
+ return texts
30
+
31
+ def datasetmaker (path=str):
32
+ data = pd.read_json(path, lines=True)
33
+ df = data.drop(['url','archive','title','date','compression','coverage','density','compression_bin','coverage_bin','density_bin'],axis=1)
34
+ tqdm.pandas()
35
+ df['text'] = df.text.apply(lambda texts : clean_data(texts))
36
+ df['summary'] = df.summary.apply(lambda summary : clean_data(summary))
37
+ # df['text'] = df['text'].map(str)
38
+ # df['summary'] = df['summary'].map(str)
39
+ dataset = Dataset.from_dict(df)
40
+ return dataset
41
+
42
+ #voir si le model par hasard esr déjà bien
43
+
44
+ # test_text = dataset['text'][0]
45
+ # pipe = pipeline('summarization',model = model_ckpt)
46
+ # pipe_out = pipe(test_text)
47
+ # print (pipe_out[0]['summary_text'].replace('.<n>','.\n'))
48
+ # print(dataset['summary'][0])
49
+
50
+ def generate_batch_sized_chunks(list_elements, batch_size):
51
+ """split the dataset into smaller batches that we can process simultaneously
52
+ Yield successive batch-sized chunks from list_of_elements."""
53
+ for i in range(0, len(list_elements), batch_size):
54
+ yield list_elements[i : i + batch_size]
55
+
56
+ def calculate_metric(dataset, metric, model, tokenizer,
57
+ batch_size, device,
58
+ column_text='text',
59
+ column_summary='summary'):
60
+ article_batches = list(str(generate_batch_sized_chunks(dataset[column_text], batch_size)))
61
+ target_batches = list(str(generate_batch_sized_chunks(dataset[column_summary], batch_size)))
62
+
63
+ for article_batch, target_batch in tqdm(
64
+ zip(article_batches, target_batches), total=len(article_batches)):
65
+
66
+ inputs = tokenizer(article_batch, max_length=1024, truncation=True,
67
+ padding="max_length", return_tensors="pt")
68
+
69
+ summaries = model.generate(input_ids=inputs["input_ids"].to(device),
70
+ attention_mask=inputs["attention_mask"].to(device),
71
+ length_penalty=0.8, num_beams=8, max_length=128)
72
+ ''' parameter for length penalty ensures that the model does not generate sequences that are too long. '''
73
+
74
+ # Décode les textes
75
+ # renplacer les tokens, ajouter des textes décodés avec les rédéfences vers la métrique.
76
+ decoded_summaries = [tokenizer.decode(s, skip_special_tokens=True,
77
+ clean_up_tokenization_spaces=True)
78
+ for s in summaries]
79
+
80
+ decoded_summaries = [d.replace("", " ") for d in decoded_summaries]
81
+
82
+
83
+ metric.add_batch(predictions=decoded_summaries, references=target_batch)
84
+
85
+ #compute et return les ROUGE scores.
86
+ results = metric.compute()
87
+ rouge_names = ['rouge1','rouge2','rougeL','rougeLsum']
88
+ rouge_dict = dict((rn, results[rn] ) for rn in rouge_names )
89
+ return pd.DataFrame(rouge_dict, index = ['T5'])
90
+
91
+
92
+ def convert_ex_to_features(example_batch):
93
+ input_encodings = tokenizer(example_batch['text'],max_length = 1024,truncation = True)
94
+
95
+ labels =tokenizer(example_batch['summary'], max_length = 128, truncation = True )
96
+
97
+ return {
98
+ 'input_ids' : input_encodings['input_ids'],
99
+ 'attention_mask': input_encodings['attention_mask'],
100
+ 'labels': labels['input_ids']
101
+ }
102
+
103
+ if __name__=='__main__':
104
+
105
+ train_dataset = datasetmaker('data/train_extract_100.jsonl')
106
+
107
+ dev_dataset = datasetmaker('data/dev_extract_100.jsonl')
108
+
109
+ test_dataset = datasetmaker('data/test_extract_100.jsonl')
110
+
111
+ dataset = datasets.DatasetDict({'train':train_dataset,'dev':dev_dataset ,'test':test_dataset})
112
+
113
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
+
115
+ tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
116
+ mt5_config = AutoConfig.from_pretrained(
117
+ "google/mt5-small",
118
+ max_length=128,
119
+ length_penalty=0.6,
120
+ no_repeat_ngram_size=2,
121
+ num_beams=15,
122
+ )
123
+ model = (AutoModelForSeq2SeqLM
124
+ .from_pretrained("google/mt5-small", config=mt5_config)
125
+ .to(device))
126
+
127
+ dataset_pt= dataset.map(convert_ex_to_features,remove_columns=["summary", "text"],batched = True,batch_size=128)
128
+
129
+ data_collator = DataCollatorForSeq2Seq(tokenizer, model=model,return_tensors="pt")
130
+
131
+
132
+ training_args = Seq2SeqTrainingArguments(
133
+ output_dir = "mt5_sum",
134
+ log_level = "error",
135
+ num_train_epochs = 10,
136
+ learning_rate = 5e-4,
137
+ # lr_scheduler_type = "linear",
138
+ warmup_steps = 0,
139
+ optim = "adafactor",
140
+ weight_decay = 0.01,
141
+ per_device_train_batch_size = 2,
142
+ per_device_eval_batch_size = 1,
143
+ gradient_accumulation_steps = 16,
144
+ evaluation_strategy = "steps",
145
+ eval_steps = 100,
146
+ predict_with_generate=True,
147
+ generation_max_length = 128,
148
+ save_steps = 500,
149
+ logging_steps = 10,
150
+ # push_to_hub = True
151
+ )
152
+
153
+
154
+ trainer = Seq2SeqTrainer(
155
+ model = model,
156
+ args = training_args,
157
+ data_collator = data_collator,
158
+ # compute_metrics = calculate_metric,
159
+ train_dataset=dataset_pt['train'],
160
+ eval_dataset=dataset_pt['dev'].select(range(10)),
161
+ tokenizer = tokenizer,
162
+ )
163
+
164
+ trainer.train()
165
+ rouge_metric = evaluate.load("rouge")
166
+
167
+ score = calculate_metric(test_dataset, rouge_metric, trainer.model, tokenizer,
168
+ batch_size=2, device=device,
169
+ column_text='text',
170
+ column_summary='summary')
171
+ print (score)
172
+
173
+
174
+ #Fine Tuning terminés et à sauvgarder
175
+
176
+
177
+
178
+ # save fine-tuned model in local
179
+ os.makedirs("./summarization_t5", exist_ok=True)
180
+ if hasattr(trainer.model, "module"):
181
+ trainer.model.module.save_pretrained("./summarization_t5")
182
+ else:
183
+ trainer.model.save_pretrained("./summarization_t5")
184
+ tokenizer.save_pretrained("./summarization_t5")
185
+ # load local model
186
+ model = (AutoModelForSeq2SeqLM
187
+ .from_pretrained("./summarization_t5")
188
+ .to(device))
189
+
190
+
191
+ # mettre en usage : TEST
192
+
193
+
194
+ # gen_kwargs = {"length_penalty": 0.8, "num_beams":8, "max_length": 128}
195
+ # sample_text = dataset["test"][0]["text"]
196
+ # reference = dataset["test"][0]["summary"]
197
+ # pipe = pipeline("summarization", model='./summarization_t5')
198
+
199
+ # print("Text:")
200
+ # print(sample_text)
201
+ # print("\nReference Summary:")
202
+ # print(reference)
203
+ # print("\nModel Summary:")
204
+ # print(pipe(sample_text, **gen_kwargs)[0]["summary_text"])
src/inference.py CHANGED
@@ -1,22 +1,16 @@
1
  """
2
  Allows to predict the summary for a given entry text
3
  """
4
- import torch
5
- from nltk import word_tokenize
6
 
7
- import dataloader
8
- from model import Decoder, Encoder, EncoderDecoderModel
9
 
10
- # On doit loader les données pour avoir le Vectoriser > sauvegarder "words" dans un fichiers et le loader par la suite ??
11
- ### À CHANGER POUR N'AVOIR À LOADER QUE LE VECTORISER
12
- data1 = dataloader.Data("data/train_extract.jsonl")
13
- data2 = dataloader.Data("data/dev_extract.jsonl")
14
- train_dataset = data1.make_dataset()
15
- dev_dataset = data2.make_dataset()
16
- words = data1.get_words()
17
 
 
 
18
  vectoriser = dataloader.Vectoriser(words)
19
- word_counts = vectoriser.word_count
20
 
21
 
22
  def inferenceAPI(text: str) -> str:
@@ -30,22 +24,20 @@ def inferenceAPI(text: str) -> str:
30
  str
31
  The summary for the input text
32
  """
33
- text = word_tokenize(text)
34
  # On défini les paramètres d'entrée pour le modèle
35
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
- encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device).to(
37
- device
38
- )
39
- decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device).to(
40
- device
41
- )
42
 
43
  # On instancie le modèle
44
- model = EncoderDecoderModel(encoder, decoder, device)
45
 
46
- model.load_state_dict(torch.load("model/model.pt", map_location=device))
47
- model.eval()
48
- model.to(device)
49
 
50
  # On vectorise le texte
51
  source = vectoriser.encode(text)
@@ -55,6 +47,7 @@ def inferenceAPI(text: str) -> str:
55
  with torch.no_grad():
56
  output = model(source).to(device)
57
  output.to(device)
 
58
  return vectoriser.decode(output)
59
 
60
 
 
1
  """
2
  Allows to predict the summary for a given entry text
3
  """
4
+ import pickle
 
5
 
6
+ import torch
 
7
 
8
+ from src import dataloader
9
+ from src.model import Decoder, Encoder, EncoderDecoderModel
 
 
 
 
 
10
 
11
+ with open("model/vocab.pkl", "rb") as vocab:
12
+ words = pickle.load(vocab)
13
  vectoriser = dataloader.Vectoriser(words)
 
14
 
15
 
16
  def inferenceAPI(text: str) -> str:
 
24
  str
25
  The summary for the input text
26
  """
27
+ text = text.split()
28
  # On défini les paramètres d'entrée pour le modèle
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
31
+ encoder.to(device)
32
+ decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
33
+ decoder.to(device)
 
 
34
 
35
  # On instancie le modèle
36
+ model = EncoderDecoderModel(encoder, decoder, vectoriser, device)
37
 
38
+ # model.load_state_dict(torch.load("model/model.pt", map_location=device))
39
+ # model.eval()
40
+ # model.to(device)
41
 
42
  # On vectorise le texte
43
  source = vectoriser.encode(text)
 
47
  with torch.no_grad():
48
  output = model(source).to(device)
49
  output.to(device)
50
+ output = output.argmax(dim=-1)
51
  return vectoriser.decode(output)
52
 
53
 
src/inference_t5.py CHANGED
@@ -27,6 +27,7 @@ def inferenceAPI(text: str) -> str:
27
  str
28
  The summary for the input text
29
  """
 
30
  # On défini les paramètres d'entrée pour le modèle
31
  text = clean_text(text)
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -35,6 +36,7 @@ def inferenceAPI(text: str) -> str:
35
  model = (AutoModelForSeq2SeqLM
36
  .from_pretrained("Linggg/t5_summary")
37
  .to(device))
 
38
  text_encoding = tokenizer(
39
  text,
40
  max_length=1024,
@@ -60,8 +62,6 @@ def inferenceAPI(text: str) -> str:
60
  return "".join(preds)
61
 
62
 
63
- if __name__ == "__main__":
64
- '''
65
- '''
66
- text = input('Entrez votre phrase à résumer : ')
67
- print('summary:', inferenceAPI(text))
 
27
  str
28
  The summary for the input text
29
  """
30
+
31
  # On défini les paramètres d'entrée pour le modèle
32
  text = clean_text(text)
33
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
36
  model = (AutoModelForSeq2SeqLM
37
  .from_pretrained("Linggg/t5_summary")
38
  .to(device))
39
+
40
  text_encoding = tokenizer(
41
  text,
42
  max_length=1024,
 
62
  return "".join(preds)
63
 
64
 
65
+ # if __name__ == "__main__":
66
+ # text = input('Entrez votre phrase à résumer : ')
67
+ # print('summary:', inferenceAPI(text))
 
 
src/model.py CHANGED
@@ -6,14 +6,8 @@ import logging
6
 
7
  import torch
8
 
9
- import dataloader
10
-
11
  logging.basicConfig(level=logging.DEBUG)
12
 
13
- data1 = dataloader.Data("data/train_extract.jsonl")
14
- words = data1.get_words()
15
- vectoriser = dataloader.Vectoriser(words)
16
-
17
 
18
  class Encoder(torch.nn.Module):
19
  def __init__(
@@ -86,51 +80,59 @@ class Decoder(torch.nn.Module):
86
 
87
 
88
  class EncoderDecoderModel(torch.nn.Module):
89
- def __init__(self, encoder, decoder, device):
90
  # Une idiosyncrasie de torch, pour qu'iel puisse faire sa magie
91
  super().__init__()
92
  self.encoder = encoder
93
  self.decoder = decoder
 
94
  self.device = device
95
 
96
- def forward(self, source, num_beams=3):
97
- # CHANGER LA TARGET LEN POUR QQCH DE MODULABLE
98
- target_len = int(1 * source.shape[0]) # Taille du texte que l'on recherche
99
- target_vocab_size = self.decoder.vocab_size # Taille du mot
100
-
101
- # tensor to store decoder outputs
102
- outputs = torch.zeros(target_len, target_vocab_size).to(
103
- self.device
104
- ) # Instenciation d'une matrice de zeros de taille (taille du texte, taille du mot)
105
- outputs.to(
106
- self.device
107
- ) # Une idiosyncrasie de torch pour mettre le tensor sur le GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  # last hidden state of the encoder is used as the initial hidden state of the decoder
110
- source.to(
111
- self.device
112
- ) # Une idiosyncrasie de torch pour mettre le tensor sur le GPU
113
- hidden, cell = self.encoder(source) # Encode le texte sous forme de vecteur
114
- hidden.to(
115
- self.device
116
- ) # Une idiosyncrasie de torch pour mettre le tensor sur le GPU
117
- cell.to(
118
- self.device
119
- ) # Une idiosyncrasie de torch pour mettre le tensor sur le GPU
120
-
121
- # first input to the decoder is the <start> token.
122
- input = vectoriser.encode("<start>") # Mot de départ du MOdèle
123
- input.to(self.device) # idiosyncrasie de torch pour mmettre sur GPU
124
-
125
- ### DÉBUT DE L'INSTANCIATION TEST ###
126
  # If you wonder, b stands for better
127
  values = None
128
  b_outputs = torch.zeros(target_len, target_vocab_size).to(self.device)
129
  b_outputs.to(self.device)
130
 
131
- for i in range(
132
- 1, target_len
133
- ): # On va déterminer autant de mot que la taille du texte souhaité
134
  # insert input token embedding, previous hidden and previous cell states
135
  # receive output tensor (predictions) and new hidden and cell states.
136
 
 
6
 
7
  import torch
8
 
 
 
9
  logging.basicConfig(level=logging.DEBUG)
10
 
 
 
 
 
11
 
12
  class Encoder(torch.nn.Module):
13
  def __init__(
 
80
 
81
 
82
  class EncoderDecoderModel(torch.nn.Module):
83
+ def __init__(self, encoder, decoder, vectoriser, device):
84
  # Une idiosyncrasie de torch, pour qu'iel puisse faire sa magie
85
  super().__init__()
86
  self.encoder = encoder
87
  self.decoder = decoder
88
+ self.vectoriser = vectoriser
89
  self.device = device
90
 
91
+ def forward(self, source, num_beams=3, summary_len=0.2):
92
+ """
93
+ :param source: tensor
94
+ the input text
95
+ :param num_beams: int
96
+ the number of outputs to iterate on for beam_search
97
+ :param summary_len: int
98
+ length ratio of the summary compared to the text
99
+ """
100
+ # The ratio must be inferior to 1 to allow text compression
101
+ assert summary_len < 1, f"number lesser than 1 expected, got {summary_len}"
102
+
103
+ target_len = int(
104
+ summary_len * source.shape[0]
105
+ ) # Expected summary length (in words)
106
+ target_vocab_size = self.decoder.vocab_size # Word Embedding length
107
+
108
+ # Output of the right format (expected summmary length x word embedding length)
109
+ # filled with zeros. On each iteration, we will replace one of the row of this
110
+ # matrix with the choosen word embedding
111
+ outputs = torch.zeros(target_len, target_vocab_size)
112
+
113
+ # put the tensors on the device (useless if CPU bus very useful in case of GPU)
114
+ outputs.to(self.device)
115
+ source.to(self.device)
116
 
117
  # last hidden state of the encoder is used as the initial hidden state of the decoder
118
+ hidden, cell = self.encoder(source) # Encode the input text
119
+ input = self.vectoriser.encode(
120
+ "<start>"
121
+ ) # Encode the first word of the summary
122
+
123
+ # put the tensors on the device
124
+ hidden.to(self.device)
125
+ cell.to(self.device)
126
+ input.to(self.device)
127
+
128
+ ### BEAM SEARCH ###
 
 
 
 
 
129
  # If you wonder, b stands for better
130
  values = None
131
  b_outputs = torch.zeros(target_len, target_vocab_size).to(self.device)
132
  b_outputs.to(self.device)
133
 
134
+ for i in range(1, target_len):
135
+ # On va déterminer autant de mot que la taille du texte souhaité
 
136
  # insert input token embedding, previous hidden and previous cell states
137
  # receive output tensor (predictions) and new hidden and cell states.
138
 
src/script.py DELETED
@@ -1,90 +0,0 @@
1
- """
2
- DONE :
3
- - Separer la partie vectoriser du Classifeur
4
- - Ajouter un LSTM au Classifieur
5
- - entrainer le Classifieur
6
- TO DO :
7
- - Améliorer les résultats du modèle
8
- """
9
- import logging
10
- import random
11
- from typing import Sequence
12
-
13
- import torch
14
-
15
- import dataloader
16
- from model import Decoder, Encoder, EncoderDecoderModel
17
- from train import train_network
18
-
19
- # logging INFO, WARNING, ERROR, CRITICAL, DEBUG
20
- logging.basicConfig(level=logging.INFO)
21
- logging.disable(level=10)
22
-
23
- import os
24
-
25
- os.environ[
26
- "CUBLAS_WORKSPACE_CONFIG"
27
- ] = ":16:8" # pour que ça marche en deterministe sur mon pc boulot
28
- # variable environnement dans git bash export CUBLAS_WORKSPACE_CONFIG=:16:8
29
- # from datasets import load_dataset
30
-
31
- ### OPEN DATASET###
32
- # dataset = load_dataset("newsroom", data_dir=DATA_PATH, data_files="data/train.jsonl")
33
-
34
- data1 = dataloader.Data("data/train_extract.jsonl")
35
- data2 = dataloader.Data("data/dev_extract.jsonl")
36
- train_dataset = data1.make_dataset()
37
- dev_dataset = data2.make_dataset()
38
- words = data1.get_words()
39
-
40
- vectoriser = dataloader.Vectoriser(words)
41
- word_counts = vectoriser.word_count
42
-
43
-
44
- def predict(model, tokens: Sequence[str]) -> Sequence[str]:
45
- """Predict the POS for a tokenized sequence"""
46
- words_idx = vectoriser.encode(tokens).to(device)
47
- # Pas de calcul de gradient ici : c'est juste pour les prédictions
48
- with torch.no_grad():
49
- # equivalent to model(input) when called out of class
50
- out = model(words_idx).to(device)
51
- out_predictions = out.to(device)
52
- return vectoriser.decode(out_predictions)
53
-
54
-
55
- if __name__ == "__main__":
56
- ### NEURAL NETWORK ###
57
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
- print("Device check. You are using:", device)
59
-
60
- ### RÉSEAU ENTRAÎNÉ ###
61
- # Pour s'assurer que les résultats seront les mêmes à chaque run du notebook
62
- torch.use_deterministic_algorithms(True)
63
- torch.manual_seed(0)
64
- random.seed(0)
65
-
66
- # On peut également entraîner encoder séparemment
67
- encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
68
- decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
69
- # S'ils sont entraînés, on peut les sauvegarder
70
- torch.save(encoder.state_dict(), "model/encoder.pt")
71
- torch.save(encoder.state_dict(), "model/encoder.pt")
72
-
73
- trained_classifier = EncoderDecoderModel(encoder, decoder, device).to(device)
74
-
75
- print(next(trained_classifier.parameters()).device)
76
- # print(train_dataset.is_cuda)
77
-
78
- train_network(
79
- trained_classifier,
80
- [vectoriser.vectorize(row) for index, row in train_dataset.iterrows()],
81
- [vectoriser.vectorize(row) for index, row in dev_dataset.iterrows()],
82
- 5,
83
- )
84
-
85
- torch.save(trained_classifier.state_dict(), "model/model.pt")
86
-
87
- print(f'test text : {dev_dataset.iloc[6]["summary"]}')
88
- print(
89
- f'test prediction : {predict(trained_classifier, dev_dataset.iloc[6]["text"])}'
90
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/train.py CHANGED
@@ -3,21 +3,19 @@ Training the network
3
  """
4
  import datetime
5
  import logging
 
6
  import time
7
  from typing import Sequence, Tuple
8
 
9
  import torch
10
 
11
  import dataloader
 
12
 
13
  # logging INFO, WARNING, ERROR, CRITICAL, DEBUG
14
  logging.basicConfig(level=logging.INFO)
15
  logging.disable(level=10)
16
 
17
- data1 = dataloader.Data("data/train_extract.jsonl")
18
- words = data1.get_words()
19
- vectoriser = dataloader.Vectoriser(words)
20
-
21
 
22
  def train_network(
23
  model: torch.nn.Module,
@@ -47,7 +45,6 @@ def train_network(
47
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
  model = model.to(device)
49
  print("Device check. You are using:", model.device)
50
- model.train()
51
 
52
  # with torch.no_grad():
53
 
@@ -81,10 +78,12 @@ def train_network(
81
 
82
  out = model(source).to(device)
83
  logging.debug(f"outputs = {out.shape}")
 
84
  target = torch.nn.functional.pad(
85
  target, (0, len(out) - len(target)), value=-100
86
  )
87
- # logging.debug(f"predition : {vectoriser.decode(output_predictions)}")
 
88
  loss = torch.nn.functional.nll_loss(out, target).to(device)
89
  loss.backward()
90
  torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
@@ -131,3 +130,73 @@ def train_network(
131
  print(
132
  f"{epoch_n}\t{epoch_loss/epoch_length:.5}\t{abs(dev_correct/dev_total):.2%}\t\t{datetime.timedelta(seconds=epoch_compute_time)}"
133
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  """
4
  import datetime
5
  import logging
6
+ import random
7
  import time
8
  from typing import Sequence, Tuple
9
 
10
  import torch
11
 
12
  import dataloader
13
+ from model import Decoder, Encoder, EncoderDecoderModel
14
 
15
  # logging INFO, WARNING, ERROR, CRITICAL, DEBUG
16
  logging.basicConfig(level=logging.INFO)
17
  logging.disable(level=10)
18
 
 
 
 
 
19
 
20
  def train_network(
21
  model: torch.nn.Module,
 
45
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
  model = model.to(device)
47
  print("Device check. You are using:", model.device)
 
48
 
49
  # with torch.no_grad():
50
 
 
78
 
79
  out = model(source).to(device)
80
  logging.debug(f"outputs = {out.shape}")
81
+
82
  target = torch.nn.functional.pad(
83
  target, (0, len(out) - len(target)), value=-100
84
  )
85
+
86
+ # logging.debug(f"prediction : {vectoriser.decode(output_predictions)}")
87
  loss = torch.nn.functional.nll_loss(out, target).to(device)
88
  loss.backward()
89
  torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
 
130
  print(
131
  f"{epoch_n}\t{epoch_loss/epoch_length:.5}\t{abs(dev_correct/dev_total):.2%}\t\t{datetime.timedelta(seconds=epoch_compute_time)}"
132
  )
133
+
134
+
135
+ def predict(model, tokens: Sequence[str]) -> Sequence[str]:
136
+ """Predict the POS for a tokenized sequence"""
137
+ words_idx = vectoriser.encode(tokens).to(device)
138
+ # Pas de calcul de gradient ici : c'est juste pour les prédictions
139
+ with torch.no_grad():
140
+ # equivalent to model(input) when called out of class
141
+ out = model(words_idx).to(device)
142
+ out_predictions = out.to(device)
143
+ print(out_predictions)
144
+ out_predictions = out_predictions.argmax(dim=-1)
145
+ return vectoriser.decode(out_predictions)
146
+
147
+
148
+ if __name__ == "__main__":
149
+ train_dataset = dataloader.Data("data/train_extract.jsonl")
150
+ words = train_dataset.get_words()
151
+ vectoriser = dataloader.Vectoriser(words)
152
+
153
+ train_dataset = dataloader.Data("data/train_extract.jsonl", transform=vectoriser)
154
+ dev_dataset = dataloader.Data("data/dev_extract.jsonl", transform=vectoriser)
155
+
156
+ train_dataloader = torch.utils.data.DataLoader(
157
+ train_dataset, batch_size=2, shuffle=True, collate_fn=dataloader.pad_collate
158
+ )
159
+
160
+ dev_dataloader = torch.utils.data.DataLoader(
161
+ dev_dataset, batch_size=4, shuffle=True, collate_fn=dataloader.pad_collate
162
+ )
163
+
164
+ for i_batch, batch in enumerate(train_dataloader):
165
+ print(i_batch, batch[0], batch[1])
166
+
167
+ ### NEURAL NETWORK ###
168
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
169
+ print("Device check. You are using:", device)
170
+
171
+ ### RÉSEAU ENTRAÎNÉ ###
172
+ # Pour s'assurer que les résultats seront les mêmes à chaque run du notebook
173
+ torch.use_deterministic_algorithms(True)
174
+ torch.manual_seed(0)
175
+ random.seed(0)
176
+
177
+ # On peut également entraîner encoder séparemment
178
+ encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
179
+ decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
180
+
181
+ trained_classifier = EncoderDecoderModel(encoder, decoder, vectoriser, device).to(
182
+ device
183
+ )
184
+
185
+ print(next(trained_classifier.parameters()).device)
186
+ # print(train_dataset.is_cuda)
187
+
188
+ train_network(
189
+ trained_classifier,
190
+ train_dataset,
191
+ dev_dataset,
192
+ 2,
193
+ )
194
+
195
+ torch.save(trained_classifier.state_dict(), "model/model.pt")
196
+ vectoriser.save("model/vocab.pkl")
197
+ trained_classifier.push_to_hub("SummaryProject-LSTM")
198
+
199
+ print(f"test summary : {vectoriser.decode(dev_dataset[6][1])}")
200
+ print(
201
+ f"test prediction : {predict(trained_classifier, vectoriser.decode(dev_dataset[6][0]))}"
202
+ )
templates/index.html.jinja CHANGED
@@ -4,7 +4,7 @@
4
  <title>Text summarization API</title>
5
  <meta charset="utf-8" />
6
  <meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=no" />
7
- <link rel="stylesheet" href="{{ url_for('templates', path='site_style/css/main.css') }}" />
8
  <script>
9
  function customReset()
10
  {
@@ -13,6 +13,23 @@
13
  document.getElementById("summary").value = "";
14
  }
15
  </script>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  </head>
17
  <body>
18
  <div id="header">
@@ -28,18 +45,21 @@
28
  </nav>
29
 
30
  <div class="choixModel">
31
- <label for="model-select">Choose a model :</label>
32
- <select name="model" id="model-select">
33
- <option value="lstm">LSTM</option>
34
- <option value="autre">Autre</option>
35
- </select>
 
 
 
36
  </div>
37
 
38
  <div>
39
  <table>
40
  <tr>
41
  <td>
42
- <form id = "my_form" action="/" method="post" class="formulaire">
43
  <textarea id="text" name="text" placeholder="Enter your text here!" rows="15" cols="75">{{text}}</textarea>
44
  <input type="hidden" name="textarea_value" value="{{ text }}">
45
  </form>
@@ -51,8 +71,9 @@
51
  </table>
52
  </div>
53
  <div class="buttons">
54
- <button form ="my_form" class='search_bn' type="submit" class="btn btn-primary btn-block btn-large" rows="1" cols="50">Go !</button>
55
- <button form ="my_form" type="button" value="Reset" onclick="customReset();">Reset</button>
 
56
  </div>
57
 
58
  <div class="copyright">
 
4
  <title>Text summarization API</title>
5
  <meta charset="utf-8" />
6
  <meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=no" />
7
+ <link rel="stylesheet" href="{{ url_for('templates', path='templates/site_style/css/main.css') }}" />
8
  <script>
9
  function customReset()
10
  {
 
13
  document.getElementById("summary").value = "";
14
  }
15
  </script>
16
+ <script>
17
+ function submitBothForms()
18
+ {
19
+ document.getElementById("my_form").submit();
20
+ document.getElementById("choixModel").submit();
21
+ }
22
+ </script>
23
+ <script>
24
+ function getValue() {
25
+ var e = document.getElementById("choixModel");
26
+ var value = e.value;
27
+ var text = e.options[e.selectedIndex].text;
28
+ return text}
29
+ </script>
30
+ <script type="text/javascript">
31
+ document.getElementById('choixModel').value = "<?php echo $_GET['choixModel'];?>";
32
+ </script>
33
  </head>
34
  <body>
35
  <div id="header">
 
45
  </nav>
46
 
47
  <div class="choixModel">
48
+ <form id="choixModel" method="post" action="/model">
49
+ <label for="selectModel">Choose a model :</label>
50
+ <select name="choixModel" class="selectModel" id="choixModel">
51
+ <option value="lstm">LSTM</option>
52
+ <option value="fineTunedT5">Fine-tuned T5</option>
53
+ </select>
54
+ </form>
55
+ <button form ="choixModel" class='search_bn' type="submit" class="btn btn-primary btn-block btn-large" rows="1" cols="50">Select model</button>
56
  </div>
57
 
58
  <div>
59
  <table>
60
  <tr>
61
  <td>
62
+ <form id = "my_form" action="/predict" method="post" class="formulaire">
63
  <textarea id="text" name="text" placeholder="Enter your text here!" rows="15" cols="75">{{text}}</textarea>
64
  <input type="hidden" name="textarea_value" value="{{ text }}">
65
  </form>
 
71
  </table>
72
  </div>
73
  <div class="buttons">
74
+ <!-- <button id="submit" type="submit" onclick=submitBothForms()>SUBMIT</button> -->
75
+ <button form ="my_form" class='search_bn' type="submit" class="btn btn-primary btn-block btn-large" rows="1" cols="50">Go !</button>
76
+ <button form ="my_form" type="button" value="Reset" onclick="customReset();">Reset</button>
77
  </div>
78
 
79
  <div class="copyright">