Capsule Networks - Better CNNs?
The pooling operation used in convolutional neural networks is a big mistake and the fact that it works so well is a disaster.
This is a quote from Geoffrey Hinton, considered to be one of the most import researchers on AI and Deep Learning. In 2012, his student Alex Krizhevsky won the ImageNet challenge with the Convolutional Neural Network (CNN) AlexNet and layed ground for the success of CNNs in Computer Vision. So why does Hinton think that CNNs are bad?
What CNNs are bad at:
If you want to recap CNNs, I can recommend this article or take a look at this blogpost for a CNNs used for image classification.
Although CNNs achieve the best results in image classification, they are actually pretty bad in detecting objects in different poses. To achieve good performance a lot of training data is needed, and techniques like data augmentation are used.
This is because CNNs are not equivariant but invariant to pose changes. Equivariance means, that when the input changes, the output changes the same way. Invariance in contrast means, that when the input changes, the output doesn’t change. Because CNNs are made to be invariant, this makes them robust against translation changes. But other pose changes (like rotation, orientation) are not handled well.
This is because CNNs use pooling (e.g. max-pooling) to convey information from one layer to the next. CNNs detect certain features in an image, but through using a pooling layer valuable information gets lost.
What do we want instead:
What we want instead is equivariance. When the input changes, the output should change accordingly.
So how do Capsule Nets work?
If you are more hands on and want to immediately see an implementation, check out this jupyter notebook, in which I implemented a capsule network with pytorch.
In computer graphics, a picture is rendered from a lot of stored parameters. Capsule networks are going the other way round: they infer instantiation parameters from an input image. So instead of only providing us a probability (as CNNs do), the output of a capsule network is a vector, Each capsule outputs a vector, where the length of the vector reflects the probability that an specific object is present, while the orientation encodes the instantiation parameters.
In the previous picture there were two different capsules, each is responsible for detecting a specific shape or object. Additionally capsule networks have a hierarchy: in lower levels simple shapes, whereas more complex objects are detected in higher levels.
The instantiation parameters encoded in the output vector can be:
Lets take a deeper look at Capsules
We have seen that capsule networks output vectors which store instantiation parameters.
Lets compare “classical” neurons and capsules:
The main differences here are
- capsules work with vectors instead of scalar
- they have an activation function called squashing
Squashing is a special kind of normalization, i.e. it rescales the vector to a value between 0 and 1 while preserving its orientation. Thereby short vectors get shrunk to almost zero length, while long vectors get rescaled to values slightly below 1.
Another difference between CNN’s and Capsule Networks is the use of routing by agreement instead of pooling. Routing by agreement is the algorithm introduced in 2017 which made it possible to implement performant Capsule Networks.
How does routing by agreement work?
We have seen that there is a hierarchy: lower level capsules report to higher level capsules. Routing by Agreement determines which lower level capsule reports to which higher level capsule.
Lower level capsules are multiplied with a affine transformation matrix W and try to predict what higher level capsules will output. Those matrices W are calculated with classic backpropagation. Through routing by agreement, the coupling coefficients c~i~ are determined. As seen before, the coupling coefficients give us a weighted sum of the transformed input. The higher c~i~, the more important the input from the lower level capsule.
First, all coefficients are initialized with 0. In each iteration we normalize the weights using softmax and then compute the weighted sum sj of all lower level capsules for a high level capsule j.
After normalizing the sum with the squash function, we compute the difference between the normalized weighted sum vj and each transformed lower level capsule input ûi. The more alike vj and ûi, i.e. the more ûiagrees with vj, the bigger the product. This is then added to the current coupling coefficient. So step by step, coupling coefficients where capsules agree get increased.
Here is an example implementation with pytorch (you can find the whole implementation in this jupyter notebook):
def routing_by_agreement(self, u_hat, num_iterations):
b_ij = torch.zeros(1, self.previous_capsules, self.num_capsules, 1) #initialize b_ij = 0
for iteration in range(num_iterations):
c_ij = F.softmax(b_ij, dim = 1) # scale each capsule weights so they sum to 1
c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)
# transformed vectors get multiplied element-wise with the coupling coefficients
# for each high level capsule, the sum over all lower level capsules inputs is calculated
s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
v_j = self.squash(s_j) # normalize with squash function
# don't need to calculate new b_ij in last iteration
if iteration < num_iterations - 1:
a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.previous_capsules, dim=1))
b_ij = b_ij + a_ij.mean(dim=0, keepdim=True).squeeze(4)
What makes routing by agreement special?
It is good at handling crowded scenes, by “explaining away” ambiguities. What does that mean?
In the image above, we can see either a house and a boat, or only a house with a triangle on top and a rectangle below.
As detecting both a house and a boat fits better to the predictions made by the lower level, the ambiguity is resolved.
Architecture and Implementation
If you are still interested in learning more about Capsule Networks and implementing it yourself, keep on reading!
This is the architecture used in the paper from Sabour et al.:
This is the encoder part. We start with a convolutional layer with 256 kernels, resulting in a 20x20x256 output. The next layer is called primary caps, but is acutally only another convolution. With the second caps layer we have a proper capsule layer as we have seen before. Each capsule gets transformed with a affine transformation matrix and then dynamic routing takes place, where it is determined which lower level capsule reports to which higher level capsule.
Afterwards the encoder part, a fully connected network is used as a decoder to reconstruct the original image from the instantiation parameters.
If you are interested in trying out a capsule network for yourself, check out this jupyter notebook. There, I am explaining and implementing the architecture step-by-step with the FashionMNIST dataset and pytorch.
So are Capsule Networks the new CNNs?
No, not yet.
Capsule networks achieve comparable results on MNIST, while requiring less training data. They are robust against small affine transformations. Also they are good for detecting overlapping objects. This is promising for image segmentation.
But because a capsule network tries to “explain” everything in an image, they perform badly on more complex datasets. Additionally, Capsule Networks need longer time to be trained.
A possible use case is the integration of capsule layers into CNNs.
Also they are nice for visualization:
Wanna learn more about capsules?
Those are nice explanations of Capsule Networks:
- Aurélien Géron, 2017, https://www.youtube.com/watch?v=pPN8d0E3900
- Max Pechyonkin, 2017, https://medium.com/ai³-theory-practice-business/understanding-hintons-capsule-networks-part-i-intuition-b4b559d1159b
- Nick Bourdakos, 2018, https://www.freecodecamp.org/news/understanding-capsule-networks-ais-alluring-new-architecture-bdb228173ddc/
- comprehensive graphic, explaining the architecture more in detail
And this is the original paper:
Sara Sabour et al., 2017, https://arxiv.org/pdf/1710.09829.pdf