Skip to contents

Flash attention offers performance optimization for attention layers, making it especially useful for large language models (LLMs) that benefit from faster and more memory-efficient attention computations.

Once enabled, supported layers like layer_multi_head_attention will attempt to use flash attention for faster computations. By default, this feature is enabled.

Note that enabling flash attention does not guarantee it will always be used. Typically, the inputs must be in float16 or bfloat16 dtype, and input layout requirements may vary depending on the backend.

Usage

config_enable_flash_attention()