Automatic Differentiation

MIND includes a built-in autodiff engine that generates optimized gradient code at the IR level using reverse-mode automatic differentiation.

Basic Usage

Mark functions as differentiable to enable gradient computation:

@differentiable
fn mse_loss(pred: Tensor<f32, N>, target: Tensor<f32, N>) -> f32 {
    mean((pred - target) ** 2)
}

fn main() {
    let pred = [1.0, 2.0, 3.0];
    let target = [1.5, 2.5, 3.5];

    // Compute loss
    let loss = mse_loss(pred, target);

    // Get gradient function
    let grad_fn = grad(mse_loss);
    let d_pred = grad_fn(pred, target);

    print(d_pred);  // Gradient w.r.t. pred
}

How It Works

MIND uses source-transformation reverse-mode AD:

  • Forward pass: Compute output while recording operations
  • Backward pass: Propagate gradients through recorded operations
  • Optimization: Apply standard compiler optimizations to gradient code

Supported Operations

All Core v1 operations have defined gradients:

OperationGradient
add(a, b)∂a = upstream, ∂b = upstream
mul(a, b)∂a = upstream * b, ∂b = upstream * a
matmul(a, b)∂a = upstream @ bᵀ, ∂b = aᵀ @ upstream
relu(x)upstream * (x > 0)
sum(x)broadcast(upstream, shape(x))

Higher-Order Gradients

@differentiable
fn f(x: f32) -> f32 {
    x ** 3
}

// First derivative: 3x²
let df = grad(f);

// Second derivative: 6x
let d2f = grad(df);

// Third derivative: 6
let d3f = grad(d2f);

Custom Gradients

@differentiable
@custom_grad(my_relu_grad)
fn my_relu(x: Tensor<f32, N>) -> Tensor<f32, N> {
    max(x, 0.0)
}

fn my_relu_grad(x: Tensor<f32, N>, upstream: Tensor<f32, N>) -> Tensor<f32, N> {
    upstream * cast<f32>(x > 0.0)
}

Gradient Checkpointing

For memory-constrained training, use checkpointing:

@differentiable
@checkpoint  // Recompute forward during backward
fn transformer_block(x: Tensor<f32, B, S, D>) -> Tensor<f32, B, S, D> {
    // Large intermediate activations are not stored
    let attn = self_attention(x);
    let ffn = feed_forward(attn);
    ffn
}

Learn More

See the full autodiff specification at mind-spec/autodiff.md.