JAX has a sub-system called Pallas[1] with a Triton-like programming model and an example implementation of Flash Attention [2]. It is quite fast. On TPUs I've heard that the XLA compiler already emits a flash-attention-like computation graph for a regular JAX implementation of attention so there's no need to have some specialized kernel in that case.
1. https://jax.readthedocs.io/en/latest/pallas/index.html
2. https://github.com/jax-ml/jax/blob/main/jax/experimental/pal...