Spaces:
Runtime error
Runtime error
Initial commit
Browse files- requirements.txt +11 -0
- src/inference.py +6 -3
- src/train.py +1 -0
requirements.txt
CHANGED
|
@@ -1,26 +1,37 @@
|
|
| 1 |
anyio==3.6.2
|
|
|
|
|
|
|
| 2 |
click==8.1.3
|
| 3 |
fastapi==0.92.0
|
|
|
|
| 4 |
h11==0.14.0
|
|
|
|
| 5 |
idna==3.4
|
| 6 |
Jinja2==3.1.2
|
| 7 |
joblib==1.2.0
|
| 8 |
MarkupSafe==2.1.2
|
|
|
|
| 9 |
numpy==1.24.2
|
| 10 |
nvidia-cublas-cu11==11.10.3.66
|
| 11 |
nvidia-cuda-nvrtc-cu11==11.7.99
|
| 12 |
nvidia-cuda-runtime-cu11==11.7.99
|
| 13 |
nvidia-cudnn-cu11==8.5.0.96
|
|
|
|
| 14 |
pandas==1.5.3
|
| 15 |
pydantic==1.10.5
|
| 16 |
python-dateutil==2.8.2
|
| 17 |
python-multipart==0.0.6
|
| 18 |
pytz==2022.7.1
|
|
|
|
| 19 |
regex==2022.10.31
|
|
|
|
| 20 |
six==1.16.0
|
| 21 |
sniffio==1.3.0
|
| 22 |
starlette==0.25.0
|
|
|
|
| 23 |
torch==1.13.1
|
| 24 |
tqdm==4.65.0
|
|
|
|
| 25 |
typing_extensions==4.5.0
|
|
|
|
| 26 |
uvicorn==0.20.0
|
|
|
|
| 1 |
anyio==3.6.2
|
| 2 |
+
certifi==2022.12.7
|
| 3 |
+
charset-normalizer==3.1.0
|
| 4 |
click==8.1.3
|
| 5 |
fastapi==0.92.0
|
| 6 |
+
filelock==3.9.0
|
| 7 |
h11==0.14.0
|
| 8 |
+
huggingface-hub==0.13.1
|
| 9 |
idna==3.4
|
| 10 |
Jinja2==3.1.2
|
| 11 |
joblib==1.2.0
|
| 12 |
MarkupSafe==2.1.2
|
| 13 |
+
nltk==3.8.1
|
| 14 |
numpy==1.24.2
|
| 15 |
nvidia-cublas-cu11==11.10.3.66
|
| 16 |
nvidia-cuda-nvrtc-cu11==11.7.99
|
| 17 |
nvidia-cuda-runtime-cu11==11.7.99
|
| 18 |
nvidia-cudnn-cu11==8.5.0.96
|
| 19 |
+
packaging==23.0
|
| 20 |
pandas==1.5.3
|
| 21 |
pydantic==1.10.5
|
| 22 |
python-dateutil==2.8.2
|
| 23 |
python-multipart==0.0.6
|
| 24 |
pytz==2022.7.1
|
| 25 |
+
PyYAML==6.0
|
| 26 |
regex==2022.10.31
|
| 27 |
+
requests==2.28.2
|
| 28 |
six==1.16.0
|
| 29 |
sniffio==1.3.0
|
| 30 |
starlette==0.25.0
|
| 31 |
+
tokenizers==0.13.2
|
| 32 |
torch==1.13.1
|
| 33 |
tqdm==4.65.0
|
| 34 |
+
transformers==4.26.1
|
| 35 |
typing_extensions==4.5.0
|
| 36 |
+
urllib3==1.26.15
|
| 37 |
uvicorn==0.20.0
|
src/inference.py
CHANGED
|
@@ -7,6 +7,7 @@ import torch
|
|
| 7 |
|
| 8 |
import dataloader
|
| 9 |
from model import Decoder, Encoder, EncoderDecoderModel
|
|
|
|
| 10 |
|
| 11 |
with open("model/vocab.pkl", "rb") as vocab:
|
| 12 |
words = pickle.load(vocab)
|
|
@@ -33,6 +34,8 @@ def inferenceAPI(text: str) -> str:
|
|
| 33 |
decoder.to(device)
|
| 34 |
|
| 35 |
# On instancie le modèle
|
|
|
|
|
|
|
| 36 |
model = EncoderDecoderModel(encoder, decoder, vectoriser, device)
|
| 37 |
|
| 38 |
model.load_state_dict(torch.load("model/model.pt", map_location=device))
|
|
@@ -51,6 +54,6 @@ def inferenceAPI(text: str) -> str:
|
|
| 51 |
return vectoriser.decode(output)
|
| 52 |
|
| 53 |
|
| 54 |
-
|
| 55 |
-
#
|
| 56 |
-
|
|
|
|
| 7 |
|
| 8 |
import dataloader
|
| 9 |
from model import Decoder, Encoder, EncoderDecoderModel
|
| 10 |
+
from transformers import AutoModel
|
| 11 |
|
| 12 |
with open("model/vocab.pkl", "rb") as vocab:
|
| 13 |
words = pickle.load(vocab)
|
|
|
|
| 34 |
decoder.to(device)
|
| 35 |
|
| 36 |
# On instancie le modèle
|
| 37 |
+
model = AutoModel.from_pretrained("EveSa/SummaryProject-LSTM", revision="main")
|
| 38 |
+
model = AutoModel.PretrainedConfig()
|
| 39 |
model = EncoderDecoderModel(encoder, decoder, vectoriser, device)
|
| 40 |
|
| 41 |
model.load_state_dict(torch.load("model/model.pt", map_location=device))
|
|
|
|
| 54 |
return vectoriser.decode(output)
|
| 55 |
|
| 56 |
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
# inference()
|
| 59 |
+
print(inferenceAPI("If you choose to use these attributes in logged messages, you need to exercise some care. In the above example, for instance, the Formatter has been set up with a format string which expects ‘clientip’ and ‘user’ in the attribute dictionary of the LogRecord. If these are missing, the message will not be logged because a string formatting exception will occur. So in this case, you always need to pass the extra dictionary with these keys."))
|
src/train.py
CHANGED
|
@@ -194,6 +194,7 @@ if __name__ == "__main__":
|
|
| 194 |
|
| 195 |
torch.save(trained_classifier.state_dict(), "model/model.pt")
|
| 196 |
vectoriser.save("model/vocab.pkl")
|
|
|
|
| 197 |
|
| 198 |
print(f"test summary : {vectoriser.decode(dev_dataset[6][1])}")
|
| 199 |
print(
|
|
|
|
| 194 |
|
| 195 |
torch.save(trained_classifier.state_dict(), "model/model.pt")
|
| 196 |
vectoriser.save("model/vocab.pkl")
|
| 197 |
+
trained_classifier.config.to_json_file("config.json")
|
| 198 |
|
| 199 |
print(f"test summary : {vectoriser.decode(dev_dataset[6][1])}")
|
| 200 |
print(
|