Commit
·
aeb99cb
1
Parent(s):
8f83a35
refactor: prompts
Browse filesSigned-off-by: Meow <[email protected]>
- modeling_lora.py +4 -2
modeling_lora.py
CHANGED
|
@@ -165,7 +165,6 @@ class LoRAParametrization(nn.Module):
|
|
| 165 |
):
|
| 166 |
"""
|
| 167 |
Registering LoRA adapters to all embedding and linear layers.
|
| 168 |
-
|
| 169 |
Additionally, we implement a custom forward function for LoRA parametrization.
|
| 170 |
This function modifies the layer's forward pass to optionally use task-specific
|
| 171 |
parameters. When a `task_id` is provided, it employs a LoRA parametrization
|
|
@@ -373,7 +372,6 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 373 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 374 |
"""
|
| 375 |
Computes sentence embeddings.
|
| 376 |
-
|
| 377 |
sentences(`str` or `List[str]`):
|
| 378 |
Sentence or sentences to be encoded
|
| 379 |
task_type(`str`, *optional*, defaults to `None`):
|
|
@@ -394,6 +392,10 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 394 |
adapter_mask = torch.full(
|
| 395 |
(num_examples,), task_id, dtype=torch.int32, device=self.device
|
| 396 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
return self.roberta.encode(
|
| 398 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
| 399 |
)
|
|
|
|
| 165 |
):
|
| 166 |
"""
|
| 167 |
Registering LoRA adapters to all embedding and linear layers.
|
|
|
|
| 168 |
Additionally, we implement a custom forward function for LoRA parametrization.
|
| 169 |
This function modifies the layer's forward pass to optionally use task-specific
|
| 170 |
parameters. When a `task_id` is provided, it employs a LoRA parametrization
|
|
|
|
| 372 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 373 |
"""
|
| 374 |
Computes sentence embeddings.
|
|
|
|
| 375 |
sentences(`str` or `List[str]`):
|
| 376 |
Sentence or sentences to be encoded
|
| 377 |
task_type(`str`, *optional*, defaults to `None`):
|
|
|
|
| 392 |
adapter_mask = torch.full(
|
| 393 |
(num_examples,), task_id, dtype=torch.int32, device=self.device
|
| 394 |
)
|
| 395 |
+
if isinstance(sentences, str):
|
| 396 |
+
sentences = self._task_instructions[task_type] + sentences
|
| 397 |
+
else:
|
| 398 |
+
sentences = [self._task_instructions[task_type] + sentence for sentence in sentences]
|
| 399 |
return self.roberta.encode(
|
| 400 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
| 401 |
)
|