Overview
TensorRef is a template for objects pointing to the start of tensors of arbitrary rank and layout within memory. A TensorRef combines a pointer and a Layout concept to provide structured access to multi-dimensional data.
Header: cutlass/tensor_ref.h
Template Signature
template <
typename Element_,
typename Layout_
>
class TensorRef;
Template Parameters
Data type of element stored within tensor (concept: NumericType). Examples: float, half_t, int8_t
Defines a mapping from logical coordinate to linear memory (concept: Layout). Examples: layout::RowMajor, layout::ColumnMajor, IdentityTensorLayout<2>
Member Types
using Element = Element_;
using Layout = Layout_;
using Reference = /* Element& or SubbyteReference<Element> */;
static int const kRank = Layout::kRank;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorCoord = typename Layout::TensorCoord;
using Stride = typename Layout::Stride;
using ConstTensorRef = TensorRef<
typename platform::remove_const<Element>::type const,
Layout>;
using NonConstTensorRef = TensorRef<
typename platform::remove_const<Element>::type,
Layout>;
Constructors
Default Constructor
CUTLASS_HOST_DEVICE
TensorRef();
Constructs a null TensorRef.
Pointer and Layout Constructor
CUTLASS_HOST_DEVICE
TensorRef(
Element *ptr,
Layout const &layout
);
Parameters:
ptr: Pointer to start of tensor
layout: Layout object containing stride and mapping function
Converting Constructor
template<typename _Magic = int>
CUTLASS_HOST_DEVICE
TensorRef(
NonConstTensorRef const &ref,
_Magic magic = /* SFINAE trick */
);
Converting constructor from TensorRef to non-constant data.
Member Functions
const_ref
CUTLASS_HOST_DEVICE
ConstTensorRef const_ref() const;
Returns a reference to constant-valued tensor.
non_const_ref
CUTLASS_HOST_DEVICE
NonConstTensorRef non_const_ref() const;
Returns a reference to non-constant tensor.
reset
CUTLASS_HOST_DEVICE
void reset(Element* ptr = nullptr);
CUTLASS_HOST_DEVICE
void reset(Element* ptr, Layout const &layout);
Updates the pointer, optionally with a new layout.
good
CUTLASS_HOST_DEVICE
bool good() const;
Returns true if the TensorRef is non-null.
data
CUTLASS_HOST_DEVICE
Element * data() const;
CUTLASS_HOST_DEVICE
Reference data(LongIndex idx) const;
Returns the pointer to referenced data, or a reference to the element at a given linear index.
layout
CUTLASS_HOST_DEVICE
Layout & layout();
CUTLASS_HOST_DEVICE
Layout layout() const;
Returns the layout object.
stride
CUTLASS_HOST_DEVICE
Stride stride() const;
CUTLASS_HOST_DEVICE
Stride & stride();
CUTLASS_HOST_DEVICE
typename Layout::Stride::Index stride(int dim) const;
CUTLASS_HOST_DEVICE
typename Layout::Stride::Index & stride(int dim);
Returns the layout object’s stride vector or a specific stride dimension.
offset
CUTLASS_HOST_DEVICE
LongIndex offset(TensorCoord const& coord) const;
Computes the offset of an index from the origin of the tensor.
CUTLASS_HOST_DEVICE
Reference at(TensorCoord const& coord) const;
Returns a reference to the element at a given coordinate.
operator[]
CUTLASS_HOST_DEVICE
Reference operator[](TensorCoord const& coord) const;
Returns a reference to the element at a given coordinate.
add_pointer_offset
CUTLASS_HOST_DEVICE
TensorRef & add_pointer_offset(LongIndex offset_);
Adds an offset to the pointer.
add_coord_offset
CUTLASS_HOST_DEVICE
TensorRef & add_coord_offset(TensorCoord const &coord);
Adds a coordinate offset to the pointer.
Arithmetic Operators
CUTLASS_HOST_DEVICE
TensorRef operator+(TensorCoord const& b) const;
CUTLASS_HOST_DEVICE
TensorRef & operator+=(TensorCoord const& b);
CUTLASS_HOST_DEVICE
TensorRef operator-(TensorCoord const& b) const;
CUTLASS_HOST_DEVICE
TensorRef & operator-=(TensorCoord const& b);
Returns or modifies a TensorRef offset by a given amount.
Helper Functions
make_TensorRef
template <typename Element, typename Layout>
CUTLASS_HOST_DEVICE
TensorRef<Element, Layout> make_TensorRef(Element *ptr, Layout const &layout);
Constructs a TensorRef, deducing types from arguments.
TensorRef_aligned
template <typename Element, typename Layout>
CUTLASS_HOST_DEVICE
bool TensorRef_aligned(TensorRef<Element, Layout> const &ref, int alignment);
Checks if a TensorRef’s pointer and strides are aligned to the specified alignment.
Usage Examples
From include/cutlass/tensor_ref.h:120-144:
Column-major matrix
TensorRef<float, layout::ColumnMajor> A(ptr_A, ldm);
Row-major matrix
TensorRef<float, layout::RowMajor> B(ptr_B, ldm);
Interleaved matrix
TensorRef<int8_t, layout::ColumnMajorInterleaved<32> > C;
Contiguous matrix with runtime layout
int ldm; // leading dimension
layout::Matrix kind; // layout::Matrix::kRowMajor or layout::Matrix::kColumnMajor
TensorRef<int, layout::ContiguousMatrix> E(ptr_E, {ldm, kind});
Accessing elements
// Create a 2D tensor reference
TensorRef<float, layout::RowMajor> tensor(ptr, ldm);
// Access element at (row, col)
float value = tensor.at(MatrixCoord(row, col));
// Equivalent using operator[]
float value = tensor[MatrixCoord(row, col)];
// Get linear offset
LongIndex offset = tensor.offset(MatrixCoord(row, col));
See Also
- Layout Types - Layout functions for various matrix formats
- Gemm - GEMM operator that uses TensorRef for matrix operands
- Array - Statically-sized array container