tlemagueresse
commited on
Commit
·
45ee714
1
Parent(s):
cac82bc
Replace pkl by joblib
Browse files- model.py +12 -14
- pipeline.pkl → pipeline.joblib +1 -1
model.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import os
|
2 |
import struct
|
3 |
-
import pickle
|
4 |
from pathlib import Path
|
5 |
from typing import Literal, Union
|
6 |
|
@@ -9,6 +8,7 @@ import torch
|
|
9 |
import lightgbm as lgb
|
10 |
import torchaudio
|
11 |
from huggingface_hub import hf_hub_download
|
|
|
12 |
from sklearn.exceptions import NotFittedError
|
13 |
from torch import Tensor
|
14 |
from torchaudio.transforms import Spectrogram
|
@@ -366,7 +366,7 @@ class FastModelHuggingFace:
|
|
366 |
Methods
|
367 |
-------
|
368 |
from_pretrained(repo_id: str, revision: str = "main",
|
369 |
-
pipeline_file_name: str = "pipeline.
|
370 |
model_file_name: str = "model_lightgbm.txt") -> "FastModelHuggingFace":
|
371 |
Loads the FastModel pipeline and model from the Hugging Face Hub.
|
372 |
predict(input_data: Union[str, "HuggingFaceDataset"], get_proba: bool = False) -> np.ndarray:
|
@@ -392,7 +392,7 @@ class FastModelHuggingFace:
|
|
392 |
cls,
|
393 |
repo_id: str,
|
394 |
revision: str = "main",
|
395 |
-
pipeline_file_name: str = "pipeline.
|
396 |
model_file_name: str = "model_lightgbm.txt",
|
397 |
) -> "FastModelHuggingFace":
|
398 |
"""
|
@@ -405,7 +405,7 @@ class FastModelHuggingFace:
|
|
405 |
revision : str, optional
|
406 |
The specific revision of the repository to use (default is "main").
|
407 |
pipeline_file_name : str, optional
|
408 |
-
The filename of the serialized pipeline (default is "pipeline.
|
409 |
model_file_name : str, optional
|
410 |
The filename of the LightGBM model (default is "model_lightgbm.txt").
|
411 |
|
@@ -424,8 +424,7 @@ class FastModelHuggingFace:
|
|
424 |
|
425 |
if not os.path.exists(pipeline_path):
|
426 |
raise FileNotFoundError(f"Pipeline file {pipeline_path} is missing or corrupted.")
|
427 |
-
|
428 |
-
pipeline = pickle.load(f)
|
429 |
|
430 |
if not os.path.exists(model_lgbm_path):
|
431 |
raise FileNotFoundError(
|
@@ -512,10 +511,10 @@ def save_pipeline(
|
|
512 |
lgbm_file_name : str, optional
|
513 |
The filename for saving the LightGBM model (default is "model_fast_model.txt").
|
514 |
pipeline_file_name : str, optional
|
515 |
-
The filename for saving the pipeline (default is "pipeline.
|
516 |
"""
|
517 |
lgbm_file_name = lgbm_file_name or "model_lightgbm.txt"
|
518 |
-
pipeline_file_name = pipeline_file_name or "pipeline.
|
519 |
|
520 |
lightgbm_path = Path(path) / lgbm_file_name
|
521 |
if model_class_instance.model:
|
@@ -523,8 +522,7 @@ def save_pipeline(
|
|
523 |
model_class_instance.model.save_model(model_class_instance.model_file_name)
|
524 |
|
525 |
pipeline_path = Path(path) / pipeline_file_name
|
526 |
-
|
527 |
-
pickle.dump(model_class_instance, f)
|
528 |
|
529 |
|
530 |
def load_pipeline(
|
@@ -540,7 +538,7 @@ def load_pipeline(
|
|
540 |
lgbm_file_name : str, optional
|
541 |
The filename for the LightGBM model (default is "model_fast_model.txt").
|
542 |
pipeline_file_name : str, optional
|
543 |
-
The filename for the pipeline (default is "pipeline.
|
544 |
|
545 |
Returns
|
546 |
-------
|
@@ -553,13 +551,13 @@ def load_pipeline(
|
|
553 |
If either the LightGBM model or pipeline file is not found.
|
554 |
"""
|
555 |
lgbm_file_name = lgbm_file_name or "model_fast_model.txt"
|
556 |
-
pipeline_file_name = pipeline_file_name or "pipeline.
|
557 |
|
558 |
pipeline_path = Path(path) / pipeline_file_name
|
559 |
if not pipeline_path.exists():
|
560 |
raise FileNotFoundError(f"Pipeline file {pipeline_path} not found.")
|
561 |
-
|
562 |
-
|
563 |
|
564 |
lightgbm_path = Path(path) / lgbm_file_name
|
565 |
if not lightgbm_path.exists():
|
|
|
1 |
import os
|
2 |
import struct
|
|
|
3 |
from pathlib import Path
|
4 |
from typing import Literal, Union
|
5 |
|
|
|
8 |
import lightgbm as lgb
|
9 |
import torchaudio
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
+
from joblib import dump, load
|
12 |
from sklearn.exceptions import NotFittedError
|
13 |
from torch import Tensor
|
14 |
from torchaudio.transforms import Spectrogram
|
|
|
366 |
Methods
|
367 |
-------
|
368 |
from_pretrained(repo_id: str, revision: str = "main",
|
369 |
+
pipeline_file_name: str = "pipeline.joblib",
|
370 |
model_file_name: str = "model_lightgbm.txt") -> "FastModelHuggingFace":
|
371 |
Loads the FastModel pipeline and model from the Hugging Face Hub.
|
372 |
predict(input_data: Union[str, "HuggingFaceDataset"], get_proba: bool = False) -> np.ndarray:
|
|
|
392 |
cls,
|
393 |
repo_id: str,
|
394 |
revision: str = "main",
|
395 |
+
pipeline_file_name: str = "pipeline.joblib",
|
396 |
model_file_name: str = "model_lightgbm.txt",
|
397 |
) -> "FastModelHuggingFace":
|
398 |
"""
|
|
|
405 |
revision : str, optional
|
406 |
The specific revision of the repository to use (default is "main").
|
407 |
pipeline_file_name : str, optional
|
408 |
+
The filename of the serialized pipeline (default is "pipeline.joblib").
|
409 |
model_file_name : str, optional
|
410 |
The filename of the LightGBM model (default is "model_lightgbm.txt").
|
411 |
|
|
|
424 |
|
425 |
if not os.path.exists(pipeline_path):
|
426 |
raise FileNotFoundError(f"Pipeline file {pipeline_path} is missing or corrupted.")
|
427 |
+
pipeline = load(pipeline_path)
|
|
|
428 |
|
429 |
if not os.path.exists(model_lgbm_path):
|
430 |
raise FileNotFoundError(
|
|
|
511 |
lgbm_file_name : str, optional
|
512 |
The filename for saving the LightGBM model (default is "model_fast_model.txt").
|
513 |
pipeline_file_name : str, optional
|
514 |
+
The filename for saving the pipeline (default is "pipeline.joblib").
|
515 |
"""
|
516 |
lgbm_file_name = lgbm_file_name or "model_lightgbm.txt"
|
517 |
+
pipeline_file_name = pipeline_file_name or "pipeline.joblib"
|
518 |
|
519 |
lightgbm_path = Path(path) / lgbm_file_name
|
520 |
if model_class_instance.model:
|
|
|
522 |
model_class_instance.model.save_model(model_class_instance.model_file_name)
|
523 |
|
524 |
pipeline_path = Path(path) / pipeline_file_name
|
525 |
+
dump(model_class_instance, pipeline_path)
|
|
|
526 |
|
527 |
|
528 |
def load_pipeline(
|
|
|
538 |
lgbm_file_name : str, optional
|
539 |
The filename for the LightGBM model (default is "model_fast_model.txt").
|
540 |
pipeline_file_name : str, optional
|
541 |
+
The filename for the pipeline (default is "pipeline.joblib").
|
542 |
|
543 |
Returns
|
544 |
-------
|
|
|
551 |
If either the LightGBM model or pipeline file is not found.
|
552 |
"""
|
553 |
lgbm_file_name = lgbm_file_name or "model_fast_model.txt"
|
554 |
+
pipeline_file_name = pipeline_file_name or "pipeline.joblib"
|
555 |
|
556 |
pipeline_path = Path(path) / pipeline_file_name
|
557 |
if not pipeline_path.exists():
|
558 |
raise FileNotFoundError(f"Pipeline file {pipeline_path} not found.")
|
559 |
+
|
560 |
+
model_class_instance = load(pipeline_path)
|
561 |
|
562 |
lightgbm_path = Path(path) / lgbm_file_name
|
563 |
if not lightgbm_path.exists():
|
pipeline.pkl → pipeline.joblib
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 834053
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:04a292b51ec618f28089ee0933b30e6623f3abff3e282aafaca15b13c402a847
|
3 |
size 834053
|