A Simple Explanation of the Softmax Function
What Softmax is, how it's used, and how to implement it in Python.
| UPDATED
Softmax turns arbitrary real values into probabilities, which are often useful in Machine Learning. The math behind it is pretty simple: given some numbers,
- Raise e (the mathematical constant) to the power of each of those numbers.
- Sum up all the exponentials (powers of ). This result is the denominator.
- Use each number’s exponential as its numerator.
- .
Written more fancily, Softmax performs the following transform on numbers :
The outputs of the Softmax transform are always in the range and add up to 1. Hence, they form a probability distribution.
A Simple Example
Say we have the numbers -1, 0, 3, and 5. First, we calculate the denominator:
Then, we can calculate the numerators and probabilities:
Numerator () | Probability () | |
---|---|---|
-1 | 0.368 | 0.002 |
0 | 1 | 0.006 |
3 | 20.09 | 0.118 |
5 | 148.41 | 0.874 |
The bigger the , the higher its probability. Also, notice that the probabilities all add up to 1, as mentioned before.
Implementing Softmax in Python
Using numpy makes this super easy:
import numpy as np
def softmax(xs):
return np.exp(xs) / sum(np.exp(xs))
xs = np.array([-1, 0, 3, 5])
print(softmax(xs)) # [0.0021657, 0.00588697, 0.11824302, 0.87370431]
Note: for more advanced users, you’ll probably want to implement this using the LogSumExp trick to avoid underflow/overflow problems.
Why is Softmax useful?
Imagine building a Neural Network to answer the question: Is this picture of a dog or a cat?
A common design for this neural network would have it output 2 real numbers, one representing dog and the other cat, and apply Softmax on these values. For example, let’s say the network outputs :
Animal | Probability | ||
---|---|---|---|
Dog | -1 | 0.368 | 0.047 |
Cat | 2 | 7.39 | 0.953 |
This means our network is 95.3% confident that the picture is of a cat. Softmax lets us answer classification questions with probabilities, which are more useful than simpler answers (e.g. binary yes/no).