tlemagueresse commited on
Commit
45ee714
·
1 Parent(s): cac82bc

Replace pkl by joblib

Browse files
Files changed (2) hide show
  1. model.py +12 -14
  2. pipeline.pkl → pipeline.joblib +1 -1
model.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  import struct
3
- import pickle
4
  from pathlib import Path
5
  from typing import Literal, Union
6
 
@@ -9,6 +8,7 @@ import torch
9
  import lightgbm as lgb
10
  import torchaudio
11
  from huggingface_hub import hf_hub_download
 
12
  from sklearn.exceptions import NotFittedError
13
  from torch import Tensor
14
  from torchaudio.transforms import Spectrogram
@@ -366,7 +366,7 @@ class FastModelHuggingFace:
366
  Methods
367
  -------
368
  from_pretrained(repo_id: str, revision: str = "main",
369
- pipeline_file_name: str = "pipeline.pkl",
370
  model_file_name: str = "model_lightgbm.txt") -> "FastModelHuggingFace":
371
  Loads the FastModel pipeline and model from the Hugging Face Hub.
372
  predict(input_data: Union[str, "HuggingFaceDataset"], get_proba: bool = False) -> np.ndarray:
@@ -392,7 +392,7 @@ class FastModelHuggingFace:
392
  cls,
393
  repo_id: str,
394
  revision: str = "main",
395
- pipeline_file_name: str = "pipeline.pkl",
396
  model_file_name: str = "model_lightgbm.txt",
397
  ) -> "FastModelHuggingFace":
398
  """
@@ -405,7 +405,7 @@ class FastModelHuggingFace:
405
  revision : str, optional
406
  The specific revision of the repository to use (default is "main").
407
  pipeline_file_name : str, optional
408
- The filename of the serialized pipeline (default is "pipeline.pkl").
409
  model_file_name : str, optional
410
  The filename of the LightGBM model (default is "model_lightgbm.txt").
411
 
@@ -424,8 +424,7 @@ class FastModelHuggingFace:
424
 
425
  if not os.path.exists(pipeline_path):
426
  raise FileNotFoundError(f"Pipeline file {pipeline_path} is missing or corrupted.")
427
- with open(pipeline_path, "rb") as f:
428
- pipeline = pickle.load(f)
429
 
430
  if not os.path.exists(model_lgbm_path):
431
  raise FileNotFoundError(
@@ -512,10 +511,10 @@ def save_pipeline(
512
  lgbm_file_name : str, optional
513
  The filename for saving the LightGBM model (default is "model_fast_model.txt").
514
  pipeline_file_name : str, optional
515
- The filename for saving the pipeline (default is "pipeline.pkl").
516
  """
517
  lgbm_file_name = lgbm_file_name or "model_lightgbm.txt"
518
- pipeline_file_name = pipeline_file_name or "pipeline.pkl"
519
 
520
  lightgbm_path = Path(path) / lgbm_file_name
521
  if model_class_instance.model:
@@ -523,8 +522,7 @@ def save_pipeline(
523
  model_class_instance.model.save_model(model_class_instance.model_file_name)
524
 
525
  pipeline_path = Path(path) / pipeline_file_name
526
- with open(pipeline_path, "wb") as f:
527
- pickle.dump(model_class_instance, f)
528
 
529
 
530
  def load_pipeline(
@@ -540,7 +538,7 @@ def load_pipeline(
540
  lgbm_file_name : str, optional
541
  The filename for the LightGBM model (default is "model_fast_model.txt").
542
  pipeline_file_name : str, optional
543
- The filename for the pipeline (default is "pipeline.pkl").
544
 
545
  Returns
546
  -------
@@ -553,13 +551,13 @@ def load_pipeline(
553
  If either the LightGBM model or pipeline file is not found.
554
  """
555
  lgbm_file_name = lgbm_file_name or "model_fast_model.txt"
556
- pipeline_file_name = pipeline_file_name or "pipeline.pkl"
557
 
558
  pipeline_path = Path(path) / pipeline_file_name
559
  if not pipeline_path.exists():
560
  raise FileNotFoundError(f"Pipeline file {pipeline_path} not found.")
561
- with open(pipeline_path, "rb") as f:
562
- model_class_instance = pickle.load(f)
563
 
564
  lightgbm_path = Path(path) / lgbm_file_name
565
  if not lightgbm_path.exists():
 
1
  import os
2
  import struct
 
3
  from pathlib import Path
4
  from typing import Literal, Union
5
 
 
8
  import lightgbm as lgb
9
  import torchaudio
10
  from huggingface_hub import hf_hub_download
11
+ from joblib import dump, load
12
  from sklearn.exceptions import NotFittedError
13
  from torch import Tensor
14
  from torchaudio.transforms import Spectrogram
 
366
  Methods
367
  -------
368
  from_pretrained(repo_id: str, revision: str = "main",
369
+ pipeline_file_name: str = "pipeline.joblib",
370
  model_file_name: str = "model_lightgbm.txt") -> "FastModelHuggingFace":
371
  Loads the FastModel pipeline and model from the Hugging Face Hub.
372
  predict(input_data: Union[str, "HuggingFaceDataset"], get_proba: bool = False) -> np.ndarray:
 
392
  cls,
393
  repo_id: str,
394
  revision: str = "main",
395
+ pipeline_file_name: str = "pipeline.joblib",
396
  model_file_name: str = "model_lightgbm.txt",
397
  ) -> "FastModelHuggingFace":
398
  """
 
405
  revision : str, optional
406
  The specific revision of the repository to use (default is "main").
407
  pipeline_file_name : str, optional
408
+ The filename of the serialized pipeline (default is "pipeline.joblib").
409
  model_file_name : str, optional
410
  The filename of the LightGBM model (default is "model_lightgbm.txt").
411
 
 
424
 
425
  if not os.path.exists(pipeline_path):
426
  raise FileNotFoundError(f"Pipeline file {pipeline_path} is missing or corrupted.")
427
+ pipeline = load(pipeline_path)
 
428
 
429
  if not os.path.exists(model_lgbm_path):
430
  raise FileNotFoundError(
 
511
  lgbm_file_name : str, optional
512
  The filename for saving the LightGBM model (default is "model_fast_model.txt").
513
  pipeline_file_name : str, optional
514
+ The filename for saving the pipeline (default is "pipeline.joblib").
515
  """
516
  lgbm_file_name = lgbm_file_name or "model_lightgbm.txt"
517
+ pipeline_file_name = pipeline_file_name or "pipeline.joblib"
518
 
519
  lightgbm_path = Path(path) / lgbm_file_name
520
  if model_class_instance.model:
 
522
  model_class_instance.model.save_model(model_class_instance.model_file_name)
523
 
524
  pipeline_path = Path(path) / pipeline_file_name
525
+ dump(model_class_instance, pipeline_path)
 
526
 
527
 
528
  def load_pipeline(
 
538
  lgbm_file_name : str, optional
539
  The filename for the LightGBM model (default is "model_fast_model.txt").
540
  pipeline_file_name : str, optional
541
+ The filename for the pipeline (default is "pipeline.joblib").
542
 
543
  Returns
544
  -------
 
551
  If either the LightGBM model or pipeline file is not found.
552
  """
553
  lgbm_file_name = lgbm_file_name or "model_fast_model.txt"
554
+ pipeline_file_name = pipeline_file_name or "pipeline.joblib"
555
 
556
  pipeline_path = Path(path) / pipeline_file_name
557
  if not pipeline_path.exists():
558
  raise FileNotFoundError(f"Pipeline file {pipeline_path} not found.")
559
+
560
+ model_class_instance = load(pipeline_path)
561
 
562
  lightgbm_path = Path(path) / lgbm_file_name
563
  if not lightgbm_path.exists():
pipeline.pkl → pipeline.joblib RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:faff5f8ba72a4be0fe89fb5951c53fe70b5ccd53170e81c141a27691361b9155
3
  size 834053
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04a292b51ec618f28089ee0933b30e6623f3abff3e282aafaca15b13c402a847
3
  size 834053