Pseudo Lab/4기 Mobile AI crew

[논문 리뷰] Knowledge distillation: A good teacher is patient and consistent

공부중인학생 2022. 5. 10. 01:34

시간에 지남에 따라 더 좋은 성능을 가진 모델들이 등장하고 있지만 논문의 저자가 Tensorflow Hub 사이트에서 사람들이 사용하는 모델의 다운로드 횟수를 찾아본 결과, 성능이 좋은 큰 모델보다 Resnet-50 모델과 같이 비교적 작고 저렴한 모델을 사용한 경우가 대부분이었습니다.

 

- 실제 문제에 적용하기 위해서 모델을 고를 때 단순히 큰 모델이 적합하지 않을 수 있다.

 

 

큰 모델을 현실에서 사용하기 적합하도록 하는 방법 두 가지가 있는데 첫 번째는 pruning이고 두 번째가 knowledge distillation입니다. 여기서 pruning 같은 경우는 knowledge distillation보다 간단하게 적용할 수 있지만 model family 변경이 어렵다는 단점이 있습니다.

 

- knowledge distillation은 teacher model은 ResNet으로 설정하고 student model은 mobilenet으로 설정할 수 있지만 pruning 같은 경우는 전과 후 전부 같은 모델을 사용해야 하므로 모델 구조에 dependent 하다. (모델을 바꾸는 것이 불가능하다.)

 

 

본 논문은 knowledge distillation에 대한 새로운 방법이 아닌 큰 모델의 성능을 저하시키지 않고 크기를 줄일 때 설계 절차가 있다는 것을 발견하여 정리한 논문입니다. (knowledge distillation을 잘 활용하면 좋은 성능의 큰 모델과 저렴한 모델의 격차를 크게 해소할 수 있다.)

  • knowledge distillation을 실제로 잘 활용할 수 있는 방법을 전달하는 것이 논문의 목적

 

방법은 다음과 같습니다.

 

  1. teacher와 student가 완전히 동일한 이미지를 input으로 받아야 합니다. (crop, augmentation 같은 노이즈들도 전부 동일하게 맞춰야 합니다.)
  2. generalization을 잘하기 위해서 많은 양의 support point를 매칭 할 수 있는 함수를 찾아야 합니다. (teacher와 strudent의 prediction을 많이 매칭 하는 함수를 찾아야 한다, 매칭되는 예측 분포가 넓어야 한다.)
  3. 실제로 knowledge distillation으로 지식이 잘 전달되려면 매우 오랫동안 학습을 진행해야 한다.

 

조금 더 자세히 알아보겠습니다.

 

 

 

 

위에 사진에서 왼쪽 두 개의 경우가 우리가 knowledge distillation을 할 때 하지 말아야 하는 방법들이고 오른쪽 두 개의 경우가 지향해야 될 방법들입니다.

 

 

 

Fixed teacher

 

 

이 부분이 우리가 쉽게 실수할 수 있는 부분입니다. student 모델을 학습시키면서 knowledge distillation을 같이 수행할 때 학습 도중에 teacher 모델의 예측 분포를 추론하는 것이 쉽지가 않습니다. teacher도 메모리 상에 올려서 forward를 진행해야 하므로 시간도 오래 걸리고 메모리 측면에서도 큰 부담이 오기 때문입니다. 

 

그래서 student 모델을 학습하기 전에 미리 teacher 모델의 prediction 값을 따로 파일로 저장하게 될 때, 이 경우에 crop, augmentation이 다르게 적용된 이미지들도 모두 같은 teacher의 prediction값으로 매칭이 되기 때문에 좋은 성능을 낼 수 없습니다.

 

- 각각의 노이즈들이 적용된 teacher의 predcition 값에 매칭 되어된다.

- 서로 다른 input이 들어갔는데 같은 결과와 매칭이 됨 (디테일한 정보들이 무시됨, 노이즈에 대한 정보들이 버려짐)

- knowledge distillation에서는 디테일한 부분이 중요한거 같다, 초기 knowledge distillation에서는 soft target을 사용해서 다른 클래스의 정보들까지 얻을려고 했던 것 처럼

- 작은 모델이 큰 모델을 모방하는게 어려운 문제이니까 자잘한 정보들까지 전부 사용해줘야한다는 느낌이 듬

 

 

 

independent noise

 

 

위와 비슷한 케이스인데 이 경우에는 하나의 prediction 값으로 매칭 되지는 않지만 teacher 모델의 input noise와 student 모델의 input noise가 서로 다르게 진행되어서 나타나는 문제입니다.

 

- 사진을 보면 서로 다른 노이즈가 들어갔기 때문에 teacher와 student의 input이 다르다는 것을 표현하고 있습니다. (서로 사진의 다른 부분을 보고 있는데 매칭이 되었다.)

 

 

 

consistent teaching, function matching

 

 

두 케이스 모두 teacher와 student가 동일한 input을 가졌다는 공통점이 있지만 자세히 보면 consistent teaching의 경우 function matching보다 매칭되는 예측 분포가 더 넓다는 것을 알 수 있습니다. 

 

 

왼쪽은 consistent teaching 오른쪽은 function matching

 

그 이유는 input에 차이에 있습니다. function matching의 경우에는 mix up과 같이 공격적이 augmentation을 진행하거나 아니면 다른 도메인 이미지를 input으로 사용합니다.

 

- 이전에 리뷰했던 Distilling the Knowledge in a Neural Network에서 soft target을 사용한 경우 다른 클래스들의 확률에서도 정보를 얻을 수 있다고 했는데 그 부분의 확장인 거 같습니다.

 

- 다른 도메인에 대해서 같은 결과가 나오려면 모델의 구조가 서로 같아야 한다. 그리고 이를 통해서 teacher의 디테일한 부분까지 학습할 수 있고 모델의 표현 범위도 늘어난다. (표현 범위가 늘어난다. = 다른 도메인에서도 서로 매칭이 된다.)

 

 

 

Test

 

Dataset

  •  flower102, pets, food101, sun397, ILSVRC-2012(ImageNet) 총 5개의 data set을 사용
  •  작은 사이즈와 큰 사이즈 데이터를 사용함

 

평가 지표는 classification accuracy이고 사용 모델은 

 

  • Teacher: BIT-ResNet-152x2 (BIT transfer을 사용한 ResNet-152 model)
  • Studnet: BIT-ResNet-50

 

flower 102 data

distillation loss on train을 본다면 fixed teacher가 loss가 낮게 나오지만 일반화 성능은 떨어지는 것을 알 수 있습니다. 

 

student validation accuracy를 볼 경우 independent noise와 fixed teacher가 가장 아래 존재하고 input을 동일하게 준 경우가 중간, 그리고 동일한 input에 mix up을 사용한 경우가 정확도가 가장 높게 나왔습니다.

 

그런데 epoch부분을 보면 10000 epoch이 학습한 것을 알 수 있습니다. (batch size = 4096)

 

이전에 말했던 것 처럼 knowledge distillation을 잘 적용하려면 오랫동안 학습해야 된다고 했습니다. 본 논문에서는 파라미터가 많은 teacher 모델을 student가 모방하는 게 쉬운 일이 아니기 때문에 굉장히 오랜 시간 동안 distillation을 진행해야 한다고 주장합니다.

 

 

Image Net data

 

이전 data set보다 더 큰 data set의 경우도 살펴보겠습니다.

 

1. 첫 번째 이미지에서는 epoch수가 올라갈수록 fixed teacher 정확도가 떨어지는 것을 볼 수 있습니다.

 

2. 두 번째 이미지에서는 consistent teaching과 Shampoo optimizer를 사용할 경우 더 빠르게 최적화된다는 것을 알 수 있습니다.

 

3. 마지막 세 번째는 BIT방법으로 pretrain 된 파라미터를 불러오는 방법입니다. scratch부터 학습한 모델과 정확도 차이를 보시면 됩니다. (초반에는 더 좋은 성능을 기대할 수 있지만 나중에 가면 서로 비슷해짐) 

 

 

 

요약

 

- teacher와 student가 완전히 동일한 입력을 받아야 한다.

- input의 manifold를 좀 더 풍부하게 만들어주기 위해서 공격적인 augmentation을 진행(mix up)

- 매우 오랫동안 학습해야 한다.

 

Q1. 다른 도메인 데이터나 공격적인 mix up으로 인해서 훼손된 이미지를 넣어도 성능이 떨어지지 않을까?

 

  • 다른 도메인 데이터나 mix up 된 데이터를 사용하더라도 성능이 떨어지지 않는 이유가 teacher 모델이 학습되는 것이 아니라 예측만 진행하기 때문이라고 합니다. (forward만 진행하기 때문에) 그래서 여러가지 데이터를 넣어봄으로써 teacher 모델의 입출력 관계에 대해서 더 자세히 알 수 있고 그 정보가 student 모델에 도움이 된다고 합니다.