EveSa commited on
Commit
cce255f
·
1 Parent(s): ca615e0

inference fonctionnelle sans load de fichier

Browse files
Files changed (1) hide show
  1. src/inference.py +6 -8
src/inference.py CHANGED
@@ -8,7 +8,7 @@ import torch
8
  from src import dataloader
9
  from src.model import Decoder, Encoder, EncoderDecoderModel
10
 
11
- with open ("model/vocab.pkl", 'rb') as vocab:
12
  words = pickle.load(vocab)
13
  vectoriser = dataloader.Vectoriser(words)
14
 
@@ -27,12 +27,10 @@ def inferenceAPI(text: str) -> str:
27
  text = text.split()
28
  # On défini les paramètres d'entrée pour le modèle
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
- encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device).to(
31
- device
32
- )
33
- decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device).to(
34
- device
35
- )
36
 
37
  # On instancie le modèle
38
  model = EncoderDecoderModel(encoder, decoder, vectoriser, device)
@@ -49,7 +47,7 @@ def inferenceAPI(text: str) -> str:
49
  with torch.no_grad():
50
  output = model(source).to(device)
51
  output.to(device)
52
- output=output.argmax(dim=-1)
53
  return vectoriser.decode(output)
54
 
55
 
 
8
  from src import dataloader
9
  from src.model import Decoder, Encoder, EncoderDecoderModel
10
 
11
+ with open("model/vocab.pkl", "rb") as vocab:
12
  words = pickle.load(vocab)
13
  vectoriser = dataloader.Vectoriser(words)
14
 
 
27
  text = text.split()
28
  # On défini les paramètres d'entrée pour le modèle
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
31
+ encoder.to(device)
32
+ decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
33
+ decoder.to(device)
 
 
34
 
35
  # On instancie le modèle
36
  model = EncoderDecoderModel(encoder, decoder, vectoriser, device)
 
47
  with torch.no_grad():
48
  output = model(source).to(device)
49
  output.to(device)
50
+ output = output.argmax(dim=-1)
51
  return vectoriser.decode(output)
52
 
53