jhj0517
commited on
Commit
·
c311090
1
Parent(s):
d810696
Update default model
Browse files
modules/image_restoration/real_esrgan_inferencer.py
CHANGED
|
@@ -19,13 +19,14 @@ class RealESRGANInferencer:
|
|
| 19 |
self.device = self.get_device()
|
| 20 |
self.model = None
|
| 21 |
self.available_models = list(MODELS_REALESRGAN_URL.keys())
|
|
|
|
| 22 |
|
| 23 |
def load_model(self,
|
| 24 |
model_name: Optional[str] = None,
|
| 25 |
scale: int = 1,
|
| 26 |
progress: gr.Progress = gr.Progress()):
|
| 27 |
if model_name is None:
|
| 28 |
-
model_name =
|
| 29 |
if not model_name.endswith(".pth"):
|
| 30 |
model_name += ".pth"
|
| 31 |
model_path = os.path.join(self.model_dir, model_name)
|
|
@@ -53,6 +54,7 @@ class RealESRGANInferencer:
|
|
| 53 |
else:
|
| 54 |
output_path = get_auto_incremental_file_path(self.output_dir, extension="png")
|
| 55 |
sr_img.save(output_path)
|
|
|
|
| 56 |
except Exception as e:
|
| 57 |
raise
|
| 58 |
|
|
|
|
| 19 |
self.device = self.get_device()
|
| 20 |
self.model = None
|
| 21 |
self.available_models = list(MODELS_REALESRGAN_URL.keys())
|
| 22 |
+
self.default_model = self.available_models[0]
|
| 23 |
|
| 24 |
def load_model(self,
|
| 25 |
model_name: Optional[str] = None,
|
| 26 |
scale: int = 1,
|
| 27 |
progress: gr.Progress = gr.Progress()):
|
| 28 |
if model_name is None:
|
| 29 |
+
model_name = self.default_model
|
| 30 |
if not model_name.endswith(".pth"):
|
| 31 |
model_name += ".pth"
|
| 32 |
model_path = os.path.join(self.model_dir, model_name)
|
|
|
|
| 54 |
else:
|
| 55 |
output_path = get_auto_incremental_file_path(self.output_dir, extension="png")
|
| 56 |
sr_img.save(output_path)
|
| 57 |
+
return output_path
|
| 58 |
except Exception as e:
|
| 59 |
raise
|
| 60 |
|