심드렁하게 저장

정규화 - Batch Normalization (배치 정규화) 본문

Artificial intelligence/Deep Learning

정규화 - Batch Normalization (배치 정규화)

Ggoosae 2025. 3. 11. 23:18

1. Batch Normalization이란?

Batch Normalizaion 하면 제일 먼저 나오는 이미지

 

Batch Normalization은 신경망에서 각 층의 입력 분포를 정규화(Normalization)하여 학습을 안정화하고 가속하는 기법이다. 딥러닝에서 각 층의 입력이 계속 변하면(Internal Covariate Shift), 학습이 불안정 해질 수 있다. 이때 Batch Normalization을 사용하면, 미니 배치 단위로 평균과 분산을 정규화하여 학습을 안정화 할 수 있다.

  • BN 적용시 장점
    • 학습 속도 증가 → 더 큰 학습률 사용 가능
    • 기울기 소실 문제 완화
    • 초깃값 초기화에 덜 민감
    • Dropout 없이도 과적합 방지 효과
  • BN의 단점
    • 배치 크기에 의존적 : 작은 배치에서는 평균/분산 추정이 불안정
    • 추론 시 moving average 사용 필요
    • RNN, 시퀀스 모델에는 부적합 -> 대신 LayerNorm, GroupNorm 사용 

2. Batch Normalization 동작 과정

BN은 각 미니배치에서 평균과 분산을 계산하여 정규화 한다.

  • 1) 입력 정규화(Normalizaion)
    • 입력 $x$가 미니배치 $B$에 속해 있을 때:
    •   
      Batch의 평균과 분산
    • $\mu_{B}$ : 미니배치 평균
    • $\sigma_{B}^{2}$: 미니배치 분산
    • $m$ : 미니 배치 크기
    • 정규화된 값:
    • 정규화 수식
    • $\epsilon$: 분모가 0이 되는 것을 방지하는 작은 값
  • 2) 학습 가능한 스케일 조정 (Affine Transformation)
  • $\gamma$ (Scale Parameter) : 표준화된 값의 스케일을 조절하는 학습 가능한 매개 변수
  • $\beta$ (Shift Parameter): 표준화된 값의 이동을 조절하는 학습 가능한 매개 변수
    • 감마와 베타가 필요한 이유는 단순 정규화만 하면 표현력이 제한될 수 있으므로 BN이 원하는 분포로 조정할 수 있도록 스케일과 이동을 추가하는 것
    • Batch Normalization을 적용하면 입력이 평균 0, 분산 1로 정규화 되면서도, 학습가능한 $\gamma$, $\beta$를 통해 표현력을 유지할 수 있음 

3. Batch Normalization의 적용 위치

BN은 Conv -> BatchNorm -> Activation 순서로 사용한다.

x = self.bn(self.fc(x))  # 선형 변환 → 정규화
x = F.relu(x)            # 활성화 함수 적용

 

4. Batch Normalization Pytorch로 구현하기

Batch Normailzation은 pytorch의 nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d 등을 활용할 수 있지만 이 글에서는 학습을 위해 pytorch로 수식을 직접 구현해본다.

import torch
import torch.nn as nn

class BatchNorm(nn.Module):
    def __init__(self, num_features, epsilon=1e-5, momentum=0.1,is_train=True):
        '''
        num_features: 채널 수
        epsilon: 0으로 나누는 것을 방지하는 작은 값
        momentum: 지수 가중 이동 평균을 위한 모멘텀
        '''
        super(BatchNorm,self).__init__()

        # 학습 가능한 Scale(gamma) & Shift(beta) 파라미터
        # torch.nn.Parameter 클래스는 자동미분이 되는(requires_grad=True) torch.Tensor
        self.gamma = nn.Parameter(torch.ones(num_features)) # 기본값 1
        self.beta = nn.Parameter(torch.zeros(num_features)) # 기본값 0

        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

        self.epsilon = epsilon
        self.momentum = momentum

    def forward(self,x):
        if self.training:
            batch_mean = x.mean(dim=0)
            batch_var = x.var(dim=0,unbiased=False)

            # 업데이트된 running mean/var
            self.running_mean = (1-self.momentum) * self.running_mean + self.momentum * batch_mean
            self.running_var = (1-self.momentum) * self.running_var + self.momentum * batch_var

            x_hat = (x - batch_mean) / torch.sqrt(batch_var + self.epsilon)
        else:
            # 추론 시에는 running 평균/분산 사용
            x_hat = (x - self.running_mean) / torch.sqrt(self.running_var + self.epsilon)

        # Scale & Shift
        return self.gamma * x_hat + self.beta
#
  • 위에서 torch.nn.Parameter 클래스는 자동미분이 되는(requires_grad=True) torch.Tensor이다.
  • 또한 self.register_buffer는 모델의 상태(state)로서 관리하고 싶은 텐서를 등록하는 데 사용된다. 즉, 이 메서드는 state_dict에 포함되어서, torch.nn.Module.state_dict()에 함께 저장되어, torch.save을 할 때, 함께 저장된다. 또한, register_buffer으로 등록된 텐서는 기본적으로 기울기를 계산하지 않는다.
  • 학습 시에는 미니배치의 평균과 분산으로 정규화하며 추론 시에는 저장된 running mean/var로 정규화한다. 이는 model.eval() 상태에서 동작한다.

5. 정리 및 요약

  • Batch Normalization은 각 층의 입력 분포를 정규화하여 학습 안정성과 속도를 향상
  • 학습률을 크게 설정할 수 있어 빠른 수렴 가능
  • CNN, MLP 모델에서 성능향상에 효과적
  • 작은 배치나 RNN에는 적합하지 않음