Summary of the paper on ‘Learning to classify images without labels’
The entire paper can be summarized in three stages :
Self-supervised learning → Clustering → Self labelling
Self supervised learning : (Mining K nearest neighbors)
A typical image classification task would involve labels to govern the features it learns through a Loss function . But when there are no labels to govern such backpropagation in a network how do we get the network to learn meaningful features from the images ?
Self — supervised representation learning involves the use of a predefined task/objective to make sure the network learns meaningful features . Let’s take a NN of 5 layers , once we have a good representation of the image (an xD vector of the 5th layer) , we can cluster them using Euclidean distance as a loss function to cluster the images . But we have no idea if this will be semantically meaningful and moreover this approach will tend to focus on low level features during backprop and hence is dependent on the initialization used in the first layer
The paper solves this by defining this pretext task
min distance ( Image , Transformed_image )
Transformed image is nothing but rotation , affine or perspective transformation etc applied to it . When the original image and transformed image are passed to the same NN with the objective of minimising the distance between them , the learned representations are much more meaningful
Clustering stage :
Great , now that we got our meaningful embeddings next would to apply K-means or any clustering algorithm to it . But naively applying K-means to get K clusters can lead to ‘cluster degeneracy’ — a state where another set of K clusters also makes sense . Also , a discriminative model can lead to assigning all the probabilities to the same cluster , thereby one cluster dominating the others . To overcome this the paper introduces Semantic clustering loss
Semantic clustering loss is the whole crux of this paper
Let’s break this down
The idea is to pass these images and its mined neighbors from the previous stage to a NN to output probabilities for C classes ( C is chosen using some knowledge initially or a guess , the paper uses the knowledge of ground truth for evaluation purposes) , something like the one shown below
The purpose of the above loss function is to make this class distribution of an image as close as possible to the class distribution of the k nearest neighbors of the image mined by solving the task in stage 1 .
This is done by the first term in the above equation which calculates the dot product of the image vector of probabilities and the its neighbors’ vector . This ensures consistency rather than using a joint distribution of classes . But in the process the class distribution can become skewed towards one class .
To ensure this the second term is used , which is a measure of how skewed the distribution is , higher the value more uniform the distribution of classes
Fine tuning using Self labelling :
The SC loss ensures consistency but there are going to be false positives which this stage takes care of . This stage filter data points based on confidence scores by thresholding the probability and then assigning a pseudo label of its predicted cluster . Cross entropy loss updates the weights of those data points which makes the predictions more certain
Training setup :
5 nearest neighbors are determined from the self supervised step (stage 1)
Weights transferred to the clustering step
Batch size =128 , weightage of the entropy term (2nd term ) in SC loss ( lambda = 2)
Fine tuning step : threshold : 0.99 , Cross entropy loss , Adam op
Standard data aug-mentations are random flips, random crops and jitter.
Strong augmentations are composed of four randomly selected transformations from AutoAugment
Experimental results :
The above results (last 3) show the accuracy obtained across each stage . It can be seen the SCAN loss is indeed significant and so are the augmentation techniques which make better generalizations. This need for hyperparameterizations is also one of the complexity of this approach
As it can be seen the above method achieves good accuracy wrt Supervised and significantly better than other prior unsupervised methods . The higher the no of classes the lesser the accuracy which is also the case with supervised methods
Link to the paper : https://arxiv.org/pdf/2005.12320.pdf