Soft-max and Soft-argmax

Back to exponential weights and log-sum-exp functions
statistics
optimization
machine learning
Authors

Joseph Salmon

François-David Collin

Published

September 3, 2024

The softmax function is a smooth approximation of the max function, and is used in many machine learning models. Similarly we can define the soft-argmax function, which is a smooth approximation of the argmax function.

Definition and notation

First, let us define \Delta_{K} = \{p \in \mathbb{R}^K \geq 0: \sum_{k} p_k = 1\} the probability simplex in \mathbb{R}^K, and u_K = (1/K, \dots, 1/K)^\top the uniform distribution in \Delta_{K}, and the standard basis vectors (\delta_k)_{k=1,\dots,K}, where \delta_k = (0,\dots,\underbrace{1}_{k-\rm{th}},\dots,0)^\top \in \mathbb{R}^K.

Formally, the standard soft-argmax function \sigma \colon \mathbb{R}^{K} \to (0, 1)^{K}, where K \geq 1, takes a vector z = (z_{1}, \dots, z_{K})^\top \in \mathbb{R}^{K} and computes each component of the vector \sigma(z) \in [0, 1]^{K} by \left(\sigma(z)\right)_k = \frac{\exp(z_k)}{\sum_{k'=1}^{K} \exp(z_{k'})}\enspace, \quad \text{for } k = 1, \dots, K.

Note in particular that \sigma(z)=\sigma_{u_K, 1}(z) for any z \in \mathbb{R}^K.

We also define for any q \in \Delta_{K} with positive coordinates1, i.e., q_k>0, for all k\in [K], the function \sigma_{q,\beta} \colon \mathbb{R}^{K} \to \mathbb{R}^K as

\left(\sigma_{q,\beta}(z)\right)_k = \frac{q_k\exp(z_k/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp(z_{k'}/\beta)} \enspace, \quad \text{for } k = 1, \dots, K.

Last but not least, we introduce the real valued log-sum-exp function (or weighted softmax), defined for any vector z \in \mathbb{R}^K by

\text{logsumexp}_{q,\beta}(z) = \beta \cdot \log\left(\sum_{k=1}^{K} q_k \cdot \exp(z_k/\beta)\right).

Variational formulation

In this note we show that the soft-argmax function can be written as the conjugate of the log-sum-exp function.

Theorem 1 Let q \in \Delta_{K} be with positive coordinates, i.e., q_k>0, for all k\in [K], \beta>0 and let z \in \mathbb{R}^K. Then,

\begin{align*} \text{logsumexp}_{q,\beta}(z) & = \max_{ p \in \Delta_K} \langle z, p \rangle - \beta\sum_{k=1}^{K} p_k \log(p_k / q_k) \\ \sigma_{q,\beta}(z) & = \argmax_{ p \in \Delta_K} \langle z, p \rangle - \beta\sum_{k=1}^{K} p_k \log(p_k / q_k)\enspace. \end{align*}

Proof. Using Lagrange multipliers (see for instance Ch. 5, Boyd and Vandenberghe 2004) you get for \Lambda=(\lambda_1,\dots,\lambda_K)^\top, \mu \in \mathbb{R} and p \in \Delta_K the Lagrangian function:

\mathcal{L}(p,\mu,\Lambda) = \langle z, p \rangle - \beta\sum_{k=1}^{K} p_k \log(p_k / q_k) + \mu\left(\sum_{k=1}^{K} p_k - 1\right) - \sum_{k=1}^{K} \lambda_k p_k.

The Slater condition is satisfied, so the KKT conditions are necessary and sufficient for optimality. The KKT conditions are \begin{align*} \frac{\partial \mathcal{L}}{\partial p_k} &= z_k - \beta\log(p_k / q_k) - \beta - \mu - \lambda_k = 0, \quad k = 1, \dots, K,\\ \mu\left(\sum_{k=1}^{K} p_k - 1\right) &= 0,\\ \lambda_k p_k &= 0, \quad k = 1, \dots, K,\\ p_k &\geq 0, \quad k = 1, \dots, K. \end{align*} From the first KKT condition we get p_k = q_k\exp\left(\frac{z_k-\beta-\mu-\lambda_k}{\beta}\right), \quad k = 1, \dots, K. Now, since \lambda_k p_k = 0 and that q_k>0 for all k, we have that \lambda_k = 0 for all k. Thus, we get after normalisation that p_k = \frac{q_k\exp(z_k/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp(z_{k'}/\beta)}, \quad k = 1, \dots, K. Note that \begin{align*} & p_k = \frac{q_k\exp(z_k/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp(z_{k'}/\beta)}\\ \iff & \log( p_k) = \log(q_k) + \frac{z_k}{\beta} - \frac{1}{\beta} \cdot \text{logsumexp}_{q,\beta}(z). \end{align*} Finally, we have that \begin{align*} \beta\sum_{k=1}^{K} p_k \log(p_k / q_k) &= \beta\sum_{k=1}^{K} p_k \left(\frac{z_k}{\beta} - \frac{1}{\beta} \cdot \text{logsumexp}_{q,\beta}(z)\right)\\ &= \sum_{k=1}^{K} p_k z_k - \text{logsumexp}_{q,\beta}(z), \end{align*} Hence, \sum_{k=1}^{K} p_k z_k - \beta \sum_{k=1}^{K} p_k \log(p_k / q_k) = \text{logsumexp}_{q,\beta}(z).

The following limit properties for infinitesimal \beta explain the naming and the regularizing property of the (temperature) parameter \beta:

Proposition 1 Reminding that u_K=(1/K,\dots,1/K)^\top and \delta_{k} is the k-th standard basis vector, for any z \in \mathbb{R}^K, we have that

\begin{align*} \sigma_{u_K,\beta}(z) & \xrightarrow[\beta \to 0]{} \delta_{k_0}, \text{ where } k_0=\argmax_{k\in [K]} z_k \\ \text{logsumexp}_{u_K,\beta}(z) & \xrightarrow[\beta \to 0]{} \max_{k\in [K]} z_k \enspace. \end{align*}

The first limit show that the soft-argmax function is a kind of smooth approximation of the argmax function, while the log-sum-exp function is a smooth approximation of the max function.

Invariance properties

The softmax function is invariant to the addition of a constant to each component of the input vector. More precisely, we have the following result:

Theorem 2 Let q \in \Delta_{K} be with positive coordinates, i.e., q_k>0, for all k\in [K], \beta>0 and let z \in \mathbb{R}^K. Then, for any c \in \mathbb{R}, we have that \sigma_{q,\beta}(z) = \sigma_{q,\beta}(z+c).

Proof. We have that \begin{align*} \sigma_{q,\beta}(z+c)_k &= \frac{q_k\exp((z_k+c)/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp((z_{k'}+c)/\beta)}\\ &= \frac{q_k\exp(z_k/\beta)\exp(c/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp(z_{k'}/\beta)\exp(c/\beta)}\\ &= \frac{q_k\exp(z_k/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp(z_{k'}/\beta)}\\ &= \sigma_{q,\beta}(z)_k. \end{align*}

Let us consider also the effect of rescaling the input vector by a positive constant. We have the following result:

Theorem 3 Let q \in \Delta_{K} be with positive coordinates, i.e., q_k>0, for all k\in [K], \beta>0 and let z \in \mathbb{R}^K. Then, for any \alpha >0, we have that \sigma_{q,\beta}(\alpha z) = \sigma_{q,\tfrac{\beta}{\alpha}}(z).

Proof. We have that \begin{align*} \left(\sigma_{q,\beta}(\alpha z)\right)_k &= \frac{q_k\exp(\alpha z_k/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp(\alpha z_{k'}/\beta)}\\ &= \frac{q_k\exp(z_k/\frac{\beta}{\alpha})}{\sum_{k'=1}^{K} q_{k'}\exp(z_{k'}/\frac{\beta}{\alpha})}\\ &= \left(\sigma_{q,\tfrac{\beta}{\alpha}}(z)\right)_k. \end{align*}

Visualization

We consider the case K=3 for visualization.

You can click on the plot to move z (the purple dot), and see the corresponding soft-argmax \sigma_{q,\beta} (the red dot). You can modify q (the black dot) and \beta with the slider. The level sets displayed are for the function after the max in the log-sum-exp definition, i.e., for a fixed z\in \mathbb{R}^3, q\in \Delta_K and \beta>0, the level set is the set of points p\in \Delta_K such that \begin{align*} \Delta_K &\to \mathbb{R} \\ p & \mapsto \langle z, p \rangle - \beta\sum_{k=1}^{K} p_k \log(p_k / q_k) \end{align*}

Code
d3 = require("d3@v6", "d3-hexbin@0.2")
math = require("mathjs")
digamma = require( 'https://cdn.jsdelivr.net/gh/stdlib-js/math-base-special-digamma@umd/browser.js' )

import {legend} from "@d3/color-legend"


densityResolution = 80
margin = ({ left: 30, top: 30, right: 30, bottom: 30 })
height = 400



beta = [inputs.beta]
// lbd = [0.3,0.3,0.4]

viewof q1 = Inputs.range([0.01, 1], {value: 1/3, label: tex`q_1`, step: 0.01})
viewof q2 = Inputs.range([0.01, 1], {value: 1/3, label: tex`q_2`, step: 0.01})
viewof q3 = Inputs.range([0.01, 1], {value: 1/3, label: tex`q_3`, step: 0.01})
viewof z1 = Inputs.range([-10, 10], {value: 0.4, label: tex`z_1`, step: 0.01})
viewof z2 = Inputs.range([-10, 10], {value: 0.2, label: tex`z_2`, step: 0.01})
viewof z3 = Inputs.range([-10, 10], {value: 0.2, label: tex`z_3`, step: 0.01})


viewof inputs = Inputs.form({
  z1: viewof z1,
  z2: viewof z2,
  z3: viewof z3,
  beta: Inputs.range([0.01, 10], {value: 1, label: tex`\beta`, step: 0.01}),
  q1: viewof q1,
  q2: viewof q2,
  q3: viewof q3,
})

q = [q1, q2, q3]
z = [z1, z2, z3]

logsumexp = (x, alpha, beta) => {
  let result = 0;
  for(let i = 0; i < x.length; i++) result += alpha[i] * Math.exp(x[i] / beta);
  return beta * Math.log(result);
}


objective = (x, z, alpha, beta) => {
  let result = 0;
  for(let i = 0; i < x.length; i++) result -= beta * x[i] * Math.log(x[i] / alpha[i]);
  for(let i = 0; i < x.length; i++) result += x[i] * z[i];
  return result;
}


kull_leibler = (x, alpha, beta) => {
  let result = 0;
  for(let i = 0; i < x.length; i++) result -= x[i] * Math.log(x[i] / alpha[i]);
  return beta * result;
}


softmax = (x, alpha, beta) => {
  let result = [];
  let sum = 0;
  for(let i = 0; i < x.length; i++) {
    result.push(alpha[i] * Math.exp(x[i] / beta));
    sum += result[i];
  }
  for(let i = 0; i < x.length; i++) result[i] /= sum;
  return result;
}

function cartesianToBarycentric(x, y, corners) {
  const A = corners[0], B = corners[1], C = corners[2];
  const v0 = { x: C.x - A.x, y: C.y - A.y };
  const v1 = { x: B.x - A.x, y: B.y - A.y };
  const v2 = { x: x - A.x, y: y - A.y };

  const d00 = dot(v0, v0);
  const d01 = dot(v0, v1);
  const d11 = dot(v1, v1);
  const d20 = dot(v2, v0);
  const d21 = dot(v2, v1);
  const denom = d00 * d11 - d01 * d01;

  const lambda1 = (d11 * d20 - d01 * d21) / denom;
  const lambda2 = (d00 * d21 - d01 * d20) / denom;
  const lambda3 = 1.0 - lambda1 - lambda2;

  return [lambda1, lambda2, lambda3];
}

function barycentricToCartesian(q, corners) {
  return [
    corners[0].x * q[0] + corners[1].x * q[1] + corners[2].x * q[2],
    corners[0].y * q[0] + corners[1].y * q[1] + corners[2].y * q[2]
  ];
}

function dot(v1, v2) {
  return v1.x * v2.x + v1.y * v2.y;
}

ternaryDensity = (density, resolution, options, lbd) => {
  options = Object.assign({
    size: 400,
    margin: { left: 30, top: 30, right: 30, bottom: 30 },
    colorScheme: d3.interpolateBlues
  }, options);


  const fillScale = d3.scaleSequential(options.colorScheme).domain([0
  , Math.log(Math.max(q[0],q[1],q[2]))]);
  // const fillScale = d3.scaleSequential(options.colorScheme).domain(d3.extent(density.map(d => d.density)))

  const svg = d3.create("svg")
    .attr("width", options.size)
    .attr("height", options.size);

  const x = d3.scaleLinear()
    .domain([0, 1])
    .range([options.margin.left, options.size - options.margin.right])

  const y = d3.scaleLinear()
    .domain([0, 1])
    .range([options.size - options.margin.bottom, options.margin.top]);


  function inverseScaleX(value) {
    // Constants from the original scale
    const domainMin = 0;
    const domainMax = 1;
    const rangeMin = options.margin.left;
    const rangeMax = options.size - options.margin.right;

    // Calculate the proportion of the input value within the range
    const proportion = (value - rangeMin) / (rangeMax - rangeMin);

    // Apply the proportion to the domain
    return proportion * (domainMax - domainMin) + domainMin;
  }

  function inverseInverseScaleX(value) {
    // Constants from the original scale
    const domainMin = 0;
    const domainMax = 1;
    const rangeMin = options.size - options.margin.right;
    const rangeMax = options.margin.left;

    // Calculate the proportion of the input value within the range
    const proportion = (rangeMax - rangeMin);

    // Apply the proportion to the domain
    return rangeMax - proportion * value ;
  }

  function inverseScaleY(value) {
    // Constants from the original scale
    const domainMin = 0;
    const domainMax = 1;
    const rangeMin = options.size - options.margin.bottom;
    const rangeMax = options.margin.top;

    // Calculate the proportion of the input value within the range
    const proportion = (value - rangeMin) / (rangeMax - rangeMin);


    // Apply the proportion to the domain
    return proportion * (domainMax - domainMin) + domainMin;
  }

  function inverseInverseScaleY(value) {
    // Constants from the original scale
    const domainMin = 0;
    const domainMax = 1;
    const rangeMin = options.size - options.margin.bottom;
    const rangeMax = options.margin.top;

    // Calculate the proportion of the input value within the range
    const proportion = (rangeMax - rangeMin);

    // Apply the proportion to the domain
    return proportion * value + rangeMin;
  }



  const axisBottom = g => g.call(d3.axisBottom(x).ticks(4))
  const axisTop = g => g.call(d3.axisTop(x).ticks(4));

  const hexagonSize = ((x(1 / (2 * resolution)) - x(0))) / Math.cos(Math.PI / 6) + 1;

  const corners = [{ x: 0.5, y: Math.sqrt(3)/2 }, { x: 0, y: 0 }, {x: 1, y: 0 }];
  const corners_expanded = [{ x: 0.5, y: Math.sqrt(3)/2 + 0.1 }, { x: -0.1, y: -0.05 }, {x: 1.1, y: -0.05 }];

  const line = d3.line().x(d => x(d.x)).y(d => y(d.y));

  svg.append('defs')
    .append('clipPath')
    .attr('id', 'triangle')
    .append('path')
    .attr('x', options.size / 2)
    .attr('y', options.size / 2)
    .attr('d', line(corners))

  svg.selectAll('.border')
    .data([[corners[0], corners[1]],
           [corners[1], corners[2]],
           [corners[2], corners[0]]])
    .enter()
    .append('line')
    .attr('x1', (d) => x(d[0].x))
    .attr('x2', (d) => x(d[1].x))
    .attr('y1', (d) => y(d[0].y))
    .attr('y2', (d) => y(d[1].y))
    .attr('class', 'border')
    .attr('stroke', 'black');

  const hexbin = d3.hexbin();

  svg.append('g')
    .attr('clip-path', 'url(#triangle)')
    .selectAll(".point")
    .data(density)
    .join('path')
    .attr('transform', d => `translate(${x(corners[0].x * d.x[0] + corners[1].x * d.x[1] + corners[2].x * d.x[2])}, ${y(corners[0].y * d.x[0] + corners[1].y * d.x[1] + corners[2].y * d.x[2])})`)
    .attr('fill', d => fillScale(d.density))
    .attr('d', hexbin.hexagon(hexagonSize));

  svg.selectAll('text')
    .data([[corners_expanded[0], corners_expanded[1]],
           [corners_expanded[1], corners_expanded[2]],
           [corners_expanded[2], corners_expanded[0]]])
    .join("text")
    .attr("text-anchor", "middle")
    .attr("alignment-baseline", "middle")
    .attr("font-style", "italic")
    .attr("x", d => (x(d[0].x) + x(d[1].x)) / 2)
    .attr("y", d => (y(d[0].y) + y(d[1].y)) / 2)
    .text((d, i) => {
      if(i == 0) return "x₂" ;
      if(i == 1) return "x₃";
      if(i == 2) return "x₁";
    })

    let xy = barycentricToCartesian(z, corners)
    let xyr = [inverseInverseScaleX(xy[0]), inverseInverseScaleY(xy[1])];

    let xy_q = barycentricToCartesian(q, corners)
    let xyr_q = [inverseInverseScaleX(xy_q[0]), inverseInverseScaleY(xy_q[1])];

    let sftmax = softmax([z1, z2, z3], q, beta)
    let xy_s = barycentricToCartesian(sftmax, corners)
    let xyr_s = [inverseInverseScaleX(xy_s[0]), inverseInverseScaleY(xy_s[1])];

    // If it does not exist, append a new red circle
    svg.append("circle")
      .attr("class", "red-circle") // Add a class for easy selection
      .attr("cx", xyr_s[0])
      .attr("cy", xyr_s[1])
      .attr("z-index", 1000)
      .attr("r", 7) // Radius of the circle
      .attr("fill", "red"); // Fill color of the circle

    svg.append("circle")
      .attr("class", "purple-circle") // Add a class for easy selection
      .attr("cx", xyr[0])
      .attr("cy", xyr[1])
      .attr("z-index", 1000)
      .attr("r", 7) // Radius of the circle
      .attr("fill", "purple"); // Fill color of the circle


    svg.append("circle")
      .attr("class", "black-circle") // Add a class for easy selection
      .attr("cx", xyr_q[0])
      .attr("cy", xyr_q[1])
      .attr("z-index", 1000)
      .attr("r", 7) // Radius of the circle
      .attr("fill", "black"); // Fill color of the circle


    // Append text for the coordinates
    svg.append("text")
      .attr("class", "coord-text") // Add a class for easy selection
      // .attr("x", xyr[0] + 10) // Position the text right next to the circle
      // .attr("y", xyr[1])
      .attr("x", 300) // Position the text right next to the circle
      .attr("y", 80)
      .attr("fill", "black") // Text color
      .style("font-size", "10px")
      .text(`q=(${q[0].toFixed(2)}, ${q[1].toFixed(2)}, ${q[2].toFixed(2)})`); // Display coordinates
    // Append color legend for purple-circle and red-circle:


    // Append text for the coordinates
    svg.append("text")
      .attr("class", "coord-text") // Add a class for easy selection
      // .attr("x", xyr[0] + 10) // Position the text right next to the circle
      // .attr("y", xyr[1])
      .attr("x", 300) // Position the text right next to the circle
      .attr("y", 100)
      .attr("fill", "red") // Text color
      .style("font-size", "10px")
      .text(`σ_{q,β}=(${sftmax[0].toFixed(2)}, ${sftmax[1].toFixed(2)}, ${sftmax[2].toFixed(2)})`); // Display coordinates
    // Append color legend for purple-circle and red-circle:

    svg.append("text")
      .attr("class", "coord-text") // Add a class for easy selection
      .attr("x", 300) // Position the text right next to the circle
      .attr("y", 120)
      .attr("fill", "purple") // Text color
      .style("font-size", "10px")
      .text(`z=(${z1.toFixed(2)}, ${z2.toFixed(2)}, ${z3.toFixed(2)})`); // Display coordinates


  const removeLine = g => g.select(".domain").remove();

function set(input, value) {
  input.value = value;
  input.dispatchEvent(new Event("input", {bubbles: true}));
}

// Attach a click event listener to the SVG
svg.on("click", function(event) {
  // Extract the click coordinates
  const [clickX, clickY] = d3.pointer(event);
  lbd = cartesianToBarycentric(inverseScaleX(clickX), inverseScaleY(clickY), corners);

  set(viewof z1, lbd[2]);
  set(viewof z2, lbd[1]);
  set(viewof z3, lbd[0]);
});

 // display axis values
svg.append("g")
    .attr("transform", `translate(0, ${options.size - options.margin.bottom})`)
    .call(axisBottom)
    .call(removeLine)

  svg.append("g")
    .attr("transform", `translate(${x(1) + options.margin.right / 2 }, ${y(0) + 26}) rotate(-120)`)
    .call(axisBottom)
    .call(removeLine)
    .call(g => g.selectAll("text").attr("transform", "translate(11, 22) rotate(120)"));

  svg.append("g")
    .attr("transform", `translate(${x(0.5) + options.margin.left / 2}, ${y(Math.sqrt(3) / 2) - 26}) rotate(120)`)
    .call(axisBottom)
    .call(removeLine)
    .call(g => g.selectAll("text").attr("transform", "translate(-11, 22) rotate(-120)"));

  return svg.node();
}


levelLogsumexp = {
  const density = [];
  for(let i = 1; i < densityResolution; i++) {
    for(let j = 1; j < densityResolution - i; j++) {
      const x1 = i / densityResolution ;
      const x2 = j / densityResolution;
      const x3 = Math.max(0, 1 - x1 - x2);
      const x = [ x1, x2, x3 ];

      density.push({ x, density: logsumexp(x, q, beta) });
    }
  }
  return density;
}

levelKL = {
  const density = [];
  for(let i = 1; i < densityResolution; i++) {
    for(let j = 1; j < densityResolution - i; j++) {
      const x1 = i / densityResolution ;
      const x2 = j / densityResolution;
      const x3 = Math.max(0, 1 - x1 - x2);
      const x = [ x1, x2, x3 ];

      // density.push({ x, density: logsumexp(x, q, beta) / Math.log(Math.max(q[0],q[1], q[2]) * Math.exp(1/beta)) });
      density.push({ x, density: kull_leibler(x, q, beta) });
    }
  }
  return density;
}


levelKL_full = {
  const density = [];
  for(let i = 1; i < densityResolution; i++) {
    for(let j = 1; j < densityResolution - i; j++) {
      const x1 = i / densityResolution ;
      const x2 = j / densityResolution;
      const x3 = Math.max(0, 1 - x1 - x2);
      const x = [ x1, x2, x3 ];

      // density.push({ x, density: logsumexp(x, q, beta) / Math.log(Math.max(q[0],q[1], q[2]) * Math.exp(1/beta)) });
      density.push({ x, density: objective(x, z, q, beta) });
    }
  }
  return density;
}

ternaryDensity(levelKL_full, densityResolution, { size: 500, colorScheme: d3.interpolateBlues})
logsumexpColorScale = legend({color: d3.scaleSequential(d3.extent([0,  logsumexp(z, q, beta)]), d3.interpolateBlues), title: "Intensity"})

Credit

References

Boyd, S., and L. Vandenberghe. 2004. Convex Optimization. Cambridge University Press.

Footnotes

  1. the case with q_k=0 for some k can be reduced to a case with fewer coordinates↩︎