NeoPy commited on
Commit
c439704
·
verified ·
1 Parent(s): 4f1099f

Create utils.py

Browse files
Files changed (1) hide show
  1. rvcinfpy/utils.py +141 -0
rvcinfpy/utils.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, zipfile, shutil, subprocess, shlex, sys # noqa
2
+ from urllib.parse import urlparse
3
+ import re
4
+ import logging
5
+
6
+
7
+ def load_file_from_url(
8
+ url: str,
9
+ model_dir: str,
10
+ file_name: str | None = None,
11
+ overwrite: bool = False,
12
+ progress: bool = True,
13
+ ) -> str:
14
+ """Download a file from `url` into `model_dir`,
15
+ using the file present if possible.
16
+
17
+ Returns the path to the downloaded file.
18
+ """
19
+ os.makedirs(model_dir, exist_ok=True)
20
+ if not file_name:
21
+ parts = urlparse(url)
22
+ file_name = os.path.basename(parts.path)
23
+ cached_file = os.path.abspath(os.path.join(model_dir, file_name))
24
+
25
+ # Overwrite
26
+ if os.path.exists(cached_file):
27
+ if overwrite or os.path.getsize(cached_file) == 0:
28
+ remove_files(cached_file)
29
+
30
+ # Download
31
+ if not os.path.exists(cached_file):
32
+ logger.info(f'Downloading: "{url}" to {cached_file}\n')
33
+ from torch.hub import download_url_to_file
34
+
35
+ download_url_to_file(url, cached_file, progress=progress)
36
+ else:
37
+ logger.debug(cached_file)
38
+
39
+ return cached_file
40
+
41
+
42
+ def friendly_name(file: str):
43
+ if file.startswith("http"):
44
+ file = urlparse(file).path
45
+
46
+ file = os.path.basename(file)
47
+ model_name, extension = os.path.splitext(file)
48
+ return model_name, extension
49
+
50
+
51
+ def download_manager(
52
+ url: str,
53
+ path: str,
54
+ extension: str = "",
55
+ overwrite: bool = False,
56
+ progress: bool = True,
57
+ ):
58
+ url = url.strip()
59
+
60
+ name, ext = friendly_name(url)
61
+ name += ext if not extension else f".{extension}"
62
+
63
+ if url.startswith("http"):
64
+ filename = load_file_from_url(
65
+ url=url,
66
+ model_dir=path,
67
+ file_name=name,
68
+ overwrite=overwrite,
69
+ progress=progress,
70
+ )
71
+ else:
72
+ filename = path
73
+
74
+ return filename
75
+
76
+
77
+ def remove_files(file_list):
78
+ if isinstance(file_list, str):
79
+ file_list = [file_list]
80
+
81
+ for file in file_list:
82
+ if os.path.exists(file):
83
+ os.remove(file)
84
+
85
+
86
+ def remove_directory_contents(directory_path):
87
+ """
88
+ Removes all files and subdirectories within a directory.
89
+
90
+ Parameters:
91
+ directory_path (str): Path to the directory whose
92
+ contents need to be removed.
93
+ """
94
+ if os.path.exists(directory_path):
95
+ for filename in os.listdir(directory_path):
96
+ file_path = os.path.join(directory_path, filename)
97
+ try:
98
+ if os.path.isfile(file_path):
99
+ os.remove(file_path)
100
+ elif os.path.isdir(file_path):
101
+ shutil.rmtree(file_path)
102
+ except Exception as e:
103
+ logger.error(f"Failed to delete {file_path}. Reason: {e}")
104
+ logger.info(f"Content in '{directory_path}' removed.")
105
+ else:
106
+ logger.error(f"Directory '{directory_path}' does not exist.")
107
+
108
+
109
+ # Create directory if not exists
110
+ def create_directories(directory_path):
111
+ if isinstance(directory_path, str):
112
+ directory_path = [directory_path]
113
+ for one_dir_path in directory_path:
114
+ if not os.path.exists(one_dir_path):
115
+ os.makedirs(one_dir_path)
116
+ logger.debug(f"Directory '{one_dir_path}' created.")
117
+
118
+
119
+ def setup_logger(name_log):
120
+ logger = logging.getLogger(name_log)
121
+ logger.setLevel(logging.INFO)
122
+
123
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
124
+ _default_handler.flush = sys.stderr.flush
125
+ logger.addHandler(_default_handler)
126
+
127
+ logger.propagate = False
128
+
129
+ handlers = logger.handlers
130
+
131
+ for handler in handlers:
132
+ formatter = logging.Formatter("[%(levelname)s] >> %(message)s")
133
+ handler.setFormatter(formatter)
134
+
135
+ # logger.handlers
136
+
137
+ return logger
138
+
139
+
140
+ logger = setup_logger("ss")
141
+ logger.setLevel(logging.INFO)