Skip to main content

Overview

TensorRef is a template for objects pointing to the start of tensors of arbitrary rank and layout within memory. A TensorRef combines a pointer and a Layout concept to provide structured access to multi-dimensional data. Header: cutlass/tensor_ref.h

Template Signature

template <
  typename Element_,
  typename Layout_
>
class TensorRef;

Template Parameters

Element_
typename
Data type of element stored within tensor (concept: NumericType). Examples: float, half_t, int8_t
Layout_
typename
Defines a mapping from logical coordinate to linear memory (concept: Layout). Examples: layout::RowMajor, layout::ColumnMajor, IdentityTensorLayout<2>

Member Types

using Element = Element_;
using Layout = Layout_;
using Reference = /* Element& or SubbyteReference<Element> */;

static int const kRank = Layout::kRank;

using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorCoord = typename Layout::TensorCoord;
using Stride = typename Layout::Stride;

using ConstTensorRef = TensorRef<
  typename platform::remove_const<Element>::type const,
  Layout>;

using NonConstTensorRef = TensorRef<
  typename platform::remove_const<Element>::type,
  Layout>;

Constructors

Default Constructor

CUTLASS_HOST_DEVICE
TensorRef();
Constructs a null TensorRef.

Pointer and Layout Constructor

CUTLASS_HOST_DEVICE
TensorRef(
  Element *ptr,
  Layout const &layout
);
Parameters:
  • ptr: Pointer to start of tensor
  • layout: Layout object containing stride and mapping function

Converting Constructor

template<typename _Magic = int>
CUTLASS_HOST_DEVICE
TensorRef(
  NonConstTensorRef const &ref,
  _Magic magic = /* SFINAE trick */
);
Converting constructor from TensorRef to non-constant data.

Member Functions

const_ref

CUTLASS_HOST_DEVICE
ConstTensorRef const_ref() const;
Returns a reference to constant-valued tensor.

non_const_ref

CUTLASS_HOST_DEVICE
NonConstTensorRef non_const_ref() const;
Returns a reference to non-constant tensor.

reset

CUTLASS_HOST_DEVICE
void reset(Element* ptr = nullptr);

CUTLASS_HOST_DEVICE
void reset(Element* ptr, Layout const &layout);
Updates the pointer, optionally with a new layout.

good

CUTLASS_HOST_DEVICE
bool good() const;
Returns true if the TensorRef is non-null.

data

CUTLASS_HOST_DEVICE
Element * data() const;

CUTLASS_HOST_DEVICE
Reference data(LongIndex idx) const;
Returns the pointer to referenced data, or a reference to the element at a given linear index.

layout

CUTLASS_HOST_DEVICE
Layout & layout();

CUTLASS_HOST_DEVICE
Layout layout() const;
Returns the layout object.

stride

CUTLASS_HOST_DEVICE
Stride stride() const;

CUTLASS_HOST_DEVICE
Stride & stride();

CUTLASS_HOST_DEVICE
typename Layout::Stride::Index stride(int dim) const;

CUTLASS_HOST_DEVICE
typename Layout::Stride::Index & stride(int dim);
Returns the layout object’s stride vector or a specific stride dimension.

offset

CUTLASS_HOST_DEVICE
LongIndex offset(TensorCoord const& coord) const;
Computes the offset of an index from the origin of the tensor.

at

CUTLASS_HOST_DEVICE
Reference at(TensorCoord const& coord) const;
Returns a reference to the element at a given coordinate.

operator[]

CUTLASS_HOST_DEVICE
Reference operator[](TensorCoord const& coord) const;
Returns a reference to the element at a given coordinate.

add_pointer_offset

CUTLASS_HOST_DEVICE
TensorRef & add_pointer_offset(LongIndex offset_);
Adds an offset to the pointer.

add_coord_offset

CUTLASS_HOST_DEVICE
TensorRef & add_coord_offset(TensorCoord const &coord);
Adds a coordinate offset to the pointer.

Arithmetic Operators

CUTLASS_HOST_DEVICE
TensorRef operator+(TensorCoord const& b) const;

CUTLASS_HOST_DEVICE
TensorRef & operator+=(TensorCoord const& b);

CUTLASS_HOST_DEVICE
TensorRef operator-(TensorCoord const& b) const;

CUTLASS_HOST_DEVICE
TensorRef & operator-=(TensorCoord const& b);
Returns or modifies a TensorRef offset by a given amount.

Helper Functions

make_TensorRef

template <typename Element, typename Layout>
CUTLASS_HOST_DEVICE
TensorRef<Element, Layout> make_TensorRef(Element *ptr, Layout const &layout);
Constructs a TensorRef, deducing types from arguments.

TensorRef_aligned

template <typename Element, typename Layout>
CUTLASS_HOST_DEVICE
bool TensorRef_aligned(TensorRef<Element, Layout> const &ref, int alignment);
Checks if a TensorRef’s pointer and strides are aligned to the specified alignment.

Usage Examples

From include/cutlass/tensor_ref.h:120-144:

Column-major matrix

TensorRef<float, layout::ColumnMajor> A(ptr_A, ldm);

Row-major matrix

TensorRef<float, layout::RowMajor> B(ptr_B, ldm);

Interleaved matrix

TensorRef<int8_t, layout::ColumnMajorInterleaved<32> > C;

Contiguous matrix with runtime layout

int ldm;                     // leading dimension
layout::Matrix kind;         // layout::Matrix::kRowMajor or layout::Matrix::kColumnMajor

TensorRef<int, layout::ContiguousMatrix> E(ptr_E, {ldm, kind});

Accessing elements

// Create a 2D tensor reference
TensorRef<float, layout::RowMajor> tensor(ptr, ldm);

// Access element at (row, col)
float value = tensor.at(MatrixCoord(row, col));

// Equivalent using operator[]
float value = tensor[MatrixCoord(row, col)];

// Get linear offset
LongIndex offset = tensor.offset(MatrixCoord(row, col));

See Also

  • Layout Types - Layout functions for various matrix formats
  • Gemm - GEMM operator that uses TensorRef for matrix operands
  • Array - Statically-sized array container

Build docs developers (and LLMs) love