Jackmin108 commited on
Commit
7af97e7
·
1 Parent(s): 2646361

feat: add lora instructions for retrieval

Browse files

Signed-off-by: Meow <[email protected]>

Files changed (2) hide show
  1. configuration_xlm_roberta.py +2 -2
  2. modeling_lora.py +8 -6
configuration_xlm_roberta.py CHANGED
@@ -27,7 +27,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
27
  use_cache: bool = True,
28
  classifier_dropout: Optional[float] = None,
29
  lora_adaptations: Optional[List[str]] = None,
30
- lora_prompts: Optional[Dict[str, str]] = None,
31
  lora_rank: int = 4,
32
  lora_dropout_p: float = 0.0,
33
  lora_alpha: int = 1,
@@ -103,7 +103,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
103
  self.classifier_dropout = classifier_dropout
104
  self.load_trained_adapters = load_trained_adapters
105
  self.lora_adaptations = lora_adaptations
106
- self.lora_prompts = lora_prompts
107
  self.lora_rank = lora_rank
108
  self.lora_dropout_p = lora_dropout_p
109
  self.lora_alpha = lora_alpha
 
27
  use_cache: bool = True,
28
  classifier_dropout: Optional[float] = None,
29
  lora_adaptations: Optional[List[str]] = None,
30
+ task_instructions: Optional[Dict[str, str]] = None,
31
  lora_rank: int = 4,
32
  lora_dropout_p: float = 0.0,
33
  lora_alpha: int = 1,
 
103
  self.classifier_dropout = classifier_dropout
104
  self.load_trained_adapters = load_trained_adapters
105
  self.lora_adaptations = lora_adaptations
106
+ self.task_instructions = task_instructions
107
  self.lora_rank = lora_rank
108
  self.lora_dropout_p = lora_dropout_p
109
  self.lora_alpha = lora_alpha
modeling_lora.py CHANGED
@@ -258,15 +258,15 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
258
  raise ValueError(
259
  f"`lora_adaptations` must be a list and contain at least one element"
260
  )
261
- self._lora_prompts = config.lora_prompts
262
  if (
263
- not isinstance(self._lora_prompts, dict)
264
- or len(self._lora_prompts) != len(self._lora_adaptations)
265
- or not all([v in self._lora_adaptations for v in self._lora_prompts.keys()])
266
  ):
267
  raise ValueError(
268
- f"`lora_prompts` must be a dict and contain the same number of elements "
269
- f"as `lora_adaptations` with all keys in `lora_prompts` present in `lora_adaptations`."
270
  )
271
  self._adaptation_map = {
272
  name: idx for idx, name in enumerate(self._lora_adaptations)
@@ -393,6 +393,8 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
393
  adapter_mask = torch.full(
394
  (num_examples,), task_id, dtype=torch.int32, device=self.device
395
  )
 
 
396
  return self.roberta.encode(
397
  sentences, *args, adapter_mask=adapter_mask, **kwargs
398
  )
 
258
  raise ValueError(
259
  f"`lora_adaptations` must be a list and contain at least one element"
260
  )
261
+ self._task_instructions = config.task_instructions
262
  if (
263
+ not isinstance(self._task_instructions, dict)
264
+ or len(self._task_instructions) != len(self._lora_adaptations)
265
+ or not all([v in self._lora_adaptations for v in self._task_instructions.keys()])
266
  ):
267
  raise ValueError(
268
+ f"`task_instructions` must be a dict and contain the same number of elements "
269
+ f"as `lora_adaptations` with all keys in `task_instructions` present in `lora_adaptations`."
270
  )
271
  self._adaptation_map = {
272
  name: idx for idx, name in enumerate(self._lora_adaptations)
 
393
  adapter_mask = torch.full(
394
  (num_examples,), task_id, dtype=torch.int32, device=self.device
395
  )
396
+ if task_type in ['query', 'passage']:
397
+ sentences = [self._task_instructions[task_type] + ' ' + sentence for sentence in sentences]
398
  return self.roberta.encode(
399
  sentences, *args, adapter_mask=adapter_mask, **kwargs
400
  )