RuntimeError: FlashAttention is not installed.

#47
by seregadgl - opened

Hi, can you tell me how to disable flash_attn?
model = SentenceTransformer("jinaai/jina-embeddings-v3",
device = device, trust_remote_code=True, model_kwargs={'default_task': 'text-matching' })
................
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=train_loss,
evaluator=dev_evaluator,

)

trainer.train()
RuntimeError: FlashAttention is not installed. To proceed with training, please install FlashAttention. For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model.

Sentence Transformers v3.2

Jina AI org

Hi @seregadgl , you need to have flash attention installed if you want to train the model, you can only disable it during inference

Thanks for the answer, maybe you can tell me what version of flash attention to install so that I can fine-tune the model in Google Colab on the T4 video card. Thanks!

Seems like you also need to install other dependencies (i.e. triton).
If you see rotary.py file, you could find that the RuntimeError: FlashAttention is not installed exception is raised if you failed to run from flash_attn.ops.triton.rotary import apply_rotary.
This line requires both flash attention and triton.
So, I guess you should also install the triton by running pip install triton
스크린샷 2024-10-20 오후 5.34.59.png

Jina AI org

@seregadgl you can install any recent version, the last one (2.6.3) should work fine

@BlackBeenie you're right, it requires triton as well, however triton should be automatically installed as you install torch if cuda is enabled

@jupyterjazz Seems like triton is not installed automatically in Google Colab. Cos, I also faced similar error, and running the pip install triton actually fixes the issue.

Jina AI org

@BlackBeenie , makes sense. This happens because Colab comes with pre-installed torch. If you uninstall it and reinstall it while connected to a GPU runtime, triton should be installed as well

@jupyterjazz Just tested it, and it works. Thanks :)

bwang0911 changed discussion status to closed

I am using windows 11 and successfully installed flash-attn show in the following pic. But still get this RuntimeError: FlashAttention is not installed error. So it does not support Windows if I want to use flash-attention?

image.png

Seems like you also need to install other dependencies (i.e. triton).
If you see rotary.py file, you could find that the RuntimeError: FlashAttention is not installed exception is raised if you failed to run from flash_attn.ops.triton.rotary import apply_rotary.
This line requires both flash attention and triton.
So, I guess you should also install the triton by running pip install triton
스크린샷 2024-10-20 오후 5.34.59.png

"name": "RuntimeError",
    "message": "FlashAttention is not installed. To proceed with training, please install FlashAttention. For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model."
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[5], line 2
      1 print(len(chunks))
----> 2 chunks_embeddings = embedder.encode(chunks, convert_to_tensor=True, batch_size=1)
      3 # Find the closest 5 sentences of the corpus for each query sentence based on cosine similarity
      4 top_k = min(3, len(chunks))

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\sentence_transformers\\SentenceTransformer.py:623, in SentenceTransformer.encode(self, sentences, prompt_name, prompt, batch_size, show_progress_bar, output_value, precision, convert_to_numpy, convert_to_tensor, device, normalize_embeddings, **kwargs)
    620 features.update(extra_features)
    622 with torch.no_grad():
--> 623     out_features = self.forward(features, **kwargs)
    624     if self.device.type == \"hpu\":
    625         out_features = copy.deepcopy(out_features)

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\sentence_transformers\\SentenceTransformer.py:690, in SentenceTransformer.forward(self, input, **kwargs)
    688     module_kwarg_keys = self.module_kwargs.get(module_name, [])
    689     module_kwargs = {key: value for key, value in kwargs.items() if key in module_kwarg_keys}
--> 690     input = module(input, **module_kwargs)
    691 return input

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\jina-embeddings-v3\\30996fea06f69ecd8382ee4f11e29acaf6b5405e\\custom_st.py:143, in Transformer.forward(self, features, task)
    139 lora_arguments = (
    140     {\"adapter_mask\": adapter_mask} if adapter_mask is not None else {}
    141 )
    142 features.pop('prompt_length', None)
--> 143 output_states = self.auto_model.forward(**features, **lora_arguments, return_dict=False)
    144 output_tokens = output_states[0]
    145 features.update({\"token_embeddings\": output_tokens, \"attention_mask\": features[\"attention_mask\"]})

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\modeling_lora.py:370, in XLMRobertaLoRA.forward(self, *args, **kwargs)
    369 def forward(self, *args, **kwargs):
--> 370     return self.roberta(*args, **kwargs)

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\modeling_xlm_roberta.py:709, in XLMRobertaModel.forward(self, input_ids, position_ids, token_type_ids, attention_mask, masked_tokens_mask, return_dict, **kwargs)
    706 else:
    707     subset_mask = None
--> 709 sequence_output = self.encoder(
    710     hidden_states,
    711     key_padding_mask=attention_mask,
    712     subset_mask=subset_mask,
    713     adapter_mask=adapter_mask,
    714 )
    716 if masked_tokens_mask is None:
    717     pooled_output = (
    718         self.pooler(sequence_output, adapter_mask=adapter_mask)
    719         if self.pooler is not None
    720         else None
    721     )

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\modeling_xlm_roberta.py:241, in XLMRobertaEncoder.forward(self, hidden_states, key_padding_mask, subset_mask, adapter_mask)
    234             hidden_states = torch.utils.checkpoint.checkpoint(
    235                 layer,
    236                 hidden_states,
    237                 use_reentrant=self.use_reentrant,
    238                 mixer_kwargs=mixer_kwargs,
    239             )
    240         else:
--> 241             hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
    242     hidden_states = pad_input(hidden_states, indices, batch, seqlen)
    243 else:

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\block.py:201, in Block.forward(self, hidden_states, residual, mixer_subset, mixer_kwargs)
    199 else:
    200     assert residual is None
--> 201     mixer_out = self.mixer(
    202         hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
    203     )
    204     if self.return_residual:  # mixer out is actually a pair here
    205         mixer_out, hidden_states = mixer_out

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\mha.py:732, in MHA.forward(self, x, x_kv, key_padding_mask, cu_seqlens, max_seqlen, mixer_subset, inference_params, adapter_mask, **kwargs)
    725 if (
    726     inference_params is None
    727     or inference_params.seqlen_offset == 0
    728     or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
    729     or not self.use_flash_attn
    730 ):
    731     if self.rotary_emb_dim > 0:
--> 732         qkv = self.rotary_emb(
    733             qkv,
    734             seqlen_offset=seqlen_offset,
    735             cu_seqlens=cu_seqlens,
    736             max_seqlen=rotary_max_seqlen,
    737         )
    738     if inference_params is None:
    739         if not self.checkpointing:

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\rotary.py:604, in RotaryEmbedding.forward(self, qkv, kv, seqlen_offset, cu_seqlens, max_seqlen)
    602 if kv is None:
    603     if self.scale is None:
--> 604         return apply_rotary_emb_qkv_(
    605             qkv,
    606             self._cos_cached,
    607             self._sin_cached,
    608             interleaved=self.interleaved,
    609             seqlen_offsets=seqlen_offset,
    610             cu_seqlens=cu_seqlens,
    611             max_seqlen=max_seqlen,
    612             use_flash_attn=self.use_flash_attn,
    613         )
    614     else:
    615         return apply_rotary_emb_qkv_(
    616             qkv,
    617             self._cos_cached,
   (...)
    625             use_flash_attn=self.use_flash_attn,
    626         )

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\rotary.py:327, in apply_rotary_emb_qkv_(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn)
    297 def apply_rotary_emb_qkv_(
    298     qkv,
    299     cos,
   (...)
    307     use_flash_attn=True,
    308 ):
    309     \"\"\"
    310     Arguments:
    311         qkv: (batch_size, seqlen, 3, nheads, headdim) if cu_seqlens is None
   (...)
    325     Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
    326     \"\"\"
--> 327     return ApplyRotaryEmbQKV_.apply(
    328         qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn,
    329     )

File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\\autograd\\function.py:575, in Function.apply(cls, *args, **kwargs)
    572 if not torch._C._are_functorch_transforms_active():
    573     # See NOTE: [functorch vjp and autograd interaction]
    574     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575     return super().apply(*args, **kwargs)  # type: ignore[misc]
    577 if not is_setup_ctx_defined:
    578     raise RuntimeError(
    579         \"In order to use an autograd.Function with functorch transforms \"
    580         \"(vmap, grad, jvp, jacrev, ...), it must override the setup_context \"
    581         \"staticmethod. For more details, please see \"
    582         \"https://pytorch.org/docs/main/notes/extending.func.html\"
    583     )

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\rotary.py:186, in ApplyRotaryEmbQKV_.forward(ctx, qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn)
    184     qk = rearrange(qkv[..., :2, :, :], \"... t h d -> ... (t h) d\")
    185     # qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
--> 186     apply_rotary(
    187         qk,
    188         cos,
    189         sin,
    190         seqlen_offsets=seqlen_offsets,
    191         interleaved=interleaved,
    192         inplace=True,
    193         cu_seqlens=cu_seqlens,
    194         max_seqlen=max_seqlen,
    195     )
    196 else:
    197     q_rot = apply_rotary_emb_torch(
    198         qkv[:, :, 0],
    199         cos,
    200         sin,
    201         interleaved=interleaved,
    202     )

File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\rotary.py:18, in apply_rotary(*args, **kwargs)
     17 def apply_rotary(*args, **kwargs):
---> 18     raise RuntimeError(
     19         \"FlashAttention is not installed. To proceed with training, please install FlashAttention. \"
     20         \"For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model.\"
     21     )

RuntimeError: FlashAttention is not installed. To proceed with training, please install FlashAttention. For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model.

I also had this problem. Thanks to @BlackBeenie I found out that I was missing the triton package. It doesn't support Python 3.13 yet.
https://pypi.org/project/triton/
https://huggingface.co/jinaai/jina-embeddings-v3/discussions/47#6714c101d0aceb08357afc2a

So I switched to Python 3.12, but ended up getting another error:

ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)

Sign up or log in to comment