MoritzMMuller commited on
Commit
36422f4
·
verified ·
1 Parent(s): 3567378

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -20
app.py CHANGED
@@ -39,29 +39,27 @@ geolocator = Nominatim(user_agent="skin-dashboard", timeout = 10)
39
 
40
  @st.cache_resource
41
  def load_image_model(token: str):
42
- # 1) Download the repo into a writable path
43
- cache_dir = "/mnt/data/model_cache"
44
- # snapshot_download will create this folder if it doesn't exist
45
- local_dir = snapshot_download(
46
- repo_id=MODEL_NAME,
47
- use_auth_token=token,
48
- cache_dir=cache_dir
49
  )
50
 
51
- # 2) Inject a valid config.json there
52
- cfg = {
53
- "architectures": ["ConvNextForImageClassification"],
54
- "model_type": "convnext",
55
- "num_labels": 2,
56
- "id2label": { "0": "benign", "1": "malignant" },
57
- "label2id": { "benign": 0, "malignant": 1 }
58
- }
59
- with open(os.path.join(local_dir, "config.json"), "w") as f:
60
- json.dump(cfg, f)
 
 
 
61
 
62
- # 3) Load from the patched local snapshot
63
- extractor = AutoFeatureExtractor.from_pretrained(local_dir)
64
- model = AutoModelForImageClassification.from_pretrained(local_dir)
65
  return pipeline(
66
  "image-classification",
67
  model=model,
 
39
 
40
  @st.cache_resource
41
  def load_image_model(token: str):
42
+ # 1) load the feature extractor from the Hub as usual
43
+ extractor = AutoFeatureExtractor.from_pretrained(
44
+ MODEL_NAME,
45
+ use_auth_token=token
 
 
 
46
  )
47
 
48
+ # 2) manually create a ConvNextConfig with the right num_labels / id2label
49
+ config = ConvNextConfig(
50
+ num_labels=2,
51
+ id2label={0: "benign", 1: "malignant"},
52
+ label2id={"benign": 0, "malignant": 1}
53
+ )
54
+
55
+ # 3) load the weights with that config override
56
+ model = AutoModelForImageClassification.from_pretrained(
57
+ MODEL_NAME,
58
+ config=config,
59
+ use_auth_token=token
60
+ )
61
 
62
+ # 4) build your pipeline
 
 
63
  return pipeline(
64
  "image-classification",
65
  model=model,