The softmax function is a smooth approximation of the max function, and is used in many machine learning models. Similarly we can define the soft-argmax function, which is a smooth approximation of the argmax function.
Definition and notation
First, let us define \Delta_{K} = \{p \in \mathbb{R}^K \geq 0: \sum_{k} p_k = 1\} the probability simplex in \mathbb{R}^K, and u_K = (1/K, \dots, 1/K)^\top the uniform distribution in \Delta_{K}, and the standard basis vectors (\delta_k)_{k=1,\dots,K}, where \delta_k = (0,\dots,\underbrace{1}_{k-\rm{th}},\dots,0)^\top \in \mathbb{R}^K.
Formally, the standard soft-argmax function \sigma \colon \mathbb{R}^{K} \to (0, 1)^{K}, where K \geq 1, takes a vector z = (z_{1}, \dots, z_{K})^\top \in \mathbb{R}^{K} and computes each component of the vector \sigma(z) \in [0, 1]^{K} by \left(\sigma(z)\right)_k = \frac{\exp(z_k)}{\sum_{k'=1}^{K} \exp(z_{k'})}\enspace, \quad \text{for } k = 1, \dots, K.
Note in particular that \sigma(z)=\sigma_{u_K, 1}(z) for any z \in \mathbb{R}^K.
We also define for any q \in \Delta_{K} with positive coordinates1, i.e., q_k>0, for all k\in [K], the function \sigma_{q,\beta} \colon \mathbb{R}^{K} \to \mathbb{R}^K as
\left(\sigma_{q,\beta}(z)\right)_k = \frac{q_k\exp(z_k/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp(z_{k'}/\beta)} \enspace, \quad \text{for } k = 1, \dots, K.
Last but not least, we introduce the real valued log-sum-exp function (or weighted softmax), defined for any vector z \in \mathbb{R}^K by
\text{logsumexp}_{q,\beta}(z) = \beta \cdot \log\left(\sum_{k=1}^{K} q_k \cdot \exp(z_k/\beta)\right).
Variational formulation
In this note we show that the soft-argmax function can be written as the conjugate of the log-sum-exp function.
Theorem 1 Let q \in \Delta_{K} be with positive coordinates, i.e., q_k>0, for all k\in [K], \beta>0 and let z \in \mathbb{R}^K. Then,
\begin{align*} \text{logsumexp}_{q,\beta}(z) & = \max_{ p \in \Delta_K} \langle z, p \rangle - \beta\sum_{k=1}^{K} p_k \log(p_k / q_k) \\ \sigma_{q,\beta}(z) & = \argmax_{ p \in \Delta_K} \langle z, p \rangle - \beta\sum_{k=1}^{K} p_k \log(p_k / q_k)\enspace. \end{align*}
Proof. Using Lagrange multipliers (see for instance Ch. 5, Boyd and Vandenberghe 2004) you get for \Lambda=(\lambda_1,\dots,\lambda_K)^\top, \mu \in \mathbb{R} and p \in \Delta_K the Lagrangian function:
\mathcal{L}(p,\mu,\Lambda) = \langle z, p \rangle - \beta\sum_{k=1}^{K} p_k \log(p_k / q_k) + \mu\left(\sum_{k=1}^{K} p_k - 1\right) - \sum_{k=1}^{K} \lambda_k p_k.
The Slater condition is satisfied, so the KKT conditions are necessary and sufficient for optimality. The KKT conditions are \begin{align*} \frac{\partial \mathcal{L}}{\partial p_k} &= z_k - \beta\log(p_k / q_k) - \beta - \mu - \lambda_k = 0, \quad k = 1, \dots, K,\\ \mu\left(\sum_{k=1}^{K} p_k - 1\right) &= 0,\\ \lambda_k p_k &= 0, \quad k = 1, \dots, K,\\ p_k &\geq 0, \quad k = 1, \dots, K. \end{align*} From the first KKT condition we get p_k = q_k\exp\left(\frac{z_k-\beta-\mu-\lambda_k}{\beta}\right), \quad k = 1, \dots, K. Now, since \lambda_k p_k = 0 and that q_k>0 for all k, we have that \lambda_k = 0 for all k. Thus, we get after normalisation that p_k = \frac{q_k\exp(z_k/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp(z_{k'}/\beta)}, \quad k = 1, \dots, K. Note that \begin{align*} & p_k = \frac{q_k\exp(z_k/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp(z_{k'}/\beta)}\\ \iff & \log( p_k) = \log(q_k) + \frac{z_k}{\beta} - \frac{1}{\beta} \cdot \text{logsumexp}_{q,\beta}(z). \end{align*} Finally, we have that \begin{align*} \beta\sum_{k=1}^{K} p_k \log(p_k / q_k) &= \beta\sum_{k=1}^{K} p_k \left(\frac{z_k}{\beta} - \frac{1}{\beta} \cdot \text{logsumexp}_{q,\beta}(z)\right)\\ &= \sum_{k=1}^{K} p_k z_k - \text{logsumexp}_{q,\beta}(z), \end{align*} Hence, \sum_{k=1}^{K} p_k z_k - \beta \sum_{k=1}^{K} p_k \log(p_k / q_k) = \text{logsumexp}_{q,\beta}(z).
The following limit properties for infinitesimal \beta explain the naming and the regularizing property of the (temperature) parameter \beta:
Proposition 1 Reminding that u_K=(1/K,\dots,1/K)^\top and \delta_{k} is the k-th standard basis vector, for any z \in \mathbb{R}^K, we have that
\begin{align*} \sigma_{u_K,\beta}(z) & \xrightarrow[\beta \to 0]{} \delta_{k_0}, \text{ where } k_0=\argmax_{k\in [K]} z_k \\ \text{logsumexp}_{u_K,\beta}(z) & \xrightarrow[\beta \to 0]{} \max_{k\in [K]} z_k \enspace. \end{align*}
The first limit show that the soft-argmax function is a kind of smooth approximation of the argmax function, while the log-sum-exp function is a smooth approximation of the max function.
Invariance properties
The softmax function is invariant to the addition of a constant to each component of the input vector. More precisely, we have the following result:
Theorem 2 Let q \in \Delta_{K} be with positive coordinates, i.e., q_k>0, for all k\in [K], \beta>0 and let z \in \mathbb{R}^K. Then, for any c \in \mathbb{R}, we have that \sigma_{q,\beta}(z) = \sigma_{q,\beta}(z+c).
Proof. We have that \begin{align*} \sigma_{q,\beta}(z+c)_k &= \frac{q_k\exp((z_k+c)/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp((z_{k'}+c)/\beta)}\\ &= \frac{q_k\exp(z_k/\beta)\exp(c/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp(z_{k'}/\beta)\exp(c/\beta)}\\ &= \frac{q_k\exp(z_k/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp(z_{k'}/\beta)}\\ &= \sigma_{q,\beta}(z)_k. \end{align*}
Let us consider also the effect of rescaling the input vector by a positive constant. We have the following result:
Theorem 3 Let q \in \Delta_{K} be with positive coordinates, i.e., q_k>0, for all k\in [K], \beta>0 and let z \in \mathbb{R}^K. Then, for any \alpha >0, we have that \sigma_{q,\beta}(\alpha z) = \sigma_{q,\tfrac{\beta}{\alpha}}(z).
Proof. We have that \begin{align*} \left(\sigma_{q,\beta}(\alpha z)\right)_k &= \frac{q_k\exp(\alpha z_k/\beta)}{\sum_{k'=1}^{K} q_{k'}\exp(\alpha z_{k'}/\beta)}\\ &= \frac{q_k\exp(z_k/\frac{\beta}{\alpha})}{\sum_{k'=1}^{K} q_{k'}\exp(z_{k'}/\frac{\beta}{\alpha})}\\ &= \left(\sigma_{q,\tfrac{\beta}{\alpha}}(z)\right)_k. \end{align*}
We consider the case K=3 for visualization.
You can click on the plot to move z (the purple dot), and see the corresponding soft-argmax \sigma_{q,\beta} (the red dot). You can modify q (the black dot) and \beta with the slider. The level sets displayed are for the function after the max in the log-sum-exp definition, i.e., for a fixed z\in \mathbb{R}^3, q\in \Delta_K and \beta>0, the level set is the set of points p\in \Delta_K such that \begin{align*} \Delta_K &\to \mathbb{R} \\ p & \mapsto \langle z, p \rangle - \beta\sum_{k=1}^{K} p_k \log(p_k / q_k) \end{align*}
The observable code was made with the help of François-David Collin
The plot is freely inspired from Herb Susmann’s code on Dirichlet Distribution
The terminology is inspired by posts from Gabriel Peyré: https://x.com/gabrielpeyre/status/1830470713041354968, https://x.com/gabrielpeyre/status/1680804520862056448
the case with q_k=0 for some k can be reduced to a case with fewer coordinates↩︎