Fix use of undeclared identifier 'CHECK_NOSPARSE_CONTIGUOUS_CUDA' with USE='-flash' Bug: https://github.com/pytorch/pytorch/issues/160826 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -71,6 +71,7 @@ #include #include +#include #ifdef USE_FLASH_ATTENTION // FlashAttention Specific Imports #include --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -33,6 +33,7 @@ #include #endif +#include #ifdef USE_FLASH_ATTENTION // FlashAttention Specific Imports #include --- /dev/null +++ b/aten/src/ATen/native/transformers/flash_api_common.h @@ -0,0 +1,28 @@ +#pragma once +#include +#include + +#include +#include + +#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK(TENSOR.is_contiguous()); + +#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK( \ + TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); + +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + TORCH_CHECK( \ + uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + TORCH_CHECK( \ + B < std::numeric_limits::max(), #B " overflows"); \ + } --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -4,28 +4,7 @@ #include #include #include - -#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ - TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ - TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ - TORCH_CHECK(TENSOR.is_contiguous()); - -#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ - TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ - TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ - TORCH_CHECK( \ - TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); - -#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ - TORCH_CHECK( \ - uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") - -#define ASSIGN_CHECK_OVERFLOW(A, B) \ - { \ - A = B; \ - TORCH_CHECK( \ - B < std::numeric_limits::max(), #B " overflows"); \ - } +#include namespace pytorch_flash {