기존 pruning 기법들이 성능에 있어 불투명한 hyperparmeter를 사용하거나 여러 번에 거쳐서 pruning을 진행하는 것을 지적하며 single-shot pruning을 제안!
1. 기존 pruning 기법의 한계
- Dataset $\mathcal{D}=\{(\mathbf{x}_i, \mathbf{y}i)\}{i=1}^{n}$가 주어지고, 목표로 하는 sparsity level이 $k$인 NN Pruning을 constrained optimization 문제로 표현할 수 있다.
$$\begin{aligned} \min_{\mathbf{c}, \mathbf{w}}L( \mathbf{w};\mathcal{D})&=\min \frac{1}{n}\sum_{i=1}^{n}\mathcal{l}(\mathbf{w};(\mathbf{x}_i,\mathbf{y}_i)) \\ \text{s.t.} \, \mathbf{w} &\in \mathbb{R}^m, ||\mathbf{\mathbf{w}}||_0 \leq \mathcal{k}, \end{aligned}$$
여기서 $\mathcal{l}(\cdot)$은 cross entropy와 같은 일반적인 loss function이며, $\mathbf{w}$는 NN의 set of paramrter, $m$은 총 parameter 수이다.
- 정의된 최적화 문제는 보통 미리 학습한 network를 pruning하고 fine-tuning하는 과정을 반복하는 iterative하고 휴리스틱한 방법이 주로 사용됨.
- 대부분의 pruning 기법이 FCN, CNN 등 architecture에 의존성이 있으며, 또한 pruning 과정에 사용되는 hyper-parameter가 많이 사용하는데 이를 구하는 과정이 휴리스틱한 경우가 많다.
2. One-Shot Pruning Method
- 본 논문에서는 training 전에 한번의 pruning을 수행하여 parameter를 줄이는 방법을 제안한다.
- Network의 각 element의 삭제 여부를 나타내는 auxiliary indicator variance $c\in\{0,1\}^m$를 정의하고 최적화 문제를 다음과 같이 수정함.
$$\begin{aligned} \min_{\mathbf{c}, \mathbf{w}}L(\mathbf{c}\odot \mathbf{w};\mathcal{D})&=\min \frac{1}{n}\sum_{i=1}^{n}\mathcal{l}(\mathbf{c}\odot \mathbf{w};(\mathbf{x}_i,\mathbf{y}_i)) \\ \text{s.t.} \, \mathbf{w} &\in \mathbb{R}^m, \\ \mathbf{c} &\in \{0,1\}^m, \, ||\mathbf{c}||_0 \leq \mathcal{k}, \end{aligned}$$
- parameter $c$로 인해 학습해야할 parameter가 2배가 되어 바로 optimize가 어려워짐. 하지만 pruning 여부, 즉 $c$에 따른 성능 변화에 대해서 loss를 정의하면 weight와 무관하게 최적화 할 수 있음.
- Pruning의 효과를 loss의 차이를 이용하여 표현할 수 있다. m개의 connection이 각각 loss에 미치는 영향을 얻기 위해 m번의 forward pass를 계산해야 함. → 연산이 복잡함
$$\begin{aligned}\Delta L_j(\mathbf{w};\mathcal{D})&=L(1\odot \mathbf{w};\mathcal{D})-L((1-\mathbf{e}_j)\odot \mathbf{w};\mathcal{D})\end{aligned}$$
weight가 아닌 c에 대한 효과로 다시 표현 가능함. (index j의 효과를 제거)
$$\begin{aligned} \Delta L_j(\mathbf{w};\mathcal{D}) &\approx g_i(\mathbf{w},\mathcal{D}) \\ &=\left. \frac{\partial L(\mathbf{c}\odot \mathbf{w};\mathcal{D})}{\partial c_j} \right|_{\mathbf{c}=1} \\ &=\left. \lim_{\delta \rightarrow0 }\frac{L(\mathbf{c}\odot \mathbf{w};\mathcal{D})-L((\mathbf{c}-\delta \mathbf{e}_j)\odot \mathbf{w};\mathcal{D})}{\delta}\right|_{\mathbf{c}=1} \end{aligned}$$
여기서 $c\in\{0,1\}^m$는 미분 불가하므로, 극소 변화에 대한 변화량으로 근사화 함.
- weight에 dependency가 적고 한번의 forward pass로 모든 connection을 평가할 수 있는 "connection sensitivity"를 정의하고, 한번의 forward pass를 통해 모든 connection의 sensitivity를 계산함.
$$s_j=\frac{|g_j(\mathbf{w};\mathcal{D})|}{\sum_{k=1}^{m}|g_k(\mathbf{w};\mathcal{D})|}$$
3. Experimental Result
- 다음과 같은 순서로 pruning을 수행함.
- network의 parameter를 초기화
- mini-batch sampling $\mathcal{D}^b=\{(\mathbf{x}_i, \mathbf{y}i)\}^{b}{j=1} \sim\mathcal{D}$
- Connection sensitivity를 계산 $s_j\forall j\in\{1,...,m\}$
- Top-$\mathcal{k}$의 parameter만 남기고 pruning
- pruned network를 학습
- 다른 pruning 기법 대비 간단한 방법으로 좋은 성능을 나타냄.
random label을 적용한 결과 pruned network의 경우 loss가 감소하지 않음.
- network의 memorization 문제를 방지할 수 있다!