jays009 commited on
Commit
e36f4e9
·
verified ·
1 Parent(s): 9487094

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
 
4
  from torchvision import models, transforms
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
@@ -24,6 +25,9 @@ CONFIDENCE_THRESHOLD = 0.8 # 80%
24
  # Mahalanobis distance threshold for OOD detection
25
  MAHALANOBIS_THRESHOLD = 100.0 # Calibrate this using a validation set
26
 
 
 
 
27
  # Download model from Hugging Face
28
  def download_model():
29
  model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
@@ -58,7 +62,7 @@ def load_main_model(model_path):
58
  try:
59
  class_statistics = torch.load("class_statistics.pth", map_location=torch.device("cpu"))
60
  except FileNotFoundError:
61
- logger.error("class_statistics.pth not found. Please run the statistics computation script first.")
62
  raise
63
 
64
  # Path to your model
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
+ import torch.serialization # Added for add_safe_globals
5
  from torchvision import models, transforms
6
  from huggingface_hub import hf_hub_download
7
  from PIL import Image
 
25
  # Mahalanobis distance threshold for OOD detection
26
  MAHALANOBIS_THRESHOLD = 100.0 # Calibrate this using a validation set
27
 
28
+ # Allowlist the NumPy global for torch.load
29
+ torch.serialization.add_safe_globals([np._core.multiarray._reconstruct])
30
+
31
  # Download model from Hugging Face
32
  def download_model():
33
  model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
 
62
  try:
63
  class_statistics = torch.load("class_statistics.pth", map_location=torch.device("cpu"))
64
  except FileNotFoundError:
65
+ logger.error("class_statistics.pth not found. Please ensure the file is in the same directory as app.py.")
66
  raise
67
 
68
  # Path to your model