Mansuba commited on
Commit
e8c275d
·
verified ·
1 Parent(s): 71062d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -45
app.py CHANGED
@@ -31,42 +31,13 @@ class ModelCache:
31
  self.cache_dir = cache_dir
32
  self.cache_dir.mkdir(parents=True, exist_ok=True)
33
 
34
- import requests
35
-
36
- def _load_banglaclip_model(self, weights_path: str) -> CLIPModel:
37
- try:
38
- # Check if the weights file exists
39
- if not Path(weights_path).exists():
40
- logger.info(f"BanglaCLIP weights not found locally, downloading from Hugging Face...")
41
- url = "https://huggingface.co/Mansuba/Bangla_text_to_image_app/resolve/main/banglaclip_model_epoch_10.pth"
42
- response = requests.get(url, stream=True)
43
- response.raise_for_status()
44
-
45
- # Save the downloaded file
46
- with open(weights_path, "wb") as f:
47
- for chunk in response.iter_content(chunk_size=8192):
48
- f.write(chunk)
49
- logger.info("BanglaCLIP weights downloaded successfully.")
50
-
51
- # Load the model weights
52
- clip_model = CLIPModel.from_pretrained(self.clip_model_name)
53
- state_dict = torch.load(weights_path, map_location=self.device)
54
-
55
- cleaned_state_dict = {
56
- k.replace('module.', '').replace('clip.', ''): v
57
- for k, v in state_dict.items()
58
- if k.replace('module.', '').replace('clip.', '').startswith(('text_model.', 'vision_model.'))
59
- }
60
-
61
- clip_model.load_state_dict(cleaned_state_dict, strict=False)
62
- return clip_model.to(self.device)
63
-
64
- except Exception as e:
65
- logger.error(f"Failed to load BanglaCLIP model: {str(e)}")
66
- raise
67
-
68
-
69
-
70
 
71
  class EnhancedBanglaSDGenerator:
72
  def __init__(
@@ -133,16 +104,11 @@ class EnhancedBanglaSDGenerator:
133
 
134
  def _load_banglaclip_model(self, weights_path: str) -> CLIPModel:
135
  try:
136
- model_file_path = Path("model_cache/banglaclip_model_epoch_10.pth")
137
- if not model_file_path.exists():
138
- logger.info("Downloading BanglaCLIP weights...")
139
- torch.hub.download_url_to_file(
140
- 'https://huggingface.co/Mansuba/Bangla_text_to_image_app/resolve/main/banglaclip_model_epoch_10.pth',
141
- model_file_path
142
- )
143
 
144
  clip_model = CLIPModel.from_pretrained(self.clip_model_name)
145
- state_dict = torch.load(model_file_path, map_location=self.device)
146
 
147
  cleaned_state_dict = {
148
  k.replace('module.', '').replace('clip.', ''): v
@@ -273,7 +239,7 @@ def create_gradio_interface():
273
  nonlocal generator
274
  if generator is None:
275
  generator = EnhancedBanglaSDGenerator(
276
- banglaclip_weights_path="model_cache/banglaclip_model_epoch_10.pth",
277
  cache_dir=str(cache_dir)
278
  )
279
  return generator
@@ -353,4 +319,5 @@ def create_gradio_interface():
353
 
354
  if __name__ == "__main__":
355
  demo = create_gradio_interface()
 
356
  demo.queue().launch(share=True)
 
31
  self.cache_dir = cache_dir
32
  self.cache_dir.mkdir(parents=True, exist_ok=True)
33
 
34
+ def load_model(self, model_id: str, load_func: callable, cache_name: str) -> Any:
35
+ try:
36
+ logger.info(f"Loading {cache_name}")
37
+ return load_func(model_id)
38
+ except Exception as e:
39
+ logger.error(f"Error loading model {cache_name}: {str(e)}")
40
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  class EnhancedBanglaSDGenerator:
43
  def __init__(
 
104
 
105
  def _load_banglaclip_model(self, weights_path: str) -> CLIPModel:
106
  try:
107
+ if not Path(weights_path).exists():
108
+ raise FileNotFoundError(f"BanglaCLIP weights not found at {weights_path}")
 
 
 
 
 
109
 
110
  clip_model = CLIPModel.from_pretrained(self.clip_model_name)
111
+ state_dict = torch.load(weights_path, map_location=self.device)
112
 
113
  cleaned_state_dict = {
114
  k.replace('module.', '').replace('clip.', ''): v
 
239
  nonlocal generator
240
  if generator is None:
241
  generator = EnhancedBanglaSDGenerator(
242
+ banglaclip_weights_path="banglaclip_model_epoch_10_quantized.pth",
243
  cache_dir=str(cache_dir)
244
  )
245
  return generator
 
319
 
320
  if __name__ == "__main__":
321
  demo = create_gradio_interface()
322
+ # Fixed queue configuration for newer Gradio versions
323
  demo.queue().launch(share=True)