jhj0517
commited on
Commit
·
c311090
1
Parent(s):
d810696
Update default model
Browse files
modules/image_restoration/real_esrgan_inferencer.py
CHANGED
@@ -19,13 +19,14 @@ class RealESRGANInferencer:
|
|
19 |
self.device = self.get_device()
|
20 |
self.model = None
|
21 |
self.available_models = list(MODELS_REALESRGAN_URL.keys())
|
|
|
22 |
|
23 |
def load_model(self,
|
24 |
model_name: Optional[str] = None,
|
25 |
scale: int = 1,
|
26 |
progress: gr.Progress = gr.Progress()):
|
27 |
if model_name is None:
|
28 |
-
model_name =
|
29 |
if not model_name.endswith(".pth"):
|
30 |
model_name += ".pth"
|
31 |
model_path = os.path.join(self.model_dir, model_name)
|
@@ -53,6 +54,7 @@ class RealESRGANInferencer:
|
|
53 |
else:
|
54 |
output_path = get_auto_incremental_file_path(self.output_dir, extension="png")
|
55 |
sr_img.save(output_path)
|
|
|
56 |
except Exception as e:
|
57 |
raise
|
58 |
|
|
|
19 |
self.device = self.get_device()
|
20 |
self.model = None
|
21 |
self.available_models = list(MODELS_REALESRGAN_URL.keys())
|
22 |
+
self.default_model = self.available_models[0]
|
23 |
|
24 |
def load_model(self,
|
25 |
model_name: Optional[str] = None,
|
26 |
scale: int = 1,
|
27 |
progress: gr.Progress = gr.Progress()):
|
28 |
if model_name is None:
|
29 |
+
model_name = self.default_model
|
30 |
if not model_name.endswith(".pth"):
|
31 |
model_name += ".pth"
|
32 |
model_path = os.path.join(self.model_dir, model_name)
|
|
|
54 |
else:
|
55 |
output_path = get_auto_incremental_file_path(self.output_dir, extension="png")
|
56 |
sr_img.save(output_path)
|
57 |
+
return output_path
|
58 |
except Exception as e:
|
59 |
raise
|
60 |
|