심드렁하게 저장

Convolution - Dilated Convolution, Deformable Convolution 본문

Artificial intelligence/Deep Learning

Convolution - Dilated Convolution, Deformable Convolution

Ggoosae 2025. 5. 6. 16:23

Dilated Convolution

Dialted Conv

 

Dilated Convolution은 커널의 필터 사이를 띄워서 적용하는 연산이다. 즉, 샘플 간 간격(dilation rate) 을 넓게 벌려서 더 넓은 receptive field(CNN에서 출력 레이어의 뉴런 하나에 영향을 미치는 입력 뉴런들의 공간 크기)를 확보 할 수 있다. 기존 Convolution은 인접한 픽셀에만 반응하지만 Dilated Conv는 픽셀 사이를 건너뛰며 계산하므로 멀리 떨어진 정보까지 한번에 수용 가능하다. 또한 해상도 손실 없이 더 넓은 문맥을 담을 수 있다. Dilated Conv의 활용사례는 다음과같다.

DeepLab v3+ 여러 dilation rate로 multi-scale context 학습
WaveNet 오디오 생성 시, 시계열에서 넓은 의존성 모델링
TCN (Temporal) 시간 기반 장기 의존성 학습
ResNet 변형 다운샘플링 없이 receptive field 확장

 

Dilated Conv의 수식은 다음과 같이 정리된다.

2D로 일반화:

$X(i,j)$ 입력의 위치 (i,j)의 픽셀값
$W(m,n)$ 커널의 위치 (m,n)에서의 가중치
$r$ dilation rate. 커널 위치 간 간격을 결정
$X(i + r \cdot m, j + r \cdot n)$ 입력에서 띄워진 위치의 값
$Y(i,j)$ 출력 feature map의 위치 (i,j)에서의 결과 값

 

Pytorch에서는 dilation param을 설정하여 구현할 수 있다.

import torch
import torch.nn as nn

class DilatedConv2d(nn.Module):
    def __init__(self,in_channels, out_channels,kernel_size=3, dilation=1, padding=2):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels,kernel_size,
            dilation=dilation,padding=padding
        )

    def forward(self,x):
        return self.conv(x)

 

일반적인 Covolution과 비교하면 다음과 같은 결과를 얻을 수 있다.

from conv_module import DilatedConv2d,NormalConv2d
import torch
from torchvision.datasets import CIFAR10
from utils import normalize
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import cv2

# Load Lenna.png
transform = transforms.Compose([
    transforms.ToTensor()
    ])
image = cv2.imread('./Lenna.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = transform(image)

normal_conv = NormalConv2d(3,3)
dilated_conv = DilatedConv2d(3,3)

with torch.no_grad():
    normal = normal_conv(image)
    dilated = dilated_conv(image)

images = [normalize(image),normalize(normal) ,normalize(dilated)]
titles = ['Input Image','Normal Convolution Output', 'Dilated Convolution Output']

plt.figure(figsize=(10, 5))
for i,img in enumerate(images):
    plt.subplot(1, len(images), i + 1)
    plt.imshow(np.transpose(img.cpu().numpy(),(1,2,0)))
    plt.title(titles[i])
    plt.axis('off')
plt.tight_layout()
plt.show()

Normal Conv VS Dilated Conv

Deformable Convolution

Deformable Covolution 출처:https://production-media.paperswithcode.com/methods/newest.png

 

기존의 Convolution은 고정된 커널-3x3,5x5 등- 위치에서만 정보를 샘플링하지만 Deformable Convolution은 각 커널 위치에 대해 학습된 이동량-offset-을 더해서 입력 위치를 유연하게 조절할 수 있는 Conv 연산이다. 즉, 커널이 딱 정사각형이 아니라 비틀고 늘어지고 움직일 수 있는 커널이다. 이를 적용하면 기하학적으로 복잡하거나 비정형적인 물체 구조에도 잘 대응한다. 특히 객체 탐지, Segmentation, 키포인트 검출 등에 효과적이다. Deformable Covolution 활용모델은 다음과같다:

Deformable ConvNet (DCNv1, DCNv2) Faster R-CNN, Mask R-CNN에서 bbox 성능 향상에 사용됨
HRNet 포즈 추정, 키포인트 기반 검출에서 강력한 표현력 제공
Multi-scale Vision Tasks 다양한 scale의 객체나 국소 위치 이동 대응에 적합

 

수식은 다음과 같이 정의된다:

$y(p_0)$ 출력 feature map의 위치 $p_0$에서의 결과값
$w_k$ 커널의 k번째 위치에 해당하는 필터 weight
$x(p_0 + p_k + \Delta p_k)$ 입력 feature map에서 샘플링하는 위치. 일반 conv와 다르게 실수 좌표일 수 있음
$p_k$ 고정된 커널 위치. 예: $p_k \in [-1,0,1]^2$
$\Delta p_k$ 학습된 offset (각 위치마다 다름, 실수값)

 

전체 구조를 요약하면,

  1. 일반 Convolution처럼 입력 feature를 받는다
  2. offset 전용 conv layer가 위치별로 $\Delta p_k$를 예측한다.
  3. 입력 위치를 offset만큼 이동시켜 실수 좌표에서 샘플링 : bilinear interpolation 필요
  4. 필터 weight와 곱해서 출력 생성

개념을 pytorch로 구현해보면 다음과 같다.

class DeformableConv2d(nn.Module):
    def __init__(self):
        super().__init__()
        self.offset_conv = nn.Conv2d(3, 18, kernel_size=3, padding=1)  # 2 * 3x3
        self.weight = nn.Parameter(torch.randn(3, 1, 3, 3))  # Depthwise 구조

    def forward(self, x):
        B, C, H, W = x.shape
        offset = self.offset_conv(x).view(B, 9, 2, H, W)
        base_y, base_x = torch.meshgrid(
            torch.arange(H), torch.arange(W), indexing='ij'
        )
        base_grid = torch.stack((base_x, base_y), dim=0).float().to(x.device)
        base_grid = base_grid[None, None, :, :, :].expand(B, 9, -1, -1, -1)
        sampling_grid = base_grid + offset  # (B, 9, 2, H, W)
        sampling_grid = sampling_grid.permute(0, 3, 4, 1, 2)  # (B, H, W, 9, 2)
        sampling_grid = sampling_grid / torch.tensor(
            [W - 1, H - 1], device=x.device
        ) * 2 - 1  # normalize to [-1, 1]

        outputs = []
        for i in range(9):
            grid = sampling_grid[..., i, :]  # (B, H, W, 2)
            sample = F.grid_sample(x, grid, align_corners=True)
            outputs.append(sample)
        out = torch.stack(outputs, dim=2)  # (B, C, 9, H, W)
        return out.sum(dim=2)