Linggg commited on
Commit
035b2cc
·
2 Parent(s): 4b472fe 4e410f4

Merge branch 'Ling' of https://github.com/EveSa/SummaryProject into Ling

Browse files
api.py DELETED
@@ -1,51 +0,0 @@
1
- import uvicorn
2
- from fastapi import FastAPI, Form, Request
3
- from fastapi.staticfiles import StaticFiles
4
- from fastapi.templating import Jinja2Templates
5
-
6
- from inference import inferenceAPI
7
-
8
-
9
- # ------ MODELE --------------------------------------------------------------
10
- # appel de la fonction inference, adaptee pour une entree txt
11
- def summarize(text: str):
12
- return " ".join(inferenceAPI(text))
13
-
14
-
15
- # ----------------------------------------------------------------------------------
16
-
17
-
18
- # -------- API ---------------------------------------------------------------------
19
- app = FastAPI()
20
-
21
- # static files pour envoi du css au navigateur
22
- templates = Jinja2Templates(directory="templates")
23
- app.mount("/", StaticFiles(directory="templates", html=True), name="templates")
24
-
25
-
26
- @app.get("/")
27
- async def index(request: Request):
28
- return templates.TemplateResponse("index.html.jinja", {"request": request})
29
-
30
-
31
- # retourner le texte, les predictions et message d'erreur si formulaire envoye vide
32
- @app.post("/")
33
- async def prediction(request: Request, text: str = Form(None)):
34
- if not text:
35
- error = "Merci de saisir votre texte."
36
- return templates.TemplateResponse(
37
- "index.html.jinja", {"request": request, "text": error}
38
- )
39
- else:
40
- summary = summarize(text)
41
- return templates.TemplateResponse(
42
- "index.html.jinja", {"request": request, "text": text, "summary": summary}
43
- )
44
-
45
-
46
- # ------------------------------------------------------------------------------------
47
-
48
-
49
- # lancer le serveur et le recharger a chaque modification sauvegardee
50
- # if __name__ == "__main__":
51
- # uvicorn.run("api:app", port=8000, reload=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,56 +1,16 @@
1
- absl-py==1.4.0
2
- aiohttp==3.8.4
3
- aiosignal==1.3.1
4
- alembic==1.9.4
5
  anyascii==0.3.1
6
  anyio==3.6.2
7
- async-timeout==4.0.2
8
- attrs==22.2.0
9
- banal==1.0.6
10
- blis==0.7.9
11
- catalogue==2.0.8
12
  certifi==2022.12.7
13
- charset-normalizer==3.0.1
14
  click==8.1.3
15
- confection==0.0.4
16
- contourpy==1.0.7
17
  contractions==0.1.73
18
- cycler==0.11.0
19
- cymem==2.0.7
20
- dataloader==2.0
21
- dataset==1.6.0
22
- datasets==2.10.1
23
- dill==0.3.6
24
- en-core-web-lg==3.5.0
25
- evaluate==0.4.0
26
- fastapi==0.91.0
27
  filelock==3.9.0
28
- flake8==6.0.0
29
- fonttools==4.38.0
30
- frozenlist==1.3.3
31
- fsspec==2023.3.0
32
- greenlet==2.0.2
33
  h11==0.14.0
34
- huggingface-hub==0.12.1
35
- certifi==2022.12.7
36
- charset-normalizer==3.1.0
37
- click==8.1.3
38
- fastapi==0.92.0
39
- filelock==3.9.0
40
  idna==3.4
41
- importlib-metadata==6.0.0
42
- importlib-resources==5.12.0
43
  Jinja2==3.1.2
44
- joblib==1.2.0
45
- kiwisolver==1.4.4
46
- langcodes==3.3.0
47
- Mako==1.2.4
48
  MarkupSafe==2.1.2
49
- matplotlib==3.7.0
50
- mccabe==0.7.0
51
- multidict==6.0.4
52
- multiprocess==0.70.14
53
- murmurhash==1.0.9
54
  numpy==1.24.2
55
  nvidia-cublas-cu11==11.10.3.66
56
  nvidia-cuda-nvrtc-cu11==11.7.99
@@ -58,56 +18,22 @@ nvidia-cuda-runtime-cu11==11.7.99
58
  nvidia-cudnn-cu11==8.5.0.96
59
  packaging==23.0
60
  pandas==1.5.3
61
- pathy==0.10.1
62
- Pillow==9.4.0
63
- preshed==3.0.8
64
- protobuf==3.20.0
65
  pyahocorasick==2.0.0
66
- pyarrow==11.0.0
67
- pycodestyle==2.10.0
68
- pydantic==1.10.4
69
- pyflakes==3.0.1
70
- pyparsing==3.0.9
71
  python-dateutil==2.8.2
72
- python-multipart==0.0.5
73
  pytz==2022.7.1
74
  PyYAML==6.0
75
  regex==2022.10.31
76
  requests==2.28.2
77
- responses==0.18.0
78
- rouge-score==0.1.2
79
- scikit-learn==1.2.1
80
- scipy==1.10.0
81
- sentencepiece==0.1.97
82
  six==1.16.0
83
- smart-open==6.3.0
84
  sniffio==1.3.0
85
- spacy==3.5.0
86
- spacy-legacy==3.0.12
87
- spacy-loggers==1.0.4
88
- SQLAlchemy==1.4.46
89
- srsly==2.4.5
90
- starlette==0.24.0
91
- summarizer==0.0.7
92
  textsearch==0.0.24
93
- thinc==8.1.7
94
- threadpoolctl==3.1.0
95
- tokenizers==0.13.2
96
- tomli==2.0.1
97
- torch==1.13.1
98
- tqdm==4.64.1
99
- transformers==4.26.1
100
- typer==0.7.0
101
- typing-extensions==4.4.0
102
- urllib3==1.26.14
103
- starlette==0.25.0
104
  tokenizers==0.13.2
105
  torch==1.13.1
106
  tqdm==4.65.0
 
107
  typing_extensions==4.5.0
108
  urllib3==1.26.15
109
- uvicorn==0.20.0
110
- wasabi==1.1.1
111
- xxhash==3.2.0
112
- yarl==1.8.2
113
- zipp==3.14.0
 
 
 
 
 
1
  anyascii==0.3.1
2
  anyio==3.6.2
 
 
 
 
 
3
  certifi==2022.12.7
4
+ charset-normalizer==3.1.0
5
  click==8.1.3
 
 
6
  contractions==0.1.73
7
+ fastapi==0.94.0
 
 
 
 
 
 
 
 
8
  filelock==3.9.0
 
 
 
 
 
9
  h11==0.14.0
10
+ huggingface-hub==0.13.2
 
 
 
 
 
11
  idna==3.4
 
 
12
  Jinja2==3.1.2
 
 
 
 
13
  MarkupSafe==2.1.2
 
 
 
 
 
14
  numpy==1.24.2
15
  nvidia-cublas-cu11==11.10.3.66
16
  nvidia-cuda-nvrtc-cu11==11.7.99
 
18
  nvidia-cudnn-cu11==8.5.0.96
19
  packaging==23.0
20
  pandas==1.5.3
 
 
 
 
21
  pyahocorasick==2.0.0
22
+ pydantic==1.10.6
 
 
 
 
23
  python-dateutil==2.8.2
24
+ python-multipart==0.0.6
25
  pytz==2022.7.1
26
  PyYAML==6.0
27
  regex==2022.10.31
28
  requests==2.28.2
 
 
 
 
 
29
  six==1.16.0
 
30
  sniffio==1.3.0
31
+ starlette==0.26.1
 
 
 
 
 
 
32
  textsearch==0.0.24
 
 
 
 
 
 
 
 
 
 
 
33
  tokenizers==0.13.2
34
  torch==1.13.1
35
  tqdm==4.65.0
36
+ transformers==4.26.1
37
  typing_extensions==4.5.0
38
  urllib3==1.26.15
39
+ uvicorn==0.21.0
 
 
 
 
src/api.py CHANGED
@@ -1,31 +1,30 @@
1
- import uvicorn
2
  from fastapi import FastAPI, Form, Request
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.templating import Jinja2Templates
5
- import re
6
 
7
- from src.inference import inferenceAPI
8
- from src.inference_t5 import inferenceAPI_t5
9
 
10
 
11
  # ------ INFERENCE MODEL --------------------------------------------------------------
12
  # appel de la fonction inference, adaptee pour une entree txt
13
  def summarize(text: str):
14
- if choisir_modele.var == 'lstm' :
15
- return " ".join(inferenceAPI(text))
16
  elif choisir_modele.var == "fineTunedT5":
17
- text = inferenceAPI_t5(text)
 
18
 
19
  # ----------------------------------------------------------------------------------
20
 
21
 
22
  def choisir_modele(choixModele):
23
  print("ON A RECUP LE CHOIX MODELE")
24
- if choixModele == "lstm" :
25
- choisir_modele.var ='lstm'
26
  elif choixModele == "fineTunedT5":
27
  choisir_modele.var = "fineTunedT5"
28
- else :
29
  "le modele n'est pas defini"
30
 
31
 
@@ -41,29 +40,29 @@ app.mount("/templates", StaticFiles(directory="templates"), name="templates")
41
  async def index(request: Request):
42
  return templates.TemplateResponse("index.html.jinja", {"request": request})
43
 
 
44
  @app.get("/model")
45
  async def index(request: Request):
46
  return templates.TemplateResponse("index.html.jinja", {"request": request})
47
 
 
48
  @app.get("/predict")
49
  async def index(request: Request):
50
  return templates.TemplateResponse("index.html.jinja", {"request": request})
51
 
52
 
53
  @app.post("/model")
54
- async def choix_model(request: Request, choixModel:str = Form(None)):
55
  print(choixModel)
56
  if not choixModel:
57
  erreur_modele = "Merci de saisir un modèle."
58
  return templates.TemplateResponse(
59
- "index.html.jinja", {"request": request, "text": erreur_modele}
60
  )
61
- else :
62
  choisir_modele(choixModel)
63
  print("C'est bon on utilise le modèle demandé")
64
- return templates.TemplateResponse(
65
- "index.html.jinja", {"request": request}
66
- )
67
 
68
 
69
  # retourner le texte, les predictions et message d'erreur si formulaire envoye vide
 
 
1
  from fastapi import FastAPI, Form, Request
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.templating import Jinja2Templates
 
4
 
5
+ from src.inference_lstm import inference_lstm
6
+ from src.inference_t5 import inference_t5
7
 
8
 
9
  # ------ INFERENCE MODEL --------------------------------------------------------------
10
  # appel de la fonction inference, adaptee pour une entree txt
11
  def summarize(text: str):
12
+ if choisir_modele.var == "lstm":
13
+ return " ".join(inference_lstm(text))
14
  elif choisir_modele.var == "fineTunedT5":
15
+ text = inference_t5(text)
16
+
17
 
18
  # ----------------------------------------------------------------------------------
19
 
20
 
21
  def choisir_modele(choixModele):
22
  print("ON A RECUP LE CHOIX MODELE")
23
+ if choixModele == "lstm":
24
+ choisir_modele.var = "lstm"
25
  elif choixModele == "fineTunedT5":
26
  choisir_modele.var = "fineTunedT5"
27
+ else:
28
  "le modele n'est pas defini"
29
 
30
 
 
40
  async def index(request: Request):
41
  return templates.TemplateResponse("index.html.jinja", {"request": request})
42
 
43
+
44
  @app.get("/model")
45
  async def index(request: Request):
46
  return templates.TemplateResponse("index.html.jinja", {"request": request})
47
 
48
+
49
  @app.get("/predict")
50
  async def index(request: Request):
51
  return templates.TemplateResponse("index.html.jinja", {"request": request})
52
 
53
 
54
  @app.post("/model")
55
+ async def choix_model(request: Request, choixModel: str = Form(None)):
56
  print(choixModel)
57
  if not choixModel:
58
  erreur_modele = "Merci de saisir un modèle."
59
  return templates.TemplateResponse(
60
+ "index.html.jinja", {"request": request, "text": erreur_modele}
61
  )
62
+ else:
63
  choisir_modele(choixModel)
64
  print("C'est bon on utilise le modèle demandé")
65
+ return templates.TemplateResponse("index.html.jinja", {"request": request})
 
 
66
 
67
 
68
  # retourner le texte, les predictions et message d'erreur si formulaire envoye vide
src/dataloader.py CHANGED
@@ -52,10 +52,15 @@ class Data(torch.utils.data.Dataset):
52
 
53
  def __getitem__(self, idx):
54
  row = self.data.iloc[idx]
55
- text = row["text"].translate(str.maketrans("", "", string.punctuation)).split()
 
 
56
  summary = (
57
- row["summary"].translate(str.maketrans("", "", string.punctuation)).split()
58
- )
 
 
 
59
  summary = ["<start>", *summary, "<end>"]
60
  sample = {"text": text, "summary": summary}
61
 
@@ -106,7 +111,8 @@ class Data(torch.utils.data.Dataset):
106
  tokenized_texts.append(text)
107
 
108
  if text_type == "summary":
109
- return [["<start>", *summary, "<end>"] for summary in tokenized_texts]
 
110
  return tokenized_texts
111
 
112
  def get_words(self) -> list:
@@ -157,8 +163,10 @@ class Vectoriser:
157
 
158
  def __init__(self, vocab=None) -> None:
159
  self.vocab = vocab
160
- self.word_count = Counter(word.lower().strip(",.\\-") for word in self.vocab)
161
- self.idx_to_token = sorted([t for t, c in self.word_count.items() if c > 1])
 
 
162
  self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
163
 
164
  def load(self, path):
@@ -167,7 +175,8 @@ class Vectoriser:
167
  self.word_count = Counter(
168
  word.lower().strip(",.\\-") for word in self.vocab
169
  )
170
- self.idx_to_token = sorted([t for t, c in self.word_count.items() if c > 1])
 
171
  self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
172
 
173
  def save(self, path):
 
52
 
53
  def __getitem__(self, idx):
54
  row = self.data.iloc[idx]
55
+ text = row["text"].translate(
56
+ str.maketrans(
57
+ "", "", string.punctuation)).split()
58
  summary = (
59
+ row["summary"].translate(
60
+ str.maketrans(
61
+ "",
62
+ "",
63
+ string.punctuation)).split())
64
  summary = ["<start>", *summary, "<end>"]
65
  sample = {"text": text, "summary": summary}
66
 
 
111
  tokenized_texts.append(text)
112
 
113
  if text_type == "summary":
114
+ return [["<start>", *summary, "<end>"]
115
+ for summary in tokenized_texts]
116
  return tokenized_texts
117
 
118
  def get_words(self) -> list:
 
163
 
164
  def __init__(self, vocab=None) -> None:
165
  self.vocab = vocab
166
+ self.word_count = Counter(word.lower().strip(",.\\-")
167
+ for word in self.vocab)
168
+ self.idx_to_token = sorted(
169
+ [t for t, c in self.word_count.items() if c > 1])
170
  self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
171
 
172
  def load(self, path):
 
175
  self.word_count = Counter(
176
  word.lower().strip(",.\\-") for word in self.vocab
177
  )
178
+ self.idx_to_token = sorted(
179
+ [t for t, c in self.word_count.items() if c > 1])
180
  self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
181
 
182
  def save(self, path):
src/fine_tune_T5.py CHANGED
@@ -146,7 +146,11 @@ if __name__ == '__main__':
146
  train_dataset = datasetmaker('data/train_extract.jsonl')
147
 
148
 
 
149
  dev_dataset = datasetmaker("data/dev_extract.jsonl")
 
 
 
150
 
151
  test_dataset = datasetmaker('data/test_extract.jsonl')
152
 
 
146
  train_dataset = datasetmaker('data/train_extract.jsonl')
147
 
148
 
149
+ <<<<<<< HEAD
150
  dev_dataset = datasetmaker("data/dev_extract.jsonl")
151
+ =======
152
+ test_dataset = datasetmaker("data/test_extract.jsonl")
153
+ >>>>>>> 4e410f4bdcd6de645d9e73bb207d8a9170dfc3e1
154
 
155
  test_dataset = datasetmaker('data/test_extract.jsonl')
156
 
src/{inference.py → inference_lstm.py} RENAMED
@@ -1,5 +1,6 @@
1
  """
2
  Allows to predict the summary for a given entry text
 
3
  """
4
  import pickle
5
 
@@ -7,13 +8,14 @@ import torch
7
 
8
  from src import dataloader
9
  from src.model import Decoder, Encoder, EncoderDecoderModel
 
10
 
11
  with open("model/vocab.pkl", "rb") as vocab:
12
  words = pickle.load(vocab)
13
  vectoriser = dataloader.Vectoriser(words)
14
 
15
 
16
- def inferenceAPI(text: str) -> str:
17
  """
18
  Predict the summary for an input text
19
  --------
@@ -34,6 +36,7 @@ def inferenceAPI(text: str) -> str:
34
 
35
  # On instancie le modèle
36
  model = EncoderDecoderModel(encoder, decoder, vectoriser, device)
 
37
 
38
  # model.load_state_dict(torch.load("model/model.pt", map_location=device))
39
  # model.eval()
 
1
  """
2
  Allows to predict the summary for a given entry text
3
+ using LSTM model
4
  """
5
  import pickle
6
 
 
8
 
9
  from src import dataloader
10
  from src.model import Decoder, Encoder, EncoderDecoderModel
11
+ # from transformers import AutoModel
12
 
13
  with open("model/vocab.pkl", "rb") as vocab:
14
  words = pickle.load(vocab)
15
  vectoriser = dataloader.Vectoriser(words)
16
 
17
 
18
+ def inference_lstm(text: str) -> str:
19
  """
20
  Predict the summary for an input text
21
  --------
 
36
 
37
  # On instancie le modèle
38
  model = EncoderDecoderModel(encoder, decoder, vectoriser, device)
39
+ # model = AutoModel.from_pretrained("EveSa/SummaryProject-LSTM")
40
 
41
  # model.load_state_dict(torch.load("model/model.pt", map_location=device))
42
  # model.eval()
src/inference_t5.py CHANGED
@@ -1,20 +1,23 @@
1
  """
2
  Allows to predict the summary for a given entry text
3
  """
4
- import torch
5
  import re
6
  import string
 
 
 
7
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
8
 
9
 
10
  def clean_text(texts: str) -> str:
11
  texts = texts.lower()
12
  texts = texts.translate(str.maketrans("", "", string.punctuation))
13
- texts = re.sub(r'\n', ' ', texts)
14
  return texts
15
 
16
 
17
- def inferenceAPI_T5(text: str) -> str:
 
18
  """
19
  Predict the summary for an input text
20
  --------
@@ -29,32 +32,36 @@ def inferenceAPI_T5(text: str) -> str:
29
  # On défini les paramètres d'entrée pour le modèle
30
  text = clean_text(text)
31
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
32
  tokenizer = (AutoTokenizer.from_pretrained("Linggg/t5_summary",use_auth_token=True))
33
  # load local model
34
  model = (AutoModelForSeq2SeqLM
35
  .from_pretrained("Linggg/t5_summary",use_auth_token=True)
36
  .to(device))
37
 
 
38
  text_encoding = tokenizer(
39
  text,
40
  max_length=1024,
41
- padding='max_length',
42
  truncation=True,
43
  return_attention_mask=True,
44
  add_special_tokens=True,
45
- return_tensors='pt'
46
  )
47
  generated_ids = model.generate(
48
- input_ids=text_encoding['input_ids'],
49
- attention_mask=text_encoding['attention_mask'],
50
  max_length=128,
51
  num_beams=8,
52
  length_penalty=0.8,
53
- early_stopping=True
54
  )
55
 
56
  preds = [
57
- tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
 
 
58
  for gen_id in generated_ids
59
  ]
60
  return "".join(preds)
 
1
  """
2
  Allows to predict the summary for a given entry text
3
  """
 
4
  import re
5
  import string
6
+
7
+ import contractions
8
+ import torch
9
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
10
 
11
 
12
  def clean_text(texts: str) -> str:
13
  texts = texts.lower()
14
  texts = texts.translate(str.maketrans("", "", string.punctuation))
15
+ texts = re.sub(r"\n", " ", texts)
16
  return texts
17
 
18
 
19
+
20
+ def inference_t5(text: str) -> str:
21
  """
22
  Predict the summary for an input text
23
  --------
 
32
  # On défini les paramètres d'entrée pour le modèle
33
  text = clean_text(text)
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
  tokenizer = (AutoTokenizer.from_pretrained("Linggg/t5_summary",use_auth_token=True))
37
  # load local model
38
  model = (AutoModelForSeq2SeqLM
39
  .from_pretrained("Linggg/t5_summary",use_auth_token=True)
40
  .to(device))
41
 
42
+
43
  text_encoding = tokenizer(
44
  text,
45
  max_length=1024,
46
+ padding="max_length",
47
  truncation=True,
48
  return_attention_mask=True,
49
  add_special_tokens=True,
50
+ return_tensors="pt",
51
  )
52
  generated_ids = model.generate(
53
+ input_ids=text_encoding["input_ids"],
54
+ attention_mask=text_encoding["attention_mask"],
55
  max_length=128,
56
  num_beams=8,
57
  length_penalty=0.8,
58
+ early_stopping=True,
59
  )
60
 
61
  preds = [
62
+ tokenizer.decode(
63
+ gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True
64
+ )
65
  for gen_id in generated_ids
66
  ]
67
  return "".join(preds)
src/model.py CHANGED
@@ -25,7 +25,8 @@ class Encoder(torch.nn.Module):
25
  # on s'en servira pour les mots inconnus
26
  self.embeddings = torch.nn.Embedding(vocab_size, embeddings_dim)
27
  self.embeddings.to(device)
28
- self.hidden = torch.nn.LSTM(embeddings_dim, hidden_size, dropout=dropout)
 
29
  # Comme on va calculer la log-vraisemblance,
30
  # c'est le log-softmax qui nous intéresse
31
  self.dropout = torch.nn.Dropout(dropout)
@@ -61,7 +62,8 @@ class Decoder(torch.nn.Module):
61
  # on s'en servira pour les mots inconnus
62
  self.vocab_size = vocab_size
63
  self.embeddings = torch.nn.Embedding(vocab_size, embeddings_dim)
64
- self.hidden = torch.nn.LSTM(embeddings_dim, hidden_size, dropout=dropout)
 
65
  self.output = torch.nn.Linear(hidden_size, vocab_size)
66
  # Comme on va calculer la log-vraisemblance,
67
  # c'est le log-softmax qui nous intéresse
@@ -100,32 +102,36 @@ class EncoderDecoderModel(torch.nn.Module):
100
  # The ratio must be inferior to 1 to allow text compression
101
  assert summary_len < 1, f"number lesser than 1 expected, got {summary_len}"
102
 
103
- target_len = int(
104
- summary_len * source.shape[0]
105
- ) # Expected summary length (in words)
106
- target_vocab_size = self.decoder.vocab_size # Word Embedding length
107
 
108
- # Output of the right format (expected summmary length x word embedding length)
109
- # filled with zeros. On each iteration, we will replace one of the row of this
110
- # matrix with the choosen word embedding
 
111
  outputs = torch.zeros(target_len, target_vocab_size)
112
 
113
- # put the tensors on the device (useless if CPU bus very useful in case of GPU)
 
114
  outputs.to(self.device)
115
  source.to(self.device)
116
 
117
- # last hidden state of the encoder is used as the initial hidden state of the decoder
118
- hidden, cell = self.encoder(source) # Encode the input text
119
- input = self.vectoriser.encode(
120
- "<start>"
121
- ) # Encode the first word of the summary
 
 
122
 
123
  # put the tensors on the device
124
  hidden.to(self.device)
125
  cell.to(self.device)
126
  input.to(self.device)
127
 
128
- ### BEAM SEARCH ###
129
  # If you wonder, b stands for better
130
  values = None
131
  b_outputs = torch.zeros(target_len, target_vocab_size).to(self.device)
@@ -134,14 +140,16 @@ class EncoderDecoderModel(torch.nn.Module):
134
  for i in range(1, target_len):
135
  # On va déterminer autant de mot que la taille du texte souhaité
136
  # insert input token embedding, previous hidden and previous cell states
137
- # receive output tensor (predictions) and new hidden and cell states.
 
138
 
139
  # replace predictions in a tensor holding predictions for each token
140
  # logging.debug(f"output : {output}")
141
 
142
  ####### DÉBUT DU BEAM SEARCH ##########
143
  if values is None:
144
- # On calcule une première fois les premières probabilité de mot après <start>
 
145
  output, hidden, cell = self.decoder(input, hidden, cell)
146
  output.to(self.device)
147
  b_hidden = hidden
@@ -152,7 +160,8 @@ class EncoderDecoderModel(torch.nn.Module):
152
  values, indices = output.topk(num_beams, sorted=True)
153
 
154
  else:
155
- # On instancie le dictionnaire qui contiendra les scores pour chaque possibilité
 
156
  scores = {}
157
 
158
  # Pour chacune des meilleures valeurs, on va calculer l'output
@@ -160,7 +169,8 @@ class EncoderDecoderModel(torch.nn.Module):
160
  indice.to(self.device)
161
 
162
  # On calcule l'output
163
- b_output, b_hidden, b_cell = self.decoder(indice, b_hidden, b_cell)
 
164
 
165
  # On empêche le modèle de se répéter d'un mot sur l'autre en mettant
166
  # de force la probabilité du mot précédent à 0
@@ -179,7 +189,8 @@ class EncoderDecoderModel(torch.nn.Module):
179
  # Et du coup on rempli la place de i-1 à la place de i
180
  b_outputs[i - 1] = b_output.to(self.device)
181
 
182
- # On instancies nos nouvelles valeurs pour la prochaine itération
 
183
  values, indices = b_output.topk(num_beams, sorted=True)
184
 
185
  ##################################
 
25
  # on s'en servira pour les mots inconnus
26
  self.embeddings = torch.nn.Embedding(vocab_size, embeddings_dim)
27
  self.embeddings.to(device)
28
+ self.hidden = torch.nn.LSTM(
29
+ embeddings_dim, hidden_size, dropout=dropout)
30
  # Comme on va calculer la log-vraisemblance,
31
  # c'est le log-softmax qui nous intéresse
32
  self.dropout = torch.nn.Dropout(dropout)
 
62
  # on s'en servira pour les mots inconnus
63
  self.vocab_size = vocab_size
64
  self.embeddings = torch.nn.Embedding(vocab_size, embeddings_dim)
65
+ self.hidden = torch.nn.LSTM(
66
+ embeddings_dim, hidden_size, dropout=dropout)
67
  self.output = torch.nn.Linear(hidden_size, vocab_size)
68
  # Comme on va calculer la log-vraisemblance,
69
  # c'est le log-softmax qui nous intéresse
 
102
  # The ratio must be inferior to 1 to allow text compression
103
  assert summary_len < 1, f"number lesser than 1 expected, got {summary_len}"
104
 
105
+ # Expected summary length (in words)
106
+ target_len = int(summary_len * source.shape[0])
107
+ # Word Embedding length
108
+ target_vocab_size = self.decoder.vocab_size
109
 
110
+ # Output of the right format (expected summmary length x word
111
+ # embedding length) filled with zeros. On each iteration, we
112
+ # will replace one of the row of this matrix with the choosen
113
+ # word embedding
114
  outputs = torch.zeros(target_len, target_vocab_size)
115
 
116
+ # put the tensors on the device (useless if CPU bus very useful in
117
+ # case of GPU)
118
  outputs.to(self.device)
119
  source.to(self.device)
120
 
121
+ # last hidden state of the encoder is used
122
+ # as the initial hidden state of the decoder
123
+
124
+ # Encode the input text
125
+ hidden, cell = self.encoder(source)
126
+ # Encode the first word of the summary
127
+ input = self.vectoriser.encode("<start>")
128
 
129
  # put the tensors on the device
130
  hidden.to(self.device)
131
  cell.to(self.device)
132
  input.to(self.device)
133
 
134
+ # BEAM SEARCH #
135
  # If you wonder, b stands for better
136
  values = None
137
  b_outputs = torch.zeros(target_len, target_vocab_size).to(self.device)
 
140
  for i in range(1, target_len):
141
  # On va déterminer autant de mot que la taille du texte souhaité
142
  # insert input token embedding, previous hidden and previous cell states
143
+ # receive output tensor (predictions) and new hidden and cell
144
+ # states.
145
 
146
  # replace predictions in a tensor holding predictions for each token
147
  # logging.debug(f"output : {output}")
148
 
149
  ####### DÉBUT DU BEAM SEARCH ##########
150
  if values is None:
151
+ # On calcule une première fois les premières probabilité de mot
152
+ # après <start>
153
  output, hidden, cell = self.decoder(input, hidden, cell)
154
  output.to(self.device)
155
  b_hidden = hidden
 
160
  values, indices = output.topk(num_beams, sorted=True)
161
 
162
  else:
163
+ # On instancie le dictionnaire qui contiendra les scores pour
164
+ # chaque possibilité
165
  scores = {}
166
 
167
  # Pour chacune des meilleures valeurs, on va calculer l'output
 
169
  indice.to(self.device)
170
 
171
  # On calcule l'output
172
+ b_output, b_hidden, b_cell = self.decoder(
173
+ indice, b_hidden, b_cell)
174
 
175
  # On empêche le modèle de se répéter d'un mot sur l'autre en mettant
176
  # de force la probabilité du mot précédent à 0
 
189
  # Et du coup on rempli la place de i-1 à la place de i
190
  b_outputs[i - 1] = b_output.to(self.device)
191
 
192
+ # On instancies nos nouvelles valeurs pour la prochaine
193
+ # itération
194
  values, indices = b_output.topk(num_beams, sorted=True)
195
 
196
  ##################################
src/train.py CHANGED
@@ -150,16 +150,24 @@ if __name__ == "__main__":
150
  words = train_dataset.get_words()
151
  vectoriser = dataloader.Vectoriser(words)
152
 
153
- train_dataset = dataloader.Data("data/train_extract.jsonl", transform=vectoriser)
154
- dev_dataset = dataloader.Data("data/dev_extract.jsonl", transform=vectoriser)
 
 
 
 
155
 
156
  train_dataloader = torch.utils.data.DataLoader(
157
- train_dataset, batch_size=2, shuffle=True, collate_fn=dataloader.pad_collate
158
- )
 
 
159
 
160
  dev_dataloader = torch.utils.data.DataLoader(
161
- dev_dataset, batch_size=4, shuffle=True, collate_fn=dataloader.pad_collate
162
- )
 
 
163
 
164
  for i_batch, batch in enumerate(train_dataloader):
165
  print(i_batch, batch[0], batch[1])
@@ -169,7 +177,8 @@ if __name__ == "__main__":
169
  print("Device check. You are using:", device)
170
 
171
  ### RÉSEAU ENTRAÎNÉ ###
172
- # Pour s'assurer que les résultats seront les mêmes à chaque run du notebook
 
173
  torch.use_deterministic_algorithms(True)
174
  torch.manual_seed(0)
175
  random.seed(0)
@@ -178,9 +187,8 @@ if __name__ == "__main__":
178
  encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
179
  decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
180
 
181
- trained_classifier = EncoderDecoderModel(encoder, decoder, vectoriser, device).to(
182
- device
183
- )
184
 
185
  print(next(trained_classifier.parameters()).device)
186
  # print(train_dataset.is_cuda)
@@ -194,7 +202,6 @@ if __name__ == "__main__":
194
 
195
  torch.save(trained_classifier.state_dict(), "model/model.pt")
196
  vectoriser.save("model/vocab.pkl")
197
- trained_classifier.push_to_hub("SummaryProject-LSTM")
198
 
199
  print(f"test summary : {vectoriser.decode(dev_dataset[6][1])}")
200
  print(
 
150
  words = train_dataset.get_words()
151
  vectoriser = dataloader.Vectoriser(words)
152
 
153
+ train_dataset = dataloader.Data(
154
+ "data/train_extract.jsonl",
155
+ transform=vectoriser)
156
+ dev_dataset = dataloader.Data(
157
+ "data/dev_extract.jsonl",
158
+ transform=vectoriser)
159
 
160
  train_dataloader = torch.utils.data.DataLoader(
161
+ train_dataset,
162
+ batch_size=2,
163
+ shuffle=True,
164
+ collate_fn=dataloader.pad_collate)
165
 
166
  dev_dataloader = torch.utils.data.DataLoader(
167
+ dev_dataset,
168
+ batch_size=4,
169
+ shuffle=True,
170
+ collate_fn=dataloader.pad_collate)
171
 
172
  for i_batch, batch in enumerate(train_dataloader):
173
  print(i_batch, batch[0], batch[1])
 
177
  print("Device check. You are using:", device)
178
 
179
  ### RÉSEAU ENTRAÎNÉ ###
180
+ # Pour s'assurer que les résultats seront les mêmes à chaque run du
181
+ # notebook
182
  torch.use_deterministic_algorithms(True)
183
  torch.manual_seed(0)
184
  random.seed(0)
 
187
  encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
188
  decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
189
 
190
+ trained_classifier = EncoderDecoderModel(
191
+ encoder, decoder, vectoriser, device).to(device)
 
192
 
193
  print(next(trained_classifier.parameters()).device)
194
  # print(train_dataset.is_cuda)
 
202
 
203
  torch.save(trained_classifier.state_dict(), "model/model.pt")
204
  vectoriser.save("model/vocab.pkl")
 
205
 
206
  print(f"test summary : {vectoriser.decode(dev_dataset[6][1])}")
207
  print(