Spaces:
Runtime error
Runtime error
Merge pull request #4 from EveSa/Eve
Browse files- .gitattributes +0 -1
- requirements.txt +11 -0
- src/inference.py +7 -6
- src/train.py +1 -0
.gitattributes
CHANGED
@@ -1,2 +1 @@
|
|
1 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
requirements.txt
CHANGED
@@ -1,26 +1,37 @@
|
|
1 |
anyio==3.6.2
|
|
|
|
|
2 |
click==8.1.3
|
3 |
fastapi==0.92.0
|
|
|
4 |
h11==0.14.0
|
|
|
5 |
idna==3.4
|
6 |
Jinja2==3.1.2
|
7 |
joblib==1.2.0
|
8 |
MarkupSafe==2.1.2
|
|
|
9 |
numpy==1.24.2
|
10 |
nvidia-cublas-cu11==11.10.3.66
|
11 |
nvidia-cuda-nvrtc-cu11==11.7.99
|
12 |
nvidia-cuda-runtime-cu11==11.7.99
|
13 |
nvidia-cudnn-cu11==8.5.0.96
|
|
|
14 |
pandas==1.5.3
|
15 |
pydantic==1.10.5
|
16 |
python-dateutil==2.8.2
|
17 |
python-multipart==0.0.6
|
18 |
pytz==2022.7.1
|
|
|
19 |
regex==2022.10.31
|
|
|
20 |
six==1.16.0
|
21 |
sniffio==1.3.0
|
22 |
starlette==0.25.0
|
|
|
23 |
torch==1.13.1
|
24 |
tqdm==4.65.0
|
|
|
25 |
typing_extensions==4.5.0
|
|
|
26 |
uvicorn==0.20.0
|
|
|
1 |
anyio==3.6.2
|
2 |
+
certifi==2022.12.7
|
3 |
+
charset-normalizer==3.1.0
|
4 |
click==8.1.3
|
5 |
fastapi==0.92.0
|
6 |
+
filelock==3.9.0
|
7 |
h11==0.14.0
|
8 |
+
huggingface-hub==0.13.1
|
9 |
idna==3.4
|
10 |
Jinja2==3.1.2
|
11 |
joblib==1.2.0
|
12 |
MarkupSafe==2.1.2
|
13 |
+
nltk==3.8.1
|
14 |
numpy==1.24.2
|
15 |
nvidia-cublas-cu11==11.10.3.66
|
16 |
nvidia-cuda-nvrtc-cu11==11.7.99
|
17 |
nvidia-cuda-runtime-cu11==11.7.99
|
18 |
nvidia-cudnn-cu11==8.5.0.96
|
19 |
+
packaging==23.0
|
20 |
pandas==1.5.3
|
21 |
pydantic==1.10.5
|
22 |
python-dateutil==2.8.2
|
23 |
python-multipart==0.0.6
|
24 |
pytz==2022.7.1
|
25 |
+
PyYAML==6.0
|
26 |
regex==2022.10.31
|
27 |
+
requests==2.28.2
|
28 |
six==1.16.0
|
29 |
sniffio==1.3.0
|
30 |
starlette==0.25.0
|
31 |
+
tokenizers==0.13.2
|
32 |
torch==1.13.1
|
33 |
tqdm==4.65.0
|
34 |
+
transformers==4.26.1
|
35 |
typing_extensions==4.5.0
|
36 |
+
urllib3==1.26.15
|
37 |
uvicorn==0.20.0
|
src/inference.py
CHANGED
@@ -7,6 +7,7 @@ import torch
|
|
7 |
|
8 |
import dataloader
|
9 |
from model import Decoder, Encoder, EncoderDecoderModel
|
|
|
10 |
|
11 |
with open("model/vocab.pkl", "rb") as vocab:
|
12 |
words = pickle.load(vocab)
|
@@ -35,9 +36,9 @@ def inferenceAPI(text: str) -> str:
|
|
35 |
# On instancie le modèle
|
36 |
model = EncoderDecoderModel(encoder, decoder, vectoriser, device)
|
37 |
|
38 |
-
model.load_state_dict(torch.load("model/model.pt", map_location=device))
|
39 |
-
model.eval()
|
40 |
-
model.to(device)
|
41 |
|
42 |
# On vectorise le texte
|
43 |
source = vectoriser.encode(text)
|
@@ -51,6 +52,6 @@ def inferenceAPI(text: str) -> str:
|
|
51 |
return vectoriser.decode(output)
|
52 |
|
53 |
|
54 |
-
|
55 |
-
#
|
56 |
-
|
|
|
7 |
|
8 |
import dataloader
|
9 |
from model import Decoder, Encoder, EncoderDecoderModel
|
10 |
+
from transformers import AutoModel
|
11 |
|
12 |
with open("model/vocab.pkl", "rb") as vocab:
|
13 |
words = pickle.load(vocab)
|
|
|
36 |
# On instancie le modèle
|
37 |
model = EncoderDecoderModel(encoder, decoder, vectoriser, device)
|
38 |
|
39 |
+
# model.load_state_dict(torch.load("model/model.pt", map_location=device))
|
40 |
+
# model.eval()
|
41 |
+
# model.to(device)
|
42 |
|
43 |
# On vectorise le texte
|
44 |
source = vectoriser.encode(text)
|
|
|
52 |
return vectoriser.decode(output)
|
53 |
|
54 |
|
55 |
+
if __name__ == "__main__":
|
56 |
+
# inference()
|
57 |
+
print(inferenceAPI("If you choose to use these attributes in logged messages, you need to exercise some care. In the above example, for instance, the Formatter has been set up with a format string which expects ‘clientip’ and ‘user’ in the attribute dictionary of the LogRecord. If these are missing, the message will not be logged because a string formatting exception will occur. So in this case, you always need to pass the extra dictionary with these keys."))
|
src/train.py
CHANGED
@@ -194,6 +194,7 @@ if __name__ == "__main__":
|
|
194 |
|
195 |
torch.save(trained_classifier.state_dict(), "model/model.pt")
|
196 |
vectoriser.save("model/vocab.pkl")
|
|
|
197 |
|
198 |
print(f"test summary : {vectoriser.decode(dev_dataset[6][1])}")
|
199 |
print(
|
|
|
194 |
|
195 |
torch.save(trained_classifier.state_dict(), "model/model.pt")
|
196 |
vectoriser.save("model/vocab.pkl")
|
197 |
+
trained_classifier.push_to_hub("SummaryProject-LSTM")
|
198 |
|
199 |
print(f"test summary : {vectoriser.decode(dev_dataset[6][1])}")
|
200 |
print(
|