Spaces:
Build error
Build error
| import os | |
| from PIL import Image | |
| from transformers import ( | |
| BlipProcessor, | |
| BlipForConditionalGeneration, | |
| BlipConfig, | |
| BlipTextConfig, | |
| BlipVisionConfig, | |
| ) | |
| import torch | |
| import model_management | |
| import folder_paths | |
| class BLIPImg2Txt: | |
| def __init__( | |
| self, | |
| conditional_caption: str, | |
| min_words: int, | |
| max_words: int, | |
| temperature: float, | |
| repetition_penalty: float, | |
| search_beams: int, | |
| model_id: str = "Salesforce/blip-image-captioning-large", | |
| custom_model_path: str = None, | |
| ): | |
| self.conditional_caption = conditional_caption | |
| self.model_id = model_id | |
| self.custom_model_path = custom_model_path | |
| if self.custom_model_path and os.path.exists(self.custom_model_path): | |
| self.model_path = self.custom_model_path | |
| else: | |
| self.model_path = folder_paths.get_full_path("blip", model_id) | |
| if temperature > 1.1 or temperature < 0.90: | |
| do_sample = True | |
| num_beams = 1 | |
| else: | |
| do_sample = False | |
| num_beams = search_beams if search_beams > 1 else 1 | |
| self.text_config_kwargs = { | |
| "do_sample": do_sample, | |
| "max_length": max_words, | |
| "min_length": min_words, | |
| "repetition_penalty": repetition_penalty, | |
| "padding": "max_length", | |
| } | |
| if not do_sample: | |
| self.text_config_kwargs["temperature"] = temperature | |
| self.text_config_kwargs["num_beams"] = num_beams | |
| def generate_caption(self, image: Image.Image) -> str: | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| if self.model_path and os.path.exists(self.model_path): | |
| model_path = self.model_path | |
| local_files_only = True | |
| else: | |
| model_path = self.model_id | |
| local_files_only = False | |
| processor = BlipProcessor.from_pretrained(model_path, local_files_only=local_files_only) | |
| config_text = BlipTextConfig.from_pretrained(model_path, local_files_only=local_files_only) | |
| config_text.update(self.text_config_kwargs) | |
| config_vision = BlipVisionConfig.from_pretrained(model_path, local_files_only=local_files_only) | |
| config = BlipConfig.from_text_vision_configs(config_text, config_vision) | |
| model = BlipForConditionalGeneration.from_pretrained( | |
| model_path, | |
| config=config, | |
| torch_dtype=torch.float16, | |
| local_files_only=local_files_only | |
| ).to(model_management.get_torch_device()) | |
| inputs = processor( | |
| image, | |
| self.conditional_caption, | |
| return_tensors="pt", | |
| ).to(model_management.get_torch_device(), torch.float16) | |
| with torch.no_grad(): | |
| out = model.generate(**inputs) | |
| ret = processor.decode(out[0], skip_special_tokens=True) | |
| del model | |
| torch.cuda.empty_cache() | |
| return ret |