gyrojeff commited on
Commit
c3f3bd0
·
1 Parent(s): ded80a4

feat: add huggingface hub model support

Browse files
Files changed (1) hide show
  1. demo.py +10 -1
demo.py CHANGED
@@ -6,6 +6,7 @@ from torchvision import transforms
6
  from detector.model import *
7
  from detector import config
8
  from font_dataset.font import load_fonts, load_font_with_exclusion
 
9
 
10
  parser = argparse.ArgumentParser()
11
  parser.add_argument(
@@ -20,7 +21,7 @@ parser.add_argument(
20
  "--checkpoint",
21
  type=str,
22
  default=None,
23
- help="Trainer checkpoint path (default: None)",
24
  )
25
  parser.add_argument(
26
  "-m",
@@ -76,6 +77,14 @@ else:
76
  if torch.__version__ >= "2.0" and os.name == "posix":
77
  model = torch.compile(model)
78
 
 
 
 
 
 
 
 
 
79
  detector = FontDetector(
80
  model=model,
81
  lambda_font=1,
 
6
  from detector.model import *
7
  from detector import config
8
  from font_dataset.font import load_fonts, load_font_with_exclusion
9
+ from huggingface_hub import hf_hub_download
10
 
11
  parser = argparse.ArgumentParser()
12
  parser.add_argument(
 
21
  "--checkpoint",
22
  type=str,
23
  default=None,
24
+ help="Trainer checkpoint path (default: None). Use link as huggingface://<user>/<repo>/<file> for huggingface.co models, currently only supports model file in the root directory.",
25
  )
26
  parser.add_argument(
27
  "-m",
 
77
  if torch.__version__ >= "2.0" and os.name == "posix":
78
  model = torch.compile(model)
79
 
80
+
81
+ if str(args.checkpoint).startswith("huggingface://"):
82
+ args.checkpoint = args.checkpoint[14:]
83
+ user, repo, file = args.checkpoint.split("/")
84
+ repo = f"{user}/{repo}"
85
+ args.checkpoint = hf_hub_download(repo, file)
86
+
87
+
88
  detector = FontDetector(
89
  model=model,
90
  lambda_font=1,