Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
7434833
1
Parent(s):
5c095cd
Add CUDA availability check for input IDs and attention mask in prompt encoding
Browse files
app.py
CHANGED
@@ -51,8 +51,8 @@ def encode_prompt(text_tokenizer, text_encoder, prompt):
|
|
51 |
print(f'prompt={prompt}')
|
52 |
captions = [prompt]
|
53 |
tokens = text_tokenizer(text=captions, max_length=512, padding='max_length', truncation=True, return_tensors='pt') # todo: put this into dataset
|
54 |
-
input_ids = tokens.input_ids.cuda(non_blocking=True)
|
55 |
-
mask = tokens.attention_mask.cuda(non_blocking=True)
|
56 |
text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float()
|
57 |
lens: List[int] = mask.sum(dim=-1).tolist()
|
58 |
cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0))
|
|
|
51 |
print(f'prompt={prompt}')
|
52 |
captions = [prompt]
|
53 |
tokens = text_tokenizer(text=captions, max_length=512, padding='max_length', truncation=True, return_tensors='pt') # todo: put this into dataset
|
54 |
+
input_ids = tokens.input_ids.cuda(non_blocking=True) if torch.cuda.is_available() else tokens.input_ids
|
55 |
+
mask = tokens.attention_mask.cuda(non_blocking=True) if torch.cuda.is_available() else tokens.attention_mask
|
56 |
text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float()
|
57 |
lens: List[int] = mask.sum(dim=-1).tolist()
|
58 |
cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0))
|