witchEverly commited on
Commit
39db582
·
verified ·
1 Parent(s): cad7e81

Update CaptionGenerator.py

Browse files
Files changed (1) hide show
  1. CaptionGenerator.py +4 -16
CaptionGenerator.py CHANGED
@@ -32,17 +32,16 @@ class CaptionGenerator:
32
  self.processor = None
33
  self.model = None
34
 
35
- def image_2_text(self, image_data):
36
  """
37
  Generate a caption for the provided image using the BLIP-2 model.
38
  :param image_data: PIL.Image - The image for which the caption is to be generated.
39
  :return: description - The description generated for the image.
40
  """
41
  try:
42
- self.processor, self.model, _ = utils.init_model(init_model_required=True)
43
- inputs = self.processor(images=image_data, return_tensors="pt")
44
- generated_ids = self.model.generate(**inputs, max_length=100)
45
- description = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
46
  return description
47
 
48
  except Exception as e:
@@ -117,17 +116,6 @@ class CaptionGenerator:
117
  st.error(f"Error occurred with Gemini API: {e}")
118
 
119
 
120
- # @st.cache_resource()
121
- # def load_model():
122
- # """
123
- # Loads the BLIP-2 model for image captioning. This function is cached to avoid
124
- # re-loading the model on every call.
125
- # :param CaptionGenerator: The class for generating captions for images.
126
- # :return: Instance of the CaptionGenerator class.
127
- # """
128
- # return CaptionGenerator()
129
-
130
-
131
  # Example usage of the CaptionGenerator class
132
  # caption_generator = load_model()
133
  # image = Image.open("example.jpg")
 
32
  self.processor = None
33
  self.model = None
34
 
35
+ def image_2_text(self, image_data, processor, model):
36
  """
37
  Generate a caption for the provided image using the BLIP-2 model.
38
  :param image_data: PIL.Image - The image for which the caption is to be generated.
39
  :return: description - The description generated for the image.
40
  """
41
  try:
42
+ inputs = processor(images=image_data, return_tensors="pt")
43
+ generated_ids = model.generate(**inputs, max_length=100)
44
+ description = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
 
45
  return description
46
 
47
  except Exception as e:
 
116
  st.error(f"Error occurred with Gemini API: {e}")
117
 
118
 
 
 
 
 
 
 
 
 
 
 
 
119
  # Example usage of the CaptionGenerator class
120
  # caption_generator = load_model()
121
  # image = Image.open("example.jpg")