Imagine you have two piles of sand with different shapes, and you want to reshape one pile to match the other. Optimal transport (OT) finds the cheapest way to move the sand — that is, the plan that minimizes the total cost of moving mass from one distribution to another.
In the simplest case, the two distributions are one-dimensional histograms. The “transport plan” is a matrix that tells you how much mass to move from each bin of the source to each bin of the target.
# Two simple 1D distributions
source_dist <- c(0.4, 0.1, 0.4, 0.1)
target_dist <- c(0.1, 0.3, 0.1, 0.3, 0.2)
oldpar <- par(mfrow = c(1, 2), mar = c(4, 4, 3, 1))
barplot(source_dist, col = "steelblue", main = "Source distribution",
names.arg = seq_along(source_dist), ylim = c(0, 0.5),
xlab = "Bin", ylab = "Mass")
barplot(target_dist, col = "tomato", main = "Target distribution",
names.arg = seq_along(target_dist), ylim = c(0, 0.5),
xlab = "Bin", ylab = "Mass")par(mfrow = oldpar)OT finds a transport plan (a matrix) that optimally maps the source to the target. Each cell of the matrix represents how much mass moves from a source bin to a target bin.
A tensor is simply a generalization of familiar data structures:
oldpar <- par(mfrow = c(1, 3), mar = c(2, 2, 3, 1))
# Vector (order 1)
barplot(c(3, 1, 4, 1, 5), col = "steelblue", main = "Order 1: Vector")
# Matrix (order 2)
mat <- matrix(c(1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6), nrow = 3)
image(mat, col = gray((0:255) / 255), axes = FALSE, main = "Order 2: Matrix")
# 3D tensor (show one slice)
arr <- array(0, dim = c(3, 4, 2))
arr[,,1] <- matrix(c(1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6), nrow = 3)
arr[,,2] <- matrix(c(6, 5, 4, 3, 5, 4, 3, 2, 4, 3, 2, 1), nrow = 3)
image(arr[,,1], col = gray((0:255) / 255), axes = FALSE,
main = "Order 3: Tensor\n(slice 1)")par(mfrow = oldpar)Standard OT works well for vectors and matrices, but what if your data is a higher-order tensor?
Optimal Tensor Transport (OTT) (Kerdoncuff 2022) extends OT to tensors of any order. Given two tensors \(X\) and \(Y\) of the same order, OTT finds transport plans — one or more matrices that describe how to map each dimension of \(X\) to the corresponding dimension of \(Y\).
f ParameterThe f parameter is the core idea that makes OTT
flexible. It is a vector that assigns each dimension to a
transport plan group. This controls how dimensions
share transport plans.
| Setting | Meaning | Analogy |
|---|---|---|
f = c(1, 2) |
Each dimension gets its own transport plan | Co-Optimal Transport |
f = c(1, 1) |
Both dimensions share the same plan | Gromov-Wasserstein-like |
f = c(1, 1, 2) |
Dims 1 & 2 share a plan; dim 3 has its own | GW collections |
For example, with a 3D tensor (e.g., subjects x genes x time):
f = c(1, 2, 3) learns separate transport plans for
subjects, genes, and timef = c(1, 1, 2) forces subjects and genes to share a
plan, while time has its ownHere we walk through a minimal example step by step.
library("otTensor")
library("rTensor")We create two small matrices (order-2 tensors) as source and target.
# Source: a 4 x 5 matrix
arrX <- matrix(0, nrow = 4, ncol = 5)
for (i in 1:4) {
for (j in 1:5) {
arrX[i, j] <- i + j
}
}
# Target: a 6 x 7 matrix (different size is OK)
arrY <- matrix(0, nrow = 6, ncol = 7)
for (i in 1:6) {
for (j in 1:7) {
arrY[i, j] <- i + j
}
}
# Convert to Tensor objects
X <- as.tensor(arrX)
Y <- as.tensor(arrY)f parameterSince this is an order-2 tensor with 2 dimensions, we set
f = c(1, 2) so that each dimension gets its own transport
plan.
f <- c(1, 2)result <- OTT(X = X, Y = Y, f = f,
num.sample = 500, num.iter = 100)The result contains a list of transport plan matrices
Ts. Since f = c(1, 2), there are two
plans:
Ts[[1]]: maps rows of X (size 4) to rows of Y (size
6)Ts[[2]]: maps columns of X (size 5) to columns of Y
(size 7)# Transport plan dimensions
cat("Transport plan 1:", dim(result$Ts[[1]]), "\n")## Transport plan 1: 4 6
cat("Transport plan 2:", dim(result$Ts[[2]]), "\n")## Transport plan 2: 5 7
.show_matrix <- function(mat, main = "") {
mat_rev <- t(apply(mat, 2, rev))
image(mat_rev, col = gray((0:255) / 255),
xaxt = "n", yaxt = "n",
xlab = "", ylab = "", axes = FALSE, main = main)
}
oldpar <- par(mfrow = c(2, 2), mar = c(2, 2, 3, 1))
.show_matrix(arrX, main = "Source (X)")
.show_matrix(arrY, main = "Target (Y)")
.show_matrix(result$Ts[[1]], main = "Transport Plan 1\n(rows)")
.show_matrix(result$Ts[[2]], main = "Transport Plan 2\n(columns)")par(mfrow = oldpar)Each transport plan is a matrix where brighter cells indicate more mass being transported between the corresponding indices.
| Parameter | Description | Default |
|---|---|---|
X |
Source tensor (rTensor::Tensor object) |
(required) |
Y |
Target tensor (same order as X, sizes may differ) | (required) |
f |
Integer vector assigning each dimension to a transport plan group | (required) |
ps |
List of source marginal distributions (one per unique value in
f) |
Uniform |
qs |
List of target marginal distributions | Uniform |
loss |
Loss function for computing costs | Absolute error |
num.sample |
Number of Monte Carlo samples for gradient estimation | 1000 |
num.iter |
Number of optimization iterations | 200 |
epsilon |
Convergence threshold | 1e-10 |
Tips:
num.sample gives more accurate gradients but is
slowernum.iter (e.g., 50) for exploration,
increase for final resultsloss = function(x, y) (x - y)^2 for squared errorThe next vignette (otTensor-2: Optimal Tensor
Transport) reproduces the experiments from the original paper
(Kerdoncuff 2022), demonstrating OTT under
all six f configurations:
## R version 3.6.3 (2020-02-29)
## Platform: x86_64-conda-linux-gnu (64-bit)
## Running under: Rocky Linux 9.5 (Blue Onyx)
##
## Matrix products: default
## BLAS: /home/koki/miniconda3/lib/libblas.so.3.9.0
## LAPACK: /home/koki/miniconda3/lib/liblapack.so.3.9.0
##
## locale:
## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
## [3] LC_TIME=en_US.UTF-8 LC_COLLATE=C
## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C
## [9] LC_ADDRESS=C LC_TELEPHONE=C
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] rTensor_1.4.8 otTensor_0.99.0
##
## loaded via a namespace (and not attached):
## [1] digest_0.6.31 R6_2.5.1 jsonlite_1.8.4 evaluate_0.20
## [5] highr_0.10 rlang_0.4.11 jquerylib_0.1.4 bslib_0.3.1
## [9] rmarkdown_2.11 tools_3.6.3 xfun_0.38 yaml_2.3.7
## [13] fastmap_1.1.1 compiler_3.6.3 htmltools_0.5.5 knitr_1.42
## [17] sass_0.4.0