Skip to main content

Linear

Applies a linear transformation to the incoming data: y = xA^T + b Also known as a fully connected layer or dense layer.

Constructor

class Linear extends Module

constructor(
  inFeatures: number,
  outFeatures: number,
  options?: {
    bias?: boolean;
    dtype?: "float32" | "float64";
    device?: "cpu" | "webgpu" | "wasm";
  }
)
Parameters:
  • inFeatures - Size of each input sample
  • outFeatures - Size of each output sample
  • options.bias - If true, adds learnable bias (default: true)
  • options.dtype - Data type for weights (default: ‘float32’)
  • options.device - Device to place tensors on (default: ‘cpu’)
Throws:
  • InvalidParameterError - If dimensions are invalid

Mathematical Formulation

y = x * W^T + b
Where:
  • x is the input tensor of shape (*, in_features)
  • W is the weight matrix of shape (out_features, in_features)
  • b is the bias vector of shape (out_features,)
  • y is the output tensor of shape (*, out_features)

Shape Conventions

Input: (*, in_features) where * means any number of leading dimensions
  • 1D: (in_features) → Output: (out_features)
  • 2D: (batch, in_features) → Output: (batch, out_features)
  • 3D: (batch, seq_len, in_features) → Output: (batch, seq_len, out_features)
Output: (*, out_features) - all leading dimensions are preserved

Attributes

  • weight - Learnable weights of shape (out_features, in_features)
  • bias - Learnable bias of shape (out_features,) if bias=true

Initialization

Weights are initialized using Kaiming/He initialization:
  • weights ~ N(0, sqrt(2/in_features))
  • Biases are initialized to zeros
This initialization is optimal for ReLU activations and helps prevent vanishing/exploding gradients.

Properties

  • inputSize: number - Number of input features
  • outputSize: number - Number of output features

Methods

forward

forward(input: Tensor): Tensor
forward(input: GradTensor): GradTensor
Computes the linear transformation y = x * W^T + b. Parameters:
  • input - Input tensor of shape (*, in_features)
Returns: Output tensor of shape (*, out_features) Throws:
  • ShapeError - If input shape is invalid
  • DTypeError - If input dtype is unsupported

getWeight

getWeight(): Tensor
Gets the weight matrix. Returns: Weight tensor of shape (out_features, in_features)

getBias

getBias(): Tensor | undefined
Gets the bias vector. Returns: Bias tensor of shape (out_features,) or undefined if no bias

Examples

Basic Usage

import { Linear } from 'deepbox/nn';
import { tensor } from 'deepbox/ndarray';

// Create a linear layer
const layer = new Linear(20, 30);

// Forward pass with 2D input (batch)
const input = tensor([[1, 2, 3, /* ... */, 20]]); // shape: (1, 20)
const output = layer.forward(input);              // shape: (1, 30)

Without Bias

const layerNoBias = new Linear(10, 5, { bias: false });

With Sequence Data

// Process sequences (e.g., for transformers)
const layer = new Linear(512, 256);

// Input: (batch=2, seq_len=10, features=512)
const sequences = tensor(/* ... */);
const output = layer.forward(sequences); // (batch=2, seq_len=10, features=256)

Building a Multi-Layer Perceptron

import { Module, Linear, ReLU } from 'deepbox/nn';
import type { Tensor } from 'deepbox/ndarray';

class MLP extends Module {
  private fc1: Linear;
  private relu1: ReLU;
  private fc2: Linear;
  private relu2: ReLU;
  private fc3: Linear;

  constructor() {
    super();
    this.fc1 = new Linear(784, 256);
    this.relu1 = new ReLU();
    this.fc2 = new Linear(256, 128);
    this.relu2 = new ReLU();
    this.fc3 = new Linear(128, 10);

    this.registerModule('fc1', this.fc1);
    this.registerModule('relu1', this.relu1);
    this.registerModule('fc2', this.fc2);
    this.registerModule('relu2', this.relu2);
    this.registerModule('fc3', this.fc3);
  }

  forward(x: Tensor): Tensor {
    x = this.fc1.forward(x);
    x = this.relu1.forward(x);
    x = this.fc2.forward(x);
    x = this.relu2.forward(x);
    x = this.fc3.forward(x);
    return x;
  }
}

const model = new MLP();
const input = tensor(/* 28x28 flattened image */);
const predictions = model.forward(input);

Training Example

import { Linear } from 'deepbox/nn';
import { parameter, tensor } from 'deepbox/ndarray';
import { Adam } from 'deepbox/optim';
import { mseLoss } from 'deepbox/nn/losses';

const layer = new Linear(10, 5);
const optimizer = new Adam(layer.parameters());

for (let epoch = 0; epoch < 100; epoch++) {
  // Zero gradients
  layer.zeroGrad();

  // Forward pass
  const input = parameter(tensor(/* ... */));
  const output = layer.forward(input);
  const target = parameter(tensor(/* ... */));

  // Compute loss
  const loss = mseLoss(output, target);

  // Backward pass
  loss.backward();

  // Update weights
  optimizer.step();
}

Performance Considerations

  1. Batch Processing: Process multiple samples together for better performance:
    // Good: Batch processing
    const input = tensor([[...], [...], [...]]); // shape: (batch, features)
    const output = layer.forward(input);
    
    // Slower: Individual samples
    for (const sample of samples) {
      const output = layer.forward(sample);
    }
    
  2. Memory Layout: Inputs are reshaped internally for efficient matrix multiplication. Contiguous tensors perform better.
  3. Data Type: Use float32 (default) unless you need the precision of float64. Float32 is faster and uses less memory.

Common Patterns

Residual Connections

class ResidualBlock extends Module {
  private fc: Linear;
  private relu: ReLU;

  constructor(dim: number) {
    super();
    this.fc = new Linear(dim, dim);
    this.relu = new ReLU();
    this.registerModule('fc', this.fc);
    this.registerModule('relu', this.relu);
  }

  forward(x: Tensor): Tensor {
    const residual = x;
    let out = this.fc.forward(x);
    out = this.relu.forward(out);
    return add(out, residual); // Skip connection
  }
}

Projection Layers

// Dimension reduction/expansion
const projectionDown = new Linear(1024, 256); // Reduce dimensions
const projectionUp = new Linear(256, 1024);   // Expand dimensions

See Also

References

Build docs developers (and LLMs) love