논문 리뷰/경량화 논문 스터디

[논문 리뷰] Distilling the Knowledge in a Neural Network (NIPS 2014 Workshop)

공부중인학생 2021. 10. 26. 21:39

이 논문은 처음으로 knowledge distillation이라는 개념을 제시한 논문입니다. pruning과 비슷하게 이전에 있던 모델을 사용할 수 있다는 점에서 활용도가 높을 거 같아 관심을 가지게 됐습니다.

 

 

2021년 10월 25일 기준으로 7705회 인용됐습니다.

 

 

  • 모델을 만들 때 성능을 쉽게 올리는 방법 중 하나는 Ensemble입니다. 각각의 모델들이 test error가 비슷하다고 하더라도 ensemble 한 결과는 더 낮은 test error를 보여줍니다.
  • 하지만 ensemble 모델을 사용하기에는 너무 많은 컴퓨터 비용이 필요하기 때문에 일반적인 상황에서는 사용하기가 쉽지 않습니다.
  • 그래서 본 논문에서는 ensemble을 통해 얻을 수 있는 generalization 능력을 더 작은 규모의 모델에 전달하는 방법을 제안합니다.

 

Distilling the Ensemble Knowledge

 

 

  • Knowledge distillation이란 많은 parameter가 사용되는 teacher 모델로부터 불필요한 parameter를 제거하고 generalization 성능을 기존 대비 향상 시킬 수 있는 knowledge들을 분류하는 방법을 말합니다.
  • 여기서는 전달되는 knowledge는 softmax 결과에 함축되어 담겨있다고 말합니다.

사진과 같이 큰 모델에서 단순한 모델로 지식을 전달하여 적은 추론 시간을 가지는 단순한 모델을 만드는 것이 본 논문의 목적입니다.

 

- 여기서 크고 복잡한 모델을 Teacher model, 작고 단순한 모델을 Student 모델이라고 합니다.

 

이 논문에서는 knowledge distillation을 진행할 때 주의할 점은 hard target을 사용하면 안 된다는 것입니다.

 

 

Hard target

 

 

  • 논문에서는 Teacher model이 Hard target(one - hot encoding)을 사용하면 안 된다고 말합니다. hard target을 사용할 시 가장 높은 확률의 정보를 제외하고는 나머지 정보들은 0으로 전부 무시하게 됩니다.

hard target을 적용하지 않은 방법은 soft target입니다. soft target는 softmax의 output을 모델의 최종 output으로 사용하는데 이를 통해 모든 class에 대한 확률 정보를 얻어 정보의 손실을 줄일 수 있다는 장점이 있습니다. = regularizaition 효과를 얻을 수 있습니다.

 

  • 고양이의 경우 소보다는 강아지에 더 가까운 형상을 가진다는 정보 또한 얻을 수 있습니다.

 

Soft target

 

하지만 기존의 softmax를 사용하게 되면 입력값의 작은 차이에도 출력에 큰 변화를 줍니다.

  • 가장 큰 입력값은 1과 가까운 값을 가지고 나머지는 0에 가까운 값으로 mapping이 됩니다.
  • 작은 입력 정보들을 잘 활용하지 못하는 문제가 발생, regularizaiton 효과가 적어짐

이러한 문제점을 해결하고자 본 논문에서는 기존의 softmax에 Temperature라는 scaling 역할의 하이퍼 파라미터를 추가합니다.

$Softmax(z_i) = \frac {exp(z_i)} {\Sigma_j (z_j)}$ $\rightarrow$ $Softmax(z_i) = \frac {exp(z_i / \tau)} {\Sigma_j (z_j / \tau)}$

$\tau$가 1일 때는 기존 softmax와 동일하고 값이 커질수록 더 soft 한 확률분포를 얻게 됩니다.

 

 

Distillation

Distillation은 offline에서 진행이 되었고 실제 label 값과 teacher model에서 얻은 soft target 값을 둘 다 사용하여 학습을 진행합니다.

  • $f_{t}(x_i) : Teacher$ 모델의 logit 값
  • $f_{s}(x_i) : Student$ 모델의 logit 값
  • $\tau :$ Scailing 역할의 하이퍼 파라미터
  • $KL:$ Kullback–Leibler divergence로 두 확률분포의 차이를 계산

 

 

  1. training set (x, hard target)을 사용해 large model을 학습한다.
  2. large model이 충분히 학습된 뒤에, large model의 output을 soft target으로 하는 transfer set(x, soft target)을 생성해낸다. 이때 soft target의 $T$ 는 1이 아닌 높은 값을 사용한다.
  3. transfer set을 사용해 small model을 학습한다. $T$ 는 soft target을 생성할 때와 같은 값을 사용한다. (soft target)
  4. training set을 사용해 small model을 학습한다. $T$ 는 1로 고정한다 (hard target)

 

Result

JFT data와 JFT 기본 모델을 사용하여 학습을 진행

 

 

  • hard target으로 모든 데이터로 학습했을 때 test accuracy는 58.9%, 그리고 3%의 데이터로 학습했을 때는 44.5%가 나왔다. (early stopping을 사용했음에도 overfitting 발생)
  • soft target으로 3%의 데이터로 학습했을 때 test accuracy는 57%가 나왔다. (early stopping 사용 안 함)

soft target을 사용한 경우 3%의 데이터만 가지고 학습하더라도 모든 데이터를 사용한 hard target과 비슷한 정확도를 가집니다.

 

 

한계점

  • 더욱 복잡한 문제와 deeper network 환경에서는 softmax output으로만 knowledge를 전달하기에는 무리가 있다는 의견이 다수 존재

 

 

Reference

https://www.youtube.com/watch?v=pgfsxe8sROQ&t=460s