from typing import Optional import torch from torch import nn from torch.nn import functional as F import numpy as np from dataclasses import dataclass from .transformer import ( LayerNormFp32, LayerNorm, QuickGELU, MultimodalTransformer, ) from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower try: from transformers import ( BeamSearchScorer, LogitsProcessorList, TopPLogitsWarper, TopKLogitsWarper, RepetitionPenaltyLogitsProcessor, MinLengthLogitsProcessor, MaxLengthCriteria, StopStringCriteria, EosTokenCriteria, StoppingCriteriaList ) GENERATION_TYPES = { "top_k": TopKLogitsWarper, "top_p": TopPLogitsWarper, "beam_search": "beam_search" } _has_transformers = True except ImportError as e: GENERATION_TYPES = { "top_k": None, "top_p": None, "beam_search": "beam_search" } _has_transformers = False @dataclass class MultimodalCfg(CLIPTextCfg): mlp_ratio: int = 4 dim_head: int = 64 heads: int = 8 n_queries: int = 256 attn_pooler_heads: int = 8 def _build_text_decoder_tower( embed_dim, multimodal_cfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = ( LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm ) decoder = MultimodalTransformer( context_length=multimodal_cfg.context_length, width=multimodal_cfg.width, heads=multimodal_cfg.heads, layers=multimodal_cfg.layers, ls_init_value=multimodal_cfg.ls_init_value, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, ) return decoder def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor: if not isinstance(token_id, torch.Tensor): if isinstance(token_id, int): token_id = [token_id] token_id = torch.tensor(token_id, device=device) return token_id class CoCa(nn.Module): def __init__( self, embed_dim, multimodal_cfg: MultimodalCfg, text_cfg: CLIPTextCfg, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, init_logit_scale: float = np.log(1 / 0.07), init_logit_bias: Optional[float] = None, cast_dtype: Optional[torch.dtype] = None, pad_id: int = 0, ): super().__init__() multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg self.text = _build_text_tower( embed_dim=embed_dim, text_cfg=text_cfg, quick_gelu=quick_gelu, cast_dtype=cast_dtype, ) vocab_size = ( text_cfg.vocab_size # for hf models if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None else text_cfg.vocab_size ) self.visual = _build_vision_tower( embed_dim=embed_dim, vision_cfg=vision_cfg, quick_gelu=quick_gelu, cast_dtype=cast_dtype, ) self.text_decoder = _build_text_decoder_tower( vocab_size, multimodal_cfg=multimodal_cfg, quick_gelu=quick_gelu, cast_dtype=cast_dtype, ) self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) if init_logit_bias is not None: self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) else: self.logit_bias = None self.pad_id = pad_id self.context_length = multimodal_cfg.context_length @torch.jit.ignore def set_grad_checkpointing(self, enable: bool = True): self.visual.set_grad_checkpointing(enable) self.text.set_grad_checkpointing(enable) self.text_decoder.set_grad_checkpointing(enable) def _encode_image(self, images, normalize: bool = True): image_latent, tokens_embs = self.visual(images) image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent return image_latent, tokens_embs def _encode_text(self, text, normalize: bool = True): text_latent, token_emb = self.text(text) text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent return text_latent, token_emb def encode_image(self, images, normalize: bool = True): image_latent, _ = self._encode_image(images, normalize=normalize) return image_latent def encode_text(self, text, normalize: bool = True): text_latent, _ = self._encode_text(text, normalize=normalize) return text_latent def forward( self, image, text: Optional[torch.Tensor] = None, image_latent: Optional[torch.Tensor] = None, image_embs: Optional[torch.Tensor] = None, output_labels: bool = True, ): if image_latent is None or image_embs is None: image_latent, image_embs = self._encode_image(image) if text is None: return {"image_features": image_latent, "image_embs": image_embs} text_latent, token_embs = self._encode_text(text) # FIXME this isn't an ideal solution, would like to improve -RW labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None if output_labels: # align text_embs and thus logits with labels for teacher-forcing caption loss token_embs = token_embs[:, :-1] logits = self.text_decoder(image_embs, token_embs) out_dict = { "image_features": image_latent, "text_features": text_latent, "logits": logits, "logit_scale": self.logit_scale.exp() } if labels is not None: out_dict["labels"] = labels if self.logit_bias is not None: out_dict["logit_bias"] = self.logit_bias return out_dict def generate( self, image, text=None, seq_len=30, max_seq_len=77, temperature=1., generation_type="beam_search", top_p=0.1, # keep tokens in the 1 - top_p quantile top_k=1, # keeps the top_k most probable tokens pad_token_id=None, eos_token_id=None, sot_token_id=None, num_beams=6, num_beam_groups=3, min_seq_len=5, stopping_criteria=None, repetition_penalty=1.0, fixed_output_length=False # if True output.shape == (batch_size, seq_len) ): # taking many ideas and components from HuggingFace GenerationMixin # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" device = image.device with torch.no_grad(): sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device) eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device) pad_token_id = self.pad_id if pad_token_id is None else pad_token_id logit_processor = LogitsProcessorList( [ MinLengthLogitsProcessor(min_seq_len, eos_token_id), RepetitionPenaltyLogitsProcessor(repetition_penalty), ] ) if stopping_criteria is None: stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] stopping_criteria = StoppingCriteriaList(stopping_criteria) if generation_type == "beam_search": output = self._generate_beamsearch( image_inputs=image, pad_token_id=pad_token_id, eos_token_id=eos_token_id, sot_token_id=sot_token_id, num_beams=num_beams, num_beam_groups=num_beam_groups, min_seq_len=min_seq_len, stopping_criteria=stopping_criteria, logit_processor=logit_processor, ) if fixed_output_length and output.shape[1] < seq_len: pad_len = seq_len - output.shape[1] return torch.cat(( output, torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id ), dim=1 ) return output elif generation_type == "top_p": logit_warper = GENERATION_TYPES[generation_type](top_p) elif generation_type == "top_k": logit_warper = GENERATION_TYPES[generation_type](top_k) else: raise ValueError( f"generation_type has to be one of " f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." ) image_latent, image_embs = self._encode_image(image) if text is None: text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id was_training = self.training num_dims = len(text.shape) if num_dims == 1: text = text[None, :] self.eval() out = text while True: x = out[:, -max_seq_len:] cur_len = x.shape[1] logits = self( image, x, image_latent=image_latent, image_embs=image_embs, output_labels=False, )["logits"][:, -1] mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id if mask.all(): if not fixed_output_length: break else: logits = logits[~mask, :] filtered_logits = logit_processor(x[~mask, :], logits) filtered_logits = logit_warper(x[~mask, :], filtered_logits) probs = F.softmax(filtered_logits / temperature, dim=-1) if (cur_len + 1 == seq_len): sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id else: sample[~mask, :] = torch.multinomial(probs, 1) out = torch.cat((out, sample), dim=-1) cur_len += 1 if all(stopping_criteria(out, None)): break if num_dims == 1: out = out.squeeze(0) self.train(was_training) return out def _generate_beamsearch( self, image_inputs, pad_token_id=None, eos_token_id=None, sot_token_id=None, num_beams=6, num_beam_groups=3, min_seq_len=5, stopping_criteria=None, logit_processor=None, logit_warper=None, ): device = image_inputs.device batch_size = image_inputs.shape[0] image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) image_latent, image_embs = self._encode_image(image_inputs) input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) input_ids = input_ids * sot_token_id beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=num_beams, device=device, num_beam_groups=num_beam_groups, ) # instantiate logits processors logits_processor = ( LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) if logit_processor is None else logit_processor ) num_beams = beam_scorer.num_beams num_beam_groups = beam_scorer.num_beam_groups num_sub_beams = num_beams // num_beam_groups batch_size = len(beam_scorer._beam_hyps) // num_beam_groups batch_beam_size, cur_len = input_ids.shape beam_indices = None if num_beams * batch_size != batch_beam_size: raise ValueError( f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." ) beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in # the same group don't produce same tokens everytime. beam_scores[:, ::num_sub_beams] = 0 beam_scores = beam_scores.view((batch_size * num_beams,)) while True: # predicted tokens in cur_len step current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) # indices which will form the beams in the next time step reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) # do one decoder step on all beams of all sentences in batch model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) outputs = self( model_inputs['images'], model_inputs['text'], image_latent=image_latent, image_embs=image_embs, output_labels=False, ) for beam_group_idx in range(num_beam_groups): group_start_idx = beam_group_idx * num_sub_beams group_end_idx = min(group_start_idx + num_sub_beams, num_beams) group_size = group_end_idx - group_start_idx # indices of beams of current group among all sentences in batch batch_group_indices = [] for batch_idx in range(batch_size): batch_group_indices.extend( [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] ) group_input_ids = input_ids[batch_group_indices] # select outputs of beams of currentg group only next_token_logits = outputs['logits'][batch_group_indices, -1, :] vocab_size = next_token_logits.shape[-1] next_token_scores_processed = logits_processor( group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx ) next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) next_token_scores = next_token_scores.expand_as(next_token_scores_processed) # reshape for beam search next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) next_token_scores, next_tokens = torch.topk( next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True ) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_tokens = next_tokens % vocab_size # stateless process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None beam_outputs = beam_scorer.process( group_input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, beam_indices=process_beam_indices, group_index=beam_group_idx, ) beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids[batch_group_indices] = group_input_ids[beam_idx] group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) current_tokens[batch_group_indices] = group_input_ids[:, -1] # (beam_idx // group_size) -> batch_idx # (beam_idx % group_size) -> offset of idx inside the group reordering_indices[batch_group_indices] = ( num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) ) input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) # increase cur_len cur_len = cur_len + 1 if beam_scorer.is_done or all(stopping_criteria(input_ids, None)): break final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, beam_indices=final_beam_indices, ) return sequence_outputs['sequences'] def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): if past: input_ids = input_ids[:, -1].unsqueeze(-1) attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) else: position_ids = None return { "text": input_ids, "images": image_inputs, "past_key_values": past, "position_ids": position_ids, "attention_mask": attention_mask, }