The softmax function is a smooth approximation of the max function, and is used in many machine learning models. Similarly we can define the soft-argmax function, which is a smooth approximation of the argmax function.
Definition and notation
First, let us define \Delta_{K} = \{p \in \mathbb{R}^K \geq 0: \sum_{k} p_k = 1\} the probability simplex in \mathbb{R}^K, and u_K = (1/K, \dots, 1/K)^\top the uniform distribution in \Delta_{K}, and the standard basis vectors (\delta_k)_{k=1,\dots,K}, where \delta_k = (0,\dots,\underbrace{1}_{k-\rm{th}},\dots,0)^\top \in \mathbb{R}^K.
Formally, the standard soft-argmax function \sigma \colon \mathbb{R}^{K} \to (0, 1)^{K}, where K \geq 1, takes a vector z = (z_{1}, \dots, z_{K})^\top \in \mathbb{R}^{K} and computes each component of the vector \sigma(z) \in [0, 1]^{K} by \left(\sigma(z)\right)_k = \frac{\exp(z_k)}{\sum_{k'=1}^{K} \exp(z_{k'})}\enspace, \quad \text{for } k = 1, \dots, K.
Note in particular that \sigma(z)=\sigma_{u_K, 1}(z) for any z \in \mathbb{R}^K.
We also define for any q \in \Delta_{K} with positive coordinates1, i.e., q_k>0, for all k\in [K], the function \sigma_{q,\beta} \colon \mathbb{R}^{K} \to \mathbb{R}^K as
\left(\sigma_{q,\beta}(z)\right)_k = \frac{q_k\exp(z_k/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp(z_{k'}/\beta)} \enspace, \quad \text{for } k = 1, \dots, K.
Last but not least, we introduce the real valued log-sum-exp function (or weighted softmax), defined for any vector z \in \mathbb{R}^K by
\text{logsumexp}_{q,\beta}(z) = \beta \cdot \log\left(\sum_{k=1}^{K} q_k \cdot \exp(z_k/\beta)\right).
Variational formulation
In this note we show that the soft-argmax function can be written as the conjugate of the log-sum-exp function.
Theorem 1 Let q \in \Delta_{K} be with positive coordinates, i.e., q_k>0, for all k\in [K], \beta>0 and let z \in \mathbb{R}^K. Then,
\begin{align*} \text{logsumexp}_{q,\beta}(z) & = \max_{ p \in \Delta_K} \langle z, p \rangle - \beta\sum_{k=1}^{K} p_k \log(p_k / q_k) \\ \sigma_{q,\beta}(z) & = \argmax_{ p \in \Delta_K} \langle z, p \rangle - \beta\sum_{k=1}^{K} p_k \log(p_k / q_k)\enspace. \end{align*}
Proof. Using Lagrange multipliers (see for instance Ch. 5, Boyd and Vandenberghe 2004) you get for \Lambda=(\lambda_1,\dots,\lambda_K)^\top, \mu \in \mathbb{R} and p \in \Delta_K the Lagrangian function:
\mathcal{L}(p,\mu,\Lambda) = \langle z, p \rangle - \beta\sum_{k=1}^{K} p_k \log(p_k / q_k) + \mu\left(\sum_{k=1}^{K} p_k - 1\right) - \sum_{k=1}^{K} \lambda_k p_k.
The Slater condition is satisfied, so the KKT conditions are necessary and sufficient for optimality. The KKT conditions are \begin{align*} \frac{\partial \mathcal{L}}{\partial p_k} &= z_k - \beta\log(p_k / q_k) - \beta - \mu - \lambda_k = 0, \quad k = 1, \dots, K,\\ \mu\left(\sum_{k=1}^{K} p_k - 1\right) &= 0,\\ \lambda_k p_k &= 0, \quad k = 1, \dots, K,\\ p_k &\geq 0, \quad k = 1, \dots, K. \end{align*} From the first KKT condition we get p_k = q_k\exp\left(\frac{z_k-\beta-\mu-\lambda_k}{\beta}\right), \quad k = 1, \dots, K. Now, since \lambda_k p_k = 0 and that q_k>0 for all k, we have that \lambda_k = 0 for all k. Thus, we get after normalisation that p_k = \frac{q_k\exp(z_k/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp(z_{k'}/\beta)}, \quad k = 1, \dots, K. Note that \begin{align*} & p_k = \frac{q_k\exp(z_k/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp(z_{k'}/\beta)}\\ \iff & \log( p_k) = \log(q_k) + \frac{z_k}{\beta} - \frac{1}{\beta} \cdot \text{logsumexp}_{q,\beta}(z). \end{align*} Finally, we have that \begin{align*} \beta\sum_{k=1}^{K} p_k \log(p_k / q_k) &= \beta\sum_{k=1}^{K} p_k \left(\frac{z_k}{\beta} - \frac{1}{\beta} \cdot \text{logsumexp}_{q,\beta}(z)\right)\\ &= \sum_{k=1}^{K} p_k z_k - \text{logsumexp}_{q,\beta}(z), \end{align*} Hence, \sum_{k=1}^{K} p_k z_k - \beta \sum_{k=1}^{K} p_k \log(p_k / q_k) = \text{logsumexp}_{q,\beta}(z).
The following limit properties for infinitesimal \beta explain the naming and the regularizing property of the (temperature) parameter \beta:
Proposition 1 Reminding that u_K=(1/K,\dots,1/K)^\top and \delta_{k} is the k-th standard basis vector, for any z \in \mathbb{R}^K, we have that
\begin{align*} \sigma_{u_K,\beta}(z) & \xrightarrow[\beta \to 0]{} \delta_{k_0}, \text{ where } k_0=\argmax_{k\in [K]} z_k \\ \text{logsumexp}_{u_K,\beta}(z) & \xrightarrow[\beta \to 0]{} \max_{k\in [K]} z_k \enspace. \end{align*}
The first limit show that the soft-argmax function is a kind of smooth approximation of the argmax function, while the log-sum-exp function is a smooth approximation of the max function.
Invariance properties
The softmax function is invariant to the addition of a constant to each component of the input vector. More precisely, we have the following result:
Theorem 2 Let q \in \Delta_{K} be with positive coordinates, i.e., q_k>0, for all k\in [K], \beta>0 and let z \in \mathbb{R}^K. Then, for any c \in \mathbb{R}, we have that \sigma_{q,\beta}(z) = \sigma_{q,\beta}(z+c).
Proof. We have that \begin{align*} \sigma_{q,\beta}(z+c)_k &= \frac{q_k\exp((z_k+c)/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp((z_{k'}+c)/\beta)}\\ &= \frac{q_k\exp(z_k/\beta)\exp(c/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp(z_{k'}/\beta)\exp(c/\beta)}\\ &= \frac{q_k\exp(z_k/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp(z_{k'}/\beta)}\\ &= \sigma_{q,\beta}(z)_k. \end{align*}
Let us consider also the effect of rescaling the input vector by a positive constant. We have the following result:
Theorem 3 Let q \in \Delta_{K} be with positive coordinates, i.e., q_k>0, for all k\in [K], \beta>0 and let z \in \mathbb{R}^K. Then, for any \alpha >0, we have that \sigma_{q,\beta}(\alpha z) = \sigma_{q,\tfrac{\beta}{\alpha}}(z).
Proof. We have that \begin{align*} \left(\sigma_{q,\beta}(\alpha z)\right)_k &= \frac{q_k\exp(\alpha z_k/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp(\alpha z_{k'}/\beta)}\\ &= \frac{q_k\exp(z_k/\frac{\beta}{\alpha})}{\sum_{k'=1}^{K} q_{k'}\exp(z_{k'}/\frac{\beta}{\alpha})}\\ &= \left(\sigma_{q,\tfrac{\beta}{\alpha}}(z)\right)_k. \end{align*}
Visualization
We consider the case K=3 for visualization.
You can click on the plot to move z (the purple dot), and see the corresponding soft-argmax \sigma_{q,\beta} (the red dot). You can modify q (the black dot) and \beta with the slider. The level sets displayed are for the function after the max in the log-sum-exp definition, i.e., for a fixed z\in \mathbb{R}^3, q\in \Delta_K and \beta>0, the level set is the set of points p\in \Delta_K such that \begin{align*} \Delta_K &\to \mathbb{R} \\ p & \mapsto \langle z, p \rangle - \beta\sum_{k=1}^{K} p_k \log(p_k / q_k) \end{align*}
Code
= require("d3@v6", "d3-hexbin@0.2")
d3 = require("mathjs")
math = require( 'https://cdn.jsdelivr.net/gh/stdlib-js/math-base-special-digamma@umd/browser.js' )
digamma
import {legend} from "@d3/color-legend"
= 80
densityResolution = ({ left: 30, top: 30, right: 30, bottom: 30 })
margin = 400
height
= [inputs.beta]
beta // lbd = [0.3,0.3,0.4]
= Inputs.range([0.01, 1], {value: 1/3, label: tex`q_1`, step: 0.01})
viewof q1 = Inputs.range([0.01, 1], {value: 1/3, label: tex`q_2`, step: 0.01})
viewof q2 = Inputs.range([0.01, 1], {value: 1/3, label: tex`q_3`, step: 0.01})
viewof q3 = Inputs.range([-10, 10], {value: 0.4, label: tex`z_1`, step: 0.01})
viewof z1 = Inputs.range([-10, 10], {value: 0.2, label: tex`z_2`, step: 0.01})
viewof z2 = Inputs.range([-10, 10], {value: 0.2, label: tex`z_3`, step: 0.01})
viewof z3
= Inputs.form({
viewof inputs z1: viewof z1,
z2: viewof z2,
z3: viewof z3,
beta: Inputs.range([0.01, 10], {value: 1, label: tex`\beta`, step: 0.01}),
q1: viewof q1,
q2: viewof q2,
q3: viewof q3,
})
= [q1, q2, q3]
q = [z1, z2, z3]
z
= (x, alpha, beta) => {
logsumexp let result = 0;
for(let i = 0; i < x.length; i++) result += alpha[i] * Math.exp(x[i] / beta);
return beta * Math.log(result);
}
= (x, z, alpha, beta) => {
objective let result = 0;
for(let i = 0; i < x.length; i++) result -= beta * x[i] * Math.log(x[i] / alpha[i]);
for(let i = 0; i < x.length; i++) result += x[i] * z[i];
return result;
}
= (x, alpha, beta) => {
kull_leibler let result = 0;
for(let i = 0; i < x.length; i++) result -= x[i] * Math.log(x[i] / alpha[i]);
return beta * result;
}
= (x, alpha, beta) => {
softmax let result = [];
let sum = 0;
for(let i = 0; i < x.length; i++) {
.push(alpha[i] * Math.exp(x[i] / beta));
result+= result[i];
sum
}for(let i = 0; i < x.length; i++) result[i] /= sum;
return result;
}
function cartesianToBarycentric(x, y, corners) {
const A = corners[0], B = corners[1], C = corners[2];
const v0 = { x: C.x - A.x, y: C.y - A.y };
const v1 = { x: B.x - A.x, y: B.y - A.y };
const v2 = { x: x - A.x, y: y - A.y };
const d00 = dot(v0, v0);
const d01 = dot(v0, v1);
const d11 = dot(v1, v1);
const d20 = dot(v2, v0);
const d21 = dot(v2, v1);
const denom = d00 * d11 - d01 * d01;
const lambda1 = (d11 * d20 - d01 * d21) / denom;
const lambda2 = (d00 * d21 - d01 * d20) / denom;
const lambda3 = 1.0 - lambda1 - lambda2;
return [lambda1, lambda2, lambda3];
}
function barycentricToCartesian(q, corners) {
return [
0].x * q[0] + corners[1].x * q[1] + corners[2].x * q[2],
corners[0].y * q[0] + corners[1].y * q[1] + corners[2].y * q[2]
corners[;
]
}
function dot(v1, v2) {
return v1.x * v2.x + v1.y * v2.y;
}
= (density, resolution, options, lbd) => {
ternaryDensity = Object.assign({
options size: 400,
margin: { left: 30, top: 30, right: 30, bottom: 30 },
colorScheme: d3.interpolateBlues
, options);
}
const fillScale = d3.scaleSequential(options.colorScheme).domain([0
, Math.log(Math.max(q[0],q[1],q[2]))]);
// const fillScale = d3.scaleSequential(options.colorScheme).domain(d3.extent(density.map(d => d.density)))
const svg = d3.create("svg")
.attr("width", options.size)
.attr("height", options.size);
const x = d3.scaleLinear()
.domain([0, 1])
.range([options.margin.left, options.size - options.margin.right])
const y = d3.scaleLinear()
.domain([0, 1])
.range([options.size - options.margin.bottom, options.margin.top]);
function inverseScaleX(value) {
// Constants from the original scale
const domainMin = 0;
const domainMax = 1;
const rangeMin = options.margin.left;
const rangeMax = options.size - options.margin.right;
// Calculate the proportion of the input value within the range
const proportion = (value - rangeMin) / (rangeMax - rangeMin);
// Apply the proportion to the domain
return proportion * (domainMax - domainMin) + domainMin;
}
function inverseInverseScaleX(value) {
// Constants from the original scale
const domainMin = 0;
const domainMax = 1;
const rangeMin = options.size - options.margin.right;
const rangeMax = options.margin.left;
// Calculate the proportion of the input value within the range
const proportion = (rangeMax - rangeMin);
// Apply the proportion to the domain
return rangeMax - proportion * value ;
}
function inverseScaleY(value) {
// Constants from the original scale
const domainMin = 0;
const domainMax = 1;
const rangeMin = options.size - options.margin.bottom;
const rangeMax = options.margin.top;
// Calculate the proportion of the input value within the range
const proportion = (value - rangeMin) / (rangeMax - rangeMin);
// Apply the proportion to the domain
return proportion * (domainMax - domainMin) + domainMin;
}
function inverseInverseScaleY(value) {
// Constants from the original scale
const domainMin = 0;
const domainMax = 1;
const rangeMin = options.size - options.margin.bottom;
const rangeMax = options.margin.top;
// Calculate the proportion of the input value within the range
const proportion = (rangeMax - rangeMin);
// Apply the proportion to the domain
return proportion * value + rangeMin;
}
const axisBottom = g => g.call(d3.axisBottom(x).ticks(4))
const axisTop = g => g.call(d3.axisTop(x).ticks(4));
const hexagonSize = ((x(1 / (2 * resolution)) - x(0))) / Math.cos(Math.PI / 6) + 1;
const corners = [{ x: 0.5, y: Math.sqrt(3)/2 }, { x: 0, y: 0 }, {x: 1, y: 0 }];
const corners_expanded = [{ x: 0.5, y: Math.sqrt(3)/2 + 0.1 }, { x: -0.1, y: -0.05 }, {x: 1.1, y: -0.05 }];
const line = d3.line().x(d => x(d.x)).y(d => y(d.y));
.append('defs')
svg.append('clipPath')
.attr('id', 'triangle')
.append('path')
.attr('x', options.size / 2)
.attr('y', options.size / 2)
.attr('d', line(corners))
.selectAll('.border')
svg.data([[corners[0], corners[1]],
1], corners[2]],
[corners[2], corners[0]]])
[corners[.enter()
.append('line')
.attr('x1', (d) => x(d[0].x))
.attr('x2', (d) => x(d[1].x))
.attr('y1', (d) => y(d[0].y))
.attr('y2', (d) => y(d[1].y))
.attr('class', 'border')
.attr('stroke', 'black');
const hexbin = d3.hexbin();
.append('g')
svg.attr('clip-path', 'url(#triangle)')
.selectAll(".point")
.data(density)
.join('path')
.attr('transform', d => `translate(${x(corners[0].x * d.x[0] + corners[1].x * d.x[1] + corners[2].x * d.x[2])}, ${y(corners[0].y * d.x[0] + corners[1].y * d.x[1] + corners[2].y * d.x[2])})`)
.attr('fill', d => fillScale(d.density))
.attr('d', hexbin.hexagon(hexagonSize));
.selectAll('text')
svg.data([[corners_expanded[0], corners_expanded[1]],
1], corners_expanded[2]],
[corners_expanded[2], corners_expanded[0]]])
[corners_expanded[.join("text")
.attr("text-anchor", "middle")
.attr("alignment-baseline", "middle")
.attr("font-style", "italic")
.attr("x", d => (x(d[0].x) + x(d[1].x)) / 2)
.attr("y", d => (y(d[0].y) + y(d[1].y)) / 2)
.text((d, i) => {
if(i == 0) return "x₂" ;
if(i == 1) return "x₃";
if(i == 2) return "x₁";
})
let xy = barycentricToCartesian(z, corners)
let xyr = [inverseInverseScaleX(xy[0]), inverseInverseScaleY(xy[1])];
let xy_q = barycentricToCartesian(q, corners)
let xyr_q = [inverseInverseScaleX(xy_q[0]), inverseInverseScaleY(xy_q[1])];
let sftmax = softmax([z1, z2, z3], q, beta)
let xy_s = barycentricToCartesian(sftmax, corners)
let xyr_s = [inverseInverseScaleX(xy_s[0]), inverseInverseScaleY(xy_s[1])];
// If it does not exist, append a new red circle
.append("circle")
svg.attr("class", "red-circle") // Add a class for easy selection
.attr("cx", xyr_s[0])
.attr("cy", xyr_s[1])
.attr("z-index", 1000)
.attr("r", 7) // Radius of the circle
.attr("fill", "red"); // Fill color of the circle
.append("circle")
svg.attr("class", "purple-circle") // Add a class for easy selection
.attr("cx", xyr[0])
.attr("cy", xyr[1])
.attr("z-index", 1000)
.attr("r", 7) // Radius of the circle
.attr("fill", "purple"); // Fill color of the circle
.append("circle")
svg.attr("class", "black-circle") // Add a class for easy selection
.attr("cx", xyr_q[0])
.attr("cy", xyr_q[1])
.attr("z-index", 1000)
.attr("r", 7) // Radius of the circle
.attr("fill", "black"); // Fill color of the circle
// Append text for the coordinates
.append("text")
svg.attr("class", "coord-text") // Add a class for easy selection
// .attr("x", xyr[0] + 10) // Position the text right next to the circle
// .attr("y", xyr[1])
.attr("x", 300) // Position the text right next to the circle
.attr("y", 80)
.attr("fill", "black") // Text color
.style("font-size", "10px")
.text(`q=(${q[0].toFixed(2)}, ${q[1].toFixed(2)}, ${q[2].toFixed(2)})`); // Display coordinates
// Append color legend for purple-circle and red-circle:
// Append text for the coordinates
.append("text")
svg.attr("class", "coord-text") // Add a class for easy selection
// .attr("x", xyr[0] + 10) // Position the text right next to the circle
// .attr("y", xyr[1])
.attr("x", 300) // Position the text right next to the circle
.attr("y", 100)
.attr("fill", "red") // Text color
.style("font-size", "10px")
.text(`σ_{q,β}=(${sftmax[0].toFixed(2)}, ${sftmax[1].toFixed(2)}, ${sftmax[2].toFixed(2)})`); // Display coordinates
// Append color legend for purple-circle and red-circle:
.append("text")
svg.attr("class", "coord-text") // Add a class for easy selection
.attr("x", 300) // Position the text right next to the circle
.attr("y", 120)
.attr("fill", "purple") // Text color
.style("font-size", "10px")
.text(`z=(${z1.toFixed(2)}, ${z2.toFixed(2)}, ${z3.toFixed(2)})`); // Display coordinates
const removeLine = g => g.select(".domain").remove();
function set(input, value) {
.value = value;
input.dispatchEvent(new Event("input", {bubbles: true}));
input
}
// Attach a click event listener to the SVG
.on("click", function(event) {
svg// Extract the click coordinates
const [clickX, clickY] = d3.pointer(event);
= cartesianToBarycentric(inverseScaleX(clickX), inverseScaleY(clickY), corners);
lbd
set(viewof z1, lbd[2]);
set(viewof z2, lbd[1]);
set(viewof z3, lbd[0]);
;
})
// display axis values
.append("g")
svg.attr("transform", `translate(0, ${options.size - options.margin.bottom})`)
.call(axisBottom)
.call(removeLine)
.append("g")
svg.attr("transform", `translate(${x(1) + options.margin.right / 2 }, ${y(0) + 26}) rotate(-120)`)
.call(axisBottom)
.call(removeLine)
.call(g => g.selectAll("text").attr("transform", "translate(11, 22) rotate(120)"));
.append("g")
svg.attr("transform", `translate(${x(0.5) + options.margin.left / 2}, ${y(Math.sqrt(3) / 2) - 26}) rotate(120)`)
.call(axisBottom)
.call(removeLine)
.call(g => g.selectAll("text").attr("transform", "translate(-11, 22) rotate(-120)"));
return svg.node();
}
= {
levelLogsumexp const density = [];
for(let i = 1; i < densityResolution; i++) {
for(let j = 1; j < densityResolution - i; j++) {
const x1 = i / densityResolution ;
const x2 = j / densityResolution;
const x3 = Math.max(0, 1 - x1 - x2);
const x = [ x1, x2, x3 ];
.push({ x, density: logsumexp(x, q, beta) });
density
}
}return density;
}
= {
levelKL const density = [];
for(let i = 1; i < densityResolution; i++) {
for(let j = 1; j < densityResolution - i; j++) {
const x1 = i / densityResolution ;
const x2 = j / densityResolution;
const x3 = Math.max(0, 1 - x1 - x2);
const x = [ x1, x2, x3 ];
// density.push({ x, density: logsumexp(x, q, beta) / Math.log(Math.max(q[0],q[1], q[2]) * Math.exp(1/beta)) });
.push({ x, density: kull_leibler(x, q, beta) });
density
}
}return density;
}
= {
levelKL_full const density = [];
for(let i = 1; i < densityResolution; i++) {
for(let j = 1; j < densityResolution - i; j++) {
const x1 = i / densityResolution ;
const x2 = j / densityResolution;
const x3 = Math.max(0, 1 - x1 - x2);
const x = [ x1, x2, x3 ];
// density.push({ x, density: logsumexp(x, q, beta) / Math.log(Math.max(q[0],q[1], q[2]) * Math.exp(1/beta)) });
.push({ x, density: objective(x, z, q, beta) });
density
}
}return density;
}
ternaryDensity(levelKL_full, densityResolution, { size: 500, colorScheme: d3.interpolateBlues})
= legend({color: d3.scaleSequential(d3.extent([0, logsumexp(z, q, beta)]), d3.interpolateBlues), title: "Intensity"}) logsumexpColorScale
Credit
The observable code was made with the help of François-David Collin
The plot is freely inspired from Herb Susmann’s code on Dirichlet Distribution
The terminology is inspired by posts from Gabriel Peyré: https://x.com/gabrielpeyre/status/1830470713041354968, https://x.com/gabrielpeyre/status/1680804520862056448
References
Footnotes
the case with q_k=0 for some k can be reduced to a case with fewer coordinates↩︎