안녕하세요. coreml에서 PReLU를 사용하기 위해 우회법을 찾던 도중 Custom PReLU 레이어를 만들었습니다. 아래 코드는 coreml 변환 이후 작동을 보장하지 않습니다.
torch 2.2.2 / python 3.10 사용 중입니다.
# PReLU의 정확한 작동 보장을 위해 C++ Implementation을 그대로 가져옴
# [pytorch/aten/src/ATen/native/Activation.cpp at main · pytorch/pytorch](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Activation.cpp)
def reshape_weight(x, weight):
# Clone the input weight to avoid modifying the original tensor
reshaped_weight = weight.clone()
# Get the dimensions of the input tensor and the weight tensor
self_dim = x.dim()
weight_dim = weight.dim()
if self_dim != weight_dim:
# Create a shape vector filled with 1s, with length equal to self_dim
dim_w = [1] * self_dim
# If self_dim is greater than 1, set the second dimension to the number of elements in the weight
if self_dim > 1:
dim_w[1] = weight.numel()
# Reshape the weight using the calculated shape vector
reshaped_weight = weight.reshape(dim_w)
return reshaped_weight
class CustomPReLU(nn.Module):
"""
PyTorch Custom PReLU that replicates nn.PReLU's behavior.
- num_parameters: 1이면 모든 채널에서 alpha(기울기)를 공유
>1이면 채널 개수만큼 alpha를 독립적으로 가짐
- init: alpha(기울기)의 초기값
"""
def __init__(self, num_parameters=1, init=0.25):
super().__init__()
self.num_parameters = num_parameters
# nn.Parameter로 선언하여 학습될 수 있게 함
self.weight = nn.Parameter(torch.full((num_parameters,), init))
def forward(self, x):
# 채널별 파라미터가 있는 경우 (NCHW 기준)
# self.weight.shape = (num_parameters,)
# x.shape = (N, C, H, W)
if self.num_parameters == 1:
# 모든 채널이 weight[0]을 공유
return torch.where(x >= 0, x, self.weight[0] * x)
else:
# 채널별 파라미터 적용
# weight를 (1, C, 1, 1)로 reshape하여 채널 차원만 broadcast
# w = self.weight.view(1, -1, 1, 1)
# w = self.weight.reshape(1, -1, 1, 1)
w = reshape_weight(x, self.weight)
return torch.where(x >= 0, x, w * x)
def replace_prelu_with_custom(model: nn.Module):
"""
model 내에 있는 nn.PReLU 모듈들을 CustomPReLU로 재귀적으로 교체.
PReLU의 weight를 그대로 복사해온다.
"""
for name, module in model.named_children():
if isinstance(module, nn.PReLU):
# print('converted', name)
# PReLU의 파라미터 개수(num_parameters)와 weight 복사
num_params = module.num_parameters
# 새로 커스텀 PReLU 모듈 생성
custom_module = CustomPReLU(num_parameters=num_params)
# 기존 weight 복사
with torch.no_grad():
if num_params == 1:
# module.weight는 shape (1,)일 수도 있고 스칼라 텐서일 수도 있음
custom_module.weight.copy_(module.weight)
else:
custom_module.weight.copy_(module.weight.data)
# 부모 모듈에 새로 만든 모듈로 교체
setattr(model, name, custom_module)
else:
# 자식 모듈들도 재귀 탐색
replace_prelu_with_custom(module)
CustomPReLU 클래스의 forward 함수에서 self.weight의 shape를 [C] -> [1, C, 1, 1]로 바꾸고 있습니다. view, reshape, reshape_weight 함수 세 가지 케이스에 대해 실험을 진행했습니다.
view와 reshape의 경우 prelu 변경 전후 모델의 추론 결과가 다릅니다.
그러나 위 코드의 reshape_weight함수를 이용하면 torch.allclose(1e-9)를 통과합니다.
reshape의 경우 self.weight와 w.clone().squeeze()의 값이 torch.equal을 통과하는 것을 확인했으나 결과가 달라지는 것을 이해하기 어렵습니다. 돌아가는데 왜 돌아가는지 모르겠습니다. 이유가 무엇일까요?