Matrix Operations
linsdex provides specialized matrix types that track structural properties for automatic optimization. Instead of always using dense matrices, the library uses diagonal, block, and tagged matrices to avoid unnecessary computation.
When to Use
- Working with diagonal covariance matrices
- Building block-structured systems (e.g., position + velocity states)
- Optimizing linear algebra with zero or identity matrices
- Understanding the matrix type system in linsdex
Matrix Types
DiagonalMatrix
Stores only diagonal elements for O(n) operations instead of O(n²) or O(n³).
python
1import jax.numpy as jnp
2from linsdex import DiagonalMatrix
3
4# Create from diagonal elements
5diag_elements = jnp.array([1.0, 2.0, 3.0])
6D = DiagonalMatrix(diag_elements)
7
8# Create identity matrix
9I = DiagonalMatrix.eye(3)
10
11# Efficient operations
12D_inv = D.get_inverse() # O(n) instead of O(n³)
13log_det = D.get_log_det() # O(n) instead of O(n³)
14elements = D.get_elements() # Get diagonal as array
DenseMatrix
General dense matrices when structure cannot be exploited.
python
1from linsdex import DenseMatrix, TAGS
2
3# Create a dense matrix
4elements = jnp.array([[1.0, 0.5], [0.5, 2.0]])
5M = DenseMatrix(elements, tags=TAGS.no_tags)
6
7# Operations
8M_inv = M.get_inverse()
9chol = M.get_cholesky()
10log_det = M.get_log_det()
Block Matrices
For higher-order systems with natural block structure (e.g., position + velocity in tracking).
python
1from linsdex.matrix.block import Block2x2Matrix
2from linsdex import DiagonalMatrix, DenseMatrix, TAGS
3
4# Create a 2x2 block matrix
5# [[A, B],
6# [C, D]]
7A = DiagonalMatrix.eye(2)
8B = DenseMatrix(jnp.zeros((2, 2)), tags=TAGS.zero_tags)
9C = DenseMatrix(jnp.zeros((2, 2)), tags=TAGS.zero_tags)
10D = DiagonalMatrix.eye(2)
11
12block_matrix = Block2x2Matrix(A, B, C, D)
13
14# Operations work on the block structure
15inv = block_matrix.get_inverse()
Tags track properties like zero and infinite values, enabling symbolic simplification before numerical computation.
python
1from linsdex import TAGS
2
3# Available tag configurations
4TAGS.no_tags # Regular matrix (non-zero, non-infinite)
5TAGS.zero_tags # Matrix is zero
6TAGS.inf_tags # Matrix has infinite elements (represents total uncertainty)
Tags propagate through operations automatically:
python
1from linsdex import DenseMatrix, TAGS
2
3# Create a zero matrix
4zero = DenseMatrix(jnp.zeros((3, 3)), tags=TAGS.zero_tags)
5nonzero = DenseMatrix(jnp.eye(3), tags=TAGS.no_tags)
6
7# Operations are detected symbolically
8result = zero @ nonzero # Detected as zero without computation
9result = nonzero + zero # Detected as nonzero without addition
Infinite matrices represent total uncertainty (precision = 0):
python
1# Used in potentials to represent uninformative priors
2inf_precision = DenseMatrix(jnp.zeros((3, 3)), tags=TAGS.inf_tags)
3
4# This indicates "no information" about a variable
Code Examples
Efficient Diagonal Operations
python
1from linsdex import DiagonalMatrix, StandardGaussian
2
3dim = 100
4
5# Independent dimensions with diagonal covariance
6variances = jnp.ones(dim)
7Sigma = DiagonalMatrix(variances)
8
9# All operations are O(n) instead of O(n³)
10precision = Sigma.get_inverse()
11log_det = Sigma.get_log_det()
12chol = Sigma.get_cholesky()
13
14# Use in Gaussian distributions
15mu = jnp.zeros(dim)
16dist = StandardGaussian(mu, Sigma)
Block Matrix for State Space Models
python
1from linsdex.matrix.block import Block2x2Matrix
2from linsdex import DiagonalMatrix, DenseMatrix, TAGS
3
4# 2D state: [position, velocity]
5# Continuous-time dynamics: d/dt [x, v] = [[0, 1], [0, 0]] [x, v]
6# Discrete transition matrix (Euler approximation):
7
8dt = 0.1
9dim = 1
10
11# Position block
12A11 = DiagonalMatrix.eye(dim) # x_new = x + ...
13A12 = DiagonalMatrix(jnp.ones(dim) * dt) # ... + dt * v
14A21 = DenseMatrix(jnp.zeros((dim, dim)), tags=TAGS.zero_tags) # v_new = ...
15A22 = DiagonalMatrix.eye(dim) # ... + v
16
17transition_matrix = Block2x2Matrix(A11, A12, A21, A22)
python
1from linsdex import DenseMatrix, DiagonalMatrix, TAGS
2
3# Regular (non-zero) matrix
4M = DenseMatrix(jnp.eye(3), tags=TAGS.no_tags)
5
6# Zero matrix (will be detected in operations)
7Z = DenseMatrix(jnp.zeros((3, 3)), tags=TAGS.zero_tags)
8
9# Diagonal matrix (automatically handles tags)
10D = DiagonalMatrix(jnp.array([1.0, 2.0, 3.0]))
Matrix Operations
python
1from linsdex import DiagonalMatrix, DenseMatrix, TAGS
2
3D = DiagonalMatrix(jnp.array([2.0, 3.0]))
4M = DenseMatrix(jnp.array([[1.0, 0.5], [0.5, 1.0]]), tags=TAGS.no_tags)
5
6# Matrix-vector multiplication
7v = jnp.array([1.0, 2.0])
8result = D @ v # Efficient diagonal multiplication
9
10# Matrix-matrix operations
11result = D @ M # Diagonal times dense
12
13# Inverse
14D_inv = D.get_inverse()
15
16# Cholesky decomposition
17chol = M.get_cholesky()
18
19# Log determinant
20log_det = D.get_log_det()
Using with Gaussian Distributions
python
1from linsdex import StandardGaussian, NaturalGaussian, DiagonalMatrix
2
3dim = 5
4
5# Independent Gaussian with diagonal covariance
6mu = jnp.zeros(dim)
7Sigma = DiagonalMatrix.eye(dim) * 0.5 # Scalar multiplication
8
9std_dist = StandardGaussian(mu, Sigma)
10
11# Convert to natural form
12nat_dist = std_dist.to_nat() # Precision is also DiagonalMatrix
Key Classes
DiagonalMatrix(elements) - Diagonal matrix from 1D array
DenseMatrix(elements, tags) - Dense matrix with symbolic tags
Block2x2Matrix(A, B, C, D) - 2x2 block matrix
Block3x3Matrix(...) - 3x3 block matrix
TAGS - Symbolic tags for optimization
Common Methods
All matrix types support:
get_inverse() - Matrix inverse
get_cholesky() - Cholesky decomposition
get_log_det() - Log determinant
get_elements() - Raw array elements
@ operator - Matrix multiplication (matmul)
- Scalar multiplication and addition
Tips
- Use
DiagonalMatrix whenever dimensions are independent to save computation
- Set correct tags when creating
DenseMatrix to enable symbolic optimization
- Block matrices are useful for higher-order state space models
- Tags propagate automatically through operations
- The library chooses the most efficient representation for operation results
- Use
DiagonalMatrix.eye(n) for identity matrices