jhj0517 commited on
Commit
c311090
·
1 Parent(s): d810696

Update default model

Browse files
modules/image_restoration/real_esrgan_inferencer.py CHANGED
@@ -19,13 +19,14 @@ class RealESRGANInferencer:
19
  self.device = self.get_device()
20
  self.model = None
21
  self.available_models = list(MODELS_REALESRGAN_URL.keys())
 
22
 
23
  def load_model(self,
24
  model_name: Optional[str] = None,
25
  scale: int = 1,
26
  progress: gr.Progress = gr.Progress()):
27
  if model_name is None:
28
- model_name = "realesr-general-x4v3"
29
  if not model_name.endswith(".pth"):
30
  model_name += ".pth"
31
  model_path = os.path.join(self.model_dir, model_name)
@@ -53,6 +54,7 @@ class RealESRGANInferencer:
53
  else:
54
  output_path = get_auto_incremental_file_path(self.output_dir, extension="png")
55
  sr_img.save(output_path)
 
56
  except Exception as e:
57
  raise
58
 
 
19
  self.device = self.get_device()
20
  self.model = None
21
  self.available_models = list(MODELS_REALESRGAN_URL.keys())
22
+ self.default_model = self.available_models[0]
23
 
24
  def load_model(self,
25
  model_name: Optional[str] = None,
26
  scale: int = 1,
27
  progress: gr.Progress = gr.Progress()):
28
  if model_name is None:
29
+ model_name = self.default_model
30
  if not model_name.endswith(".pth"):
31
  model_name += ".pth"
32
  model_path = os.path.join(self.model_dir, model_name)
 
54
  else:
55
  output_path = get_auto_incremental_file_path(self.output_dir, extension="png")
56
  sr_img.save(output_path)
57
+ return output_path
58
  except Exception as e:
59
  raise
60