Transfer Learning and Domain Adaptation#

Note

We will use the abbreviation TL for transfer learning.

Note

We will use the abbreviation DA for domain adaptation.

When do we need TL/DA?#

Suppose that you are born and raised in the US, you’ve not learned a foreign language, and you’ve never been to another country. Now all of a sudden you are thrown into a small town in Russia, what would you think?

Machine learning models feel the same way when you feed them completely different inputs. For example, when a model has only seen MNIST digits, but now you’re asking it to classify Full-HD, colorful, picturesque printed numbers taken from a photo, it’s no wonder that your model decides not to work.

However, those Full-HD numbers are still numbers right? Sure they have something in common with MNIST? Indeed, the same number 1 share some traits and the same number 2 also are roughly similar (just like Russian and English, they are both from humans). But it’s still different enough that the machine learning model decides to give up.

TL/DA aims to solve that. What TL/DA tries to do, is to make a model work across several similar environments that are different. By applying techniques in TL/DA, a model can perform better in not-so-similar-but-arguably-the-same environments, compared to the the environment where the model is trained.

Different techniques in TL/DA.#

Most TL/DA methods fall in one of the following three categories.

Discrepancy-based methods.#

Discrepancy-based methods utilize a feature-extractor and a very simple classifier. Those methods try to align the statistical measures of features of different domains.

For example, we have a text feature-extractor, trained on news. Suppose that when reading from news, the mean of the features is \( \mu \) and the stddev \( \sigma \). When reading from medical documents, the mean of the features is \( \mu' \) and the stddev \( \sigma' \). A discrepancy-based method basically rescales the features extracted from the feature-extractor \( x' \) to \( \frac{\sigma}{\sigma'} (x' - \mu') + \mu \). And then pass the rescaled features through the classifier.

Adversarial-based methods.#

Adversarial-based methods utilize a feature-extractor and a discriminator. Those methods aim to train a feature-extractor that extract features that are common to different domains.

For example, we have an audio feature-extractor. When training on classical music and funk, both types are encoded into a feature vector, and the discriminator’s job is to tell apart classical music and funk. And we train it like a GAN.

Reconstruction-based methods.#

Reconstruction-based methods try to utilize one encoder and multiple decoders. Those methods try to encode features that can be reconstructed by different decoders onto different domains.

For example, we have an mammal image encoder, but an gorilla decoder and a monkey decoder. During training, the encoder tries to encode both images from gorilla domain, and monkey domain. The gorilla decoder would try to decode the gorilla images’ features into gorillas, and the monkey decoder would try to decode the monkey images’ features into monkeys.