Update app.py
Browse files
app.py
CHANGED
|
@@ -102,6 +102,7 @@ class Predictor:
|
|
| 102 |
def __init__(self):
|
| 103 |
self.model_target_size = None
|
| 104 |
self.last_loaded_repo = None
|
|
|
|
| 105 |
|
| 106 |
def download_model(self, model_repo):
|
| 107 |
csv_path = huggingface_hub.hf_hub_download(
|
|
@@ -117,11 +118,11 @@ class Predictor:
|
|
| 117 |
return csv_path, model_path
|
| 118 |
|
| 119 |
def load_model(self, model_repo):
|
| 120 |
-
|
|
|
|
| 121 |
return
|
| 122 |
-
|
| 123 |
csv_path, model_path = self.download_model(model_repo)
|
| 124 |
-
|
| 125 |
tags_df = pd.read_csv(csv_path)
|
| 126 |
sep_tags = load_labels(tags_df)
|
| 127 |
|
|
@@ -130,12 +131,23 @@ class Predictor:
|
|
| 130 |
self.general_indexes = sep_tags[2]
|
| 131 |
self.character_indexes = sep_tags[3]
|
| 132 |
|
| 133 |
-
model = rt.InferenceSession(model_path)
|
| 134 |
-
_, height, width, _ = model.get_inputs()[0].shape
|
| 135 |
-
self.model_target_size = height
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
def prepare_image(self, image):
|
| 141 |
target_size = self.model_target_size
|
|
@@ -179,6 +191,9 @@ class Predictor:
|
|
| 179 |
):
|
| 180 |
self.load_model(model_repo)
|
| 181 |
|
|
|
|
|
|
|
|
|
|
| 182 |
image = self.prepare_image(image)
|
| 183 |
|
| 184 |
input_name = self.model.get_inputs()[0].name
|
|
@@ -347,4 +362,4 @@ def main():
|
|
| 347 |
|
| 348 |
|
| 349 |
if __name__ == "__main__":
|
| 350 |
-
main()
|
|
|
|
| 102 |
def __init__(self):
|
| 103 |
self.model_target_size = None
|
| 104 |
self.last_loaded_repo = None
|
| 105 |
+
self.model = None # Inisialisasi model di sini
|
| 106 |
|
| 107 |
def download_model(self, model_repo):
|
| 108 |
csv_path = huggingface_hub.hf_hub_download(
|
|
|
|
| 118 |
return csv_path, model_path
|
| 119 |
|
| 120 |
def load_model(self, model_repo):
|
| 121 |
+
# Cek apakah model sudah dimuat
|
| 122 |
+
if model_repo == self.last_loaded_repo and self.model is not None:
|
| 123 |
return
|
| 124 |
+
|
| 125 |
csv_path, model_path = self.download_model(model_repo)
|
|
|
|
| 126 |
tags_df = pd.read_csv(csv_path)
|
| 127 |
sep_tags = load_labels(tags_df)
|
| 128 |
|
|
|
|
| 131 |
self.general_indexes = sep_tags[2]
|
| 132 |
self.character_indexes = sep_tags[3]
|
| 133 |
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
+
# Gunakan CPU execution provider jika GPU tidak tersedia
|
| 136 |
+
providers = ["CPUExecutionProvider"]
|
| 137 |
+
if rt.get_device() == "GPU":
|
| 138 |
+
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
| 139 |
+
try:
|
| 140 |
+
model = rt.InferenceSession(model_path, providers=providers)
|
| 141 |
+
_, height, width, _ = model.get_inputs()[0].shape
|
| 142 |
+
self.model_target_size = height
|
| 143 |
+
self.last_loaded_repo = model_repo
|
| 144 |
+
self.model = model
|
| 145 |
+
|
| 146 |
+
except Exception as e:
|
| 147 |
+
print(f"Error loading model with given providers: {e}")
|
| 148 |
+
self.model = None
|
| 149 |
+
self.last_loaded_repo = None
|
| 150 |
+
|
| 151 |
|
| 152 |
def prepare_image(self, image):
|
| 153 |
target_size = self.model_target_size
|
|
|
|
| 191 |
):
|
| 192 |
self.load_model(model_repo)
|
| 193 |
|
| 194 |
+
if self.model is None:
|
| 195 |
+
return "", {}, {}, {}
|
| 196 |
+
|
| 197 |
image = self.prepare_image(image)
|
| 198 |
|
| 199 |
input_name = self.model.get_inputs()[0].name
|
|
|
|
| 362 |
|
| 363 |
|
| 364 |
if __name__ == "__main__":
|
| 365 |
+
main()
|