시간에 지남에 따라 더 좋은 성능을 가진 모델들이 등장하고 있지만 논문의 저자가 Tensorflow Hub 사이트에서 사람들이 사용하는 모델의 다운로드 횟수를 찾아본 결과, 성능이 좋은 큰 모델보다 Resnet-50 모델과 같이 비교적 작고 저렴한 모델을 사용한 경우가 대부분이었습니다.
- 실제 문제에 적용하기 위해서 모델을 고를 때 단순히 큰 모델이 적합하지 않을 수 있다.
큰 모델을 현실에서 사용하기 적합하도록 하는 방법 두 가지가 있는데 첫 번째는 pruning이고 두 번째가 knowledge distillation입니다. 여기서 pruning 같은 경우는 knowledge distillation보다 간단하게 적용할 수 있지만 model family 변경이 어렵다는 단점이 있습니다.
- knowledge distillation은 teacher model은 ResNet으로 설정하고 student model은 mobilenet으로 설정할 수 있지만 pruning 같은 경우는 전과 후 전부 같은 모델을 사용해야 하므로 모델 구조에 dependent 하다. (모델을 바꾸는 것이 불가능하다.)
본 논문은 knowledge distillation에 대한 새로운 방법이 아닌 큰 모델의 성능을 저하시키지 않고 크기를 줄일 때 설계 절차가 있다는 것을 발견하여 정리한 논문입니다. (knowledge distillation을 잘 활용하면 좋은 성능의 큰 모델과 저렴한 모델의 격차를 크게 해소할 수 있다.)
- knowledge distillation을 실제로 잘 활용할 수 있는 방법을 전달하는 것이 논문의 목적
방법은 다음과 같습니다.
- teacher와 student가 완전히 동일한 이미지를 input으로 받아야 합니다. (crop, augmentation 같은 노이즈들도 전부 동일하게 맞춰야 합니다.)
- generalization을 잘하기 위해서 많은 양의 support point를 매칭 할 수 있는 함수를 찾아야 합니다. (teacher와 strudent의 prediction을 많이 매칭 하는 함수를 찾아야 한다, 매칭되는 예측 분포가 넓어야 한다.)
- 실제로 knowledge distillation으로 지식이 잘 전달되려면 매우 오랫동안 학습을 진행해야 한다.
조금 더 자세히 알아보겠습니다.
위에 사진에서 왼쪽 두 개의 경우가 우리가 knowledge distillation을 할 때 하지 말아야 하는 방법들이고 오른쪽 두 개의 경우가 지향해야 될 방법들입니다.
Fixed teacher
이 부분이 우리가 쉽게 실수할 수 있는 부분입니다. student 모델을 학습시키면서 knowledge distillation을 같이 수행할 때 학습 도중에 teacher 모델의 예측 분포를 추론하는 것이 쉽지가 않습니다. teacher도 메모리 상에 올려서 forward를 진행해야 하므로 시간도 오래 걸리고 메모리 측면에서도 큰 부담이 오기 때문입니다.
그래서 student 모델을 학습하기 전에 미리 teacher 모델의 prediction 값을 따로 파일로 저장하게 될 때, 이 경우에 crop, augmentation이 다르게 적용된 이미지들도 모두 같은 teacher의 prediction값으로 매칭이 되기 때문에 좋은 성능을 낼 수 없습니다.
- 각각의 노이즈들이 적용된 teacher의 predcition 값에 매칭 되어된다.
- 서로 다른 input이 들어갔는데 같은 결과와 매칭이 됨 (디테일한 정보들이 무시됨, 노이즈에 대한 정보들이 버려짐)
- knowledge distillation에서는 디테일한 부분이 중요한거 같다, 초기 knowledge distillation에서는 soft target을 사용해서 다른 클래스의 정보들까지 얻을려고 했던 것 처럼
- 작은 모델이 큰 모델을 모방하는게 어려운 문제이니까 자잘한 정보들까지 전부 사용해줘야한다는 느낌이 듬
independent noise
위와 비슷한 케이스인데 이 경우에는 하나의 prediction 값으로 매칭 되지는 않지만 teacher 모델의 input noise와 student 모델의 input noise가 서로 다르게 진행되어서 나타나는 문제입니다.
- 사진을 보면 서로 다른 노이즈가 들어갔기 때문에 teacher와 student의 input이 다르다는 것을 표현하고 있습니다. (서로 사진의 다른 부분을 보고 있는데 매칭이 되었다.)
consistent teaching, function matching
두 케이스 모두 teacher와 student가 동일한 input을 가졌다는 공통점이 있지만 자세히 보면 consistent teaching의 경우 function matching보다 매칭되는 예측 분포가 더 넓다는 것을 알 수 있습니다.
그 이유는 input에 차이에 있습니다. function matching의 경우에는 mix up과 같이 공격적이 augmentation을 진행하거나 아니면 다른 도메인 이미지를 input으로 사용합니다.
- 이전에 리뷰했던 Distilling the Knowledge in a Neural Network에서 soft target을 사용한 경우 다른 클래스들의 확률에서도 정보를 얻을 수 있다고 했는데 그 부분의 확장인 거 같습니다.
- 다른 도메인에 대해서 같은 결과가 나오려면 모델의 구조가 서로 같아야 한다. 그리고 이를 통해서 teacher의 디테일한 부분까지 학습할 수 있고 모델의 표현 범위도 늘어난다. (표현 범위가 늘어난다. = 다른 도메인에서도 서로 매칭이 된다.)
Test
Dataset
- flower102, pets, food101, sun397, ILSVRC-2012(ImageNet) 총 5개의 data set을 사용
- 작은 사이즈와 큰 사이즈 데이터를 사용함
평가 지표는 classification accuracy이고 사용 모델은
- Teacher: BIT-ResNet-152x2 (BIT transfer을 사용한 ResNet-152 model)
- Studnet: BIT-ResNet-50
distillation loss on train을 본다면 fixed teacher가 loss가 낮게 나오지만 일반화 성능은 떨어지는 것을 알 수 있습니다.
student validation accuracy를 볼 경우 independent noise와 fixed teacher가 가장 아래 존재하고 input을 동일하게 준 경우가 중간, 그리고 동일한 input에 mix up을 사용한 경우가 정확도가 가장 높게 나왔습니다.
그런데 epoch부분을 보면 10000 epoch이 학습한 것을 알 수 있습니다. (batch size = 4096)
이전에 말했던 것 처럼 knowledge distillation을 잘 적용하려면 오랫동안 학습해야 된다고 했습니다. 본 논문에서는 파라미터가 많은 teacher 모델을 student가 모방하는 게 쉬운 일이 아니기 때문에 굉장히 오랜 시간 동안 distillation을 진행해야 한다고 주장합니다.
이전 data set보다 더 큰 data set의 경우도 살펴보겠습니다.
1. 첫 번째 이미지에서는 epoch수가 올라갈수록 fixed teacher 정확도가 떨어지는 것을 볼 수 있습니다.
2. 두 번째 이미지에서는 consistent teaching과 Shampoo optimizer를 사용할 경우 더 빠르게 최적화된다는 것을 알 수 있습니다.
3. 마지막 세 번째는 BIT방법으로 pretrain 된 파라미터를 불러오는 방법입니다. scratch부터 학습한 모델과 정확도 차이를 보시면 됩니다. (초반에는 더 좋은 성능을 기대할 수 있지만 나중에 가면 서로 비슷해짐)
요약
- teacher와 student가 완전히 동일한 입력을 받아야 한다.
- input의 manifold를 좀 더 풍부하게 만들어주기 위해서 공격적인 augmentation을 진행(mix up)
- 매우 오랫동안 학습해야 한다.
Q1. 다른 도메인 데이터나 공격적인 mix up으로 인해서 훼손된 이미지를 넣어도 성능이 떨어지지 않을까?
- 다른 도메인 데이터나 mix up 된 데이터를 사용하더라도 성능이 떨어지지 않는 이유가 teacher 모델이 학습되는 것이 아니라 예측만 진행하기 때문이라고 합니다. (forward만 진행하기 때문에) 그래서 여러가지 데이터를 넣어봄으로써 teacher 모델의 입출력 관계에 대해서 더 자세히 알 수 있고 그 정보가 student 모델에 도움이 된다고 합니다.
'Pseudo Lab > 4기 Mobile AI crew' 카테고리의 다른 글
[모두콘] 실용적인 딥러닝 모델 경량화 & 최적화 (0) | 2022.06.27 |
---|