MohamedRashad commited on
Commit
7434833
·
1 Parent(s): 5c095cd

Add CUDA availability check for input IDs and attention mask in prompt encoding

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -51,8 +51,8 @@ def encode_prompt(text_tokenizer, text_encoder, prompt):
51
  print(f'prompt={prompt}')
52
  captions = [prompt]
53
  tokens = text_tokenizer(text=captions, max_length=512, padding='max_length', truncation=True, return_tensors='pt') # todo: put this into dataset
54
- input_ids = tokens.input_ids.cuda(non_blocking=True)
55
- mask = tokens.attention_mask.cuda(non_blocking=True)
56
  text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float()
57
  lens: List[int] = mask.sum(dim=-1).tolist()
58
  cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0))
 
51
  print(f'prompt={prompt}')
52
  captions = [prompt]
53
  tokens = text_tokenizer(text=captions, max_length=512, padding='max_length', truncation=True, return_tensors='pt') # todo: put this into dataset
54
+ input_ids = tokens.input_ids.cuda(non_blocking=True) if torch.cuda.is_available() else tokens.input_ids
55
+ mask = tokens.attention_mask.cuda(non_blocking=True) if torch.cuda.is_available() else tokens.attention_mask
56
  text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float()
57
  lens: List[int] = mask.sum(dim=-1).tolist()
58
  cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0))