Mansuba commited on
Commit
38b64f3
·
verified ·
1 Parent(s): 8a74117

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -7
app.py CHANGED
@@ -31,13 +31,42 @@ class ModelCache:
31
  self.cache_dir = cache_dir
32
  self.cache_dir.mkdir(parents=True, exist_ok=True)
33
 
34
- def load_model(self, model_id: str, load_func: callable, cache_name: str) -> Any:
35
- try:
36
- logger.info(f"Loading {cache_name}")
37
- return load_func(model_id)
38
- except Exception as e:
39
- logger.error(f"Error loading model {cache_name}: {str(e)}")
40
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  class EnhancedBanglaSDGenerator:
43
  def __init__(
 
31
  self.cache_dir = cache_dir
32
  self.cache_dir.mkdir(parents=True, exist_ok=True)
33
 
34
+ import requests
35
+
36
+ def _load_banglaclip_model(self, weights_path: str) -> CLIPModel:
37
+ try:
38
+ # Check if the weights file exists
39
+ if not Path(weights_path).exists():
40
+ logger.info(f"BanglaCLIP weights not found locally, downloading from Hugging Face...")
41
+ url = "https://huggingface.co/Mansuba/Bangla_text_to_image_app/resolve/main/banglaclip_model_epoch_10.pth"
42
+ response = requests.get(url, stream=True)
43
+ response.raise_for_status()
44
+
45
+ # Save the downloaded file
46
+ with open(weights_path, "wb") as f:
47
+ for chunk in response.iter_content(chunk_size=8192):
48
+ f.write(chunk)
49
+ logger.info("BanglaCLIP weights downloaded successfully.")
50
+
51
+ # Load the model weights
52
+ clip_model = CLIPModel.from_pretrained(self.clip_model_name)
53
+ state_dict = torch.load(weights_path, map_location=self.device)
54
+
55
+ cleaned_state_dict = {
56
+ k.replace('module.', '').replace('clip.', ''): v
57
+ for k, v in state_dict.items()
58
+ if k.replace('module.', '').replace('clip.', '').startswith(('text_model.', 'vision_model.'))
59
+ }
60
+
61
+ clip_model.load_state_dict(cleaned_state_dict, strict=False)
62
+ return clip_model.to(self.device)
63
+
64
+ except Exception as e:
65
+ logger.error(f"Failed to load BanglaCLIP model: {str(e)}")
66
+ raise
67
+
68
+
69
+
70
 
71
  class EnhancedBanglaSDGenerator:
72
  def __init__(