안녕하세요? Hook function 관련해서 질문드릴게 있어서 글 남깁니다.
개요:
GPT2 block(=layer) 내에 있는 특정 모듈을 대상으로 intervention을 가하기 위해 register_forward_hook 함수를 사용하고 있습니다. (GPT2 모델은 huggingface transformer 내 gpt2 모델을 사용했습니다.)
Intervention 대상 모듈 (list_m_ids):
(1) transformer.h.{n_layer}.attn.c_attn
(2) transformer.h.{n_layer}.attn.c_proj
(3) transformer.h.{n_layer}.mlp.c_fc
(4) transformer.h.{n_layer}.mlp.c_proj
(좀 더 자세하게 말씀드리면 layer가 12개 존재하는 gpt2 모델 기준으로 총 48개의 모듈에 대한 intervention을 거하고 있습니다.)
문제: intervention을 가했는데도 불구하고, 동일한 layer 내 (1)과 (2)의 loss가 동일하게 나오고, (3)과 (4)의 loss 가 동일하게 나오는 현상이 발생하고 있습니다. hooking function의 parameter인 (module, input, output)과 intervention 이후의 new_output 도 모두 다른걸 확인 완료했습니다 (하단 로그 참조).
hook function
def hook_func(m_id):
def _hook_func(module, _input, _output):
new_output = _output + dict_intervention[m_id]
print(f"module: {module}")
print(f"_input: {_input}")
print(f"_output: {_output}")
print(f"new_output: {new_output}")
return new_output
return _hook_func
with torch.no_grad():
for m_id in list_m_ids:
hook = model.get_submodule(m_id).register_forward_hook(hook_func(m_id))
outputs = model(inputs, labels=inputs.clone())
loss[m_id] = np.exp(outputs.loss.items())
hook.remove()
print("loss:", loss[m_id])
- output log는 다음과같습니다
sentence_idx: 0
m_id: transformer.h.0.attn.c_attn
module: Conv1D()
_input: (tensor([[[-0.0615, -0.1206, -0.1294, ..., -0.0864, 0.0021, 0.0195],
[ 0.4367, -0.0484, 0.0263, ..., 0.1548, 0.2173, 0.1258],
[-0.4231, -0.1931, -0.0902, ..., 0.0283, 0.1649, -0.0793],
...,
[ 0.2313, -0.1444, 0.1054, ..., -0.2182, 0.0458, 0.1173],
[-0.3861, -0.1854, 0.0642, ..., 0.2437, 0.0576, -0.1109],
[-0.2373, 0.0031, 0.0299, ..., -0.1507, 0.1637, 0.4065]]],
device='cuda:0'),)
_output: tensor([[[-0.8298, -0.6448, -0.4022, ..., -0.0069, -0.1674, 0.0688],
[-0.7464, 0.9376, -1.3309, ..., 0.1358, 0.4994, 0.2281],
[ 0.6109, 0.1870, -0.8229, ..., 0.2682, 0.0185, 0.1607],
...,
[ 0.7141, -0.3613, -1.9945, ..., 0.4382, 0.0813, 0.2043],
[ 1.6572, 0.6914, -0.4856, ..., 0.0381, -0.3873, -0.0249],
[-0.1834, -0.2926, 0.0105, ..., -0.0664, 0.1361, -0.3783]]],
device='cuda:0')
new_output: tensor([[[-0.0288, -0.3114, 0.0557, ..., 0.0645, -0.1732, 0.0111],
[-0.4986, 0.1701, -1.0163, ..., 0.3744, 0.3668, 0.1815],
[ 0.3047, 0.8006, 0.1433, ..., 0.1129, 0.2242, 0.0319],
...,
[ 1.0819, -0.0811, -0.6493, ..., 0.4032, 0.0125, 0.2486],
[ 1.6623, -0.3283, 0.1893, ..., -0.0898, 0.0466, 0.0422],
[ 0.3081, -1.8928, 0.2130, ..., -0.0200, 0.1305, -0.0184]]],
device='cuda:0')
outputs.logits: tensor([[ -2.4135, -1.2998, -5.3751, ..., -8.8890, -7.1677, -5.2011],
[ 16.5356, 19.7391, 11.9018, ..., 12.2652, 9.1145, 8.4064],
[-18.6066, -15.6845, -23.4903, ..., -20.3913, -25.7245, -26.3299],
...,
[ 1.7284, -0.9057, -3.5867, ..., -6.9615, -10.8864, -9.5238],
[ 33.7493, 30.4189, 26.0778, ..., 24.3718, 20.9329, 20.1168],
[ -9.7913, -8.7155, -11.4811, ..., -17.2810, -22.1441, -17.0313]],
device='cuda:0')
loss: 29.470737374030648
m_id: transformer.h.0.attn.c_proj
module: Conv1D()
_input: (tensor([[[ 0.0118, 0.2083, -0.1307, ..., -0.0069, -0.1674, 0.0688],
[ 0.0144, 0.1823, -0.0851, ..., 0.0623, 0.1559, 0.1461],
[ 0.0451, 0.1538, 0.0384, ..., 0.0944, 0.1008, 0.1417],
...,
[-0.0694, -0.0134, -0.0214, ..., 0.0263, 0.0040, -0.0208],
[-0.1137, -0.0970, -0.1129, ..., 0.0070, -0.0079, -0.0138],
[-0.1048, 0.0253, -0.0466, ..., 0.0157, -0.0279, -0.0329]]],
device='cuda:0'),)
_output: tensor([[[-1.4456, -1.1025, -0.3038, ..., 0.1879, 0.0180, -0.0746],
[-1.5696, -0.5864, -0.9546, ..., 0.1019, 0.0580, -0.0956],
[-1.5708, -0.7926, 0.0503, ..., 0.1538, 0.0582, -0.1794],
...,
[-0.2901, -0.3021, 0.1445, ..., -0.0246, -0.0046, 0.1150],
[ 0.3652, -0.2418, 0.2801, ..., 0.0146, 0.0308, 0.0057],
[ 0.4580, -0.2454, -0.1958, ..., -0.0391, 0.0089, 0.0345]]],
device='cuda:0')
new_output: tensor([[[ 0.1565, -0.4952, -0.1143, ..., 0.0959, 0.0137, -0.0193],
[-0.3473, -0.4503, -0.4558, ..., 0.0347, 0.0646, -0.0503],
[-0.2465, -0.0605, -0.1782, ..., 0.1067, 0.0376, -0.1109],
...,
[-1.3325, -0.7436, 0.2225, ..., -0.0378, 0.0613, 0.1693],
[-1.3401, -0.0249, -0.6910, ..., 0.0268, 0.0343, 0.1071],
[ 0.1038, -0.0552, -0.5469, ..., -0.0292, -0.0282, 0.0712]]],
device='cuda:0')
outputs.logits: tensor([[ -2.4135, -1.2998, -5.3751, ..., -8.8890, -7.1677, -5.2011],
[ 16.5356, 19.7391, 11.9018, ..., 12.2652, 9.1145, 8.4064],
[-18.6066, -15.6845, -23.4903, ..., -20.3913, -25.7245, -26.3299],
...,
[ 1.7284, -0.9057, -3.5867, ..., -6.9615, -10.8864, -9.5238],
[ 33.7493, 30.4189, 26.0778, ..., 24.3718, 20.9329, 20.1168],
[ -9.7913, -8.7155, -11.4811, ..., -17.2810, -22.1441, -17.0313]],
device='cuda:0')
loss: 29.470737374030648
m_id: transformer.h.0.mlp.c_fc
module: Conv1D()
_input: (tensor([[[-0.1610, -0.2409, -0.0824, ..., 0.0338, 0.0478, -0.0263],
[-0.0771, -0.0692, -0.1262, ..., 0.3234, 0.3941, 0.0617],
[-0.1922, -0.1504, 0.0064, ..., 0.1874, 0.3033, -0.2792],
...,
[ 0.0357, -0.0466, 0.0804, ..., -0.3506, 0.0939, 0.2769],
[ 0.0412, -0.0604, 0.1028, ..., 0.4087, 0.1469, -0.1335],
[ 0.0797, -0.0057, -0.0017, ..., -0.2536, 0.3438, 0.6034]]],
device='cuda:0'),)
_output: tensor([[[-0.0429, -2.0128, -2.0130, ..., -3.1365, -2.5582, 0.6534],
[-0.0877, -1.1516, 0.5917, ..., -1.7849, -0.1427, -0.8034],
[ 0.3357, 0.1032, -1.3886, ..., -1.0677, 0.2035, -2.6555],
...,
[ 0.2661, -1.1748, 0.1337, ..., -0.4629, -0.3839, -0.7288],
[ 0.3894, -2.3801, -2.6291, ..., -5.2681, -1.3698, -0.7015],
[ 0.0736, -1.1691, -2.5465, ..., -2.3099, -0.4566, -0.6415]]],
device='cuda:0')
new_output: tensor([[[-0.1442, -2.0562, -1.6867, ..., -2.5164, -1.1803, 1.1230],
[ 0.4238, -1.2899, -1.1302, ..., -2.0902, 0.4380, -0.6506],
[ 0.2192, -0.6439, -0.2801, ..., -1.0144, -0.4218, -1.8172],
...,
[ 0.0694, -0.6179, -0.2199, ..., -0.7417, -0.6159, -0.7027],
[-0.2255, -2.1850, -0.6787, ..., -2.1977, -1.4968, -0.6493],
[ 0.1734, -2.2434, -1.1219, ..., -1.4834, -1.3147, 0.0156]]],
device='cuda:0')
outputs.logits: tensor([[ 0.7795, 2.0030, -1.6746, ..., -4.6599, -4.3385, -2.4356],
[ 24.9088, 29.4419, 21.8366, ..., 21.9064, 17.4732, 17.0667],
[-23.4861, -19.9130, -27.9467, ..., -24.4025, -30.8767, -30.1642],
...,
[ 21.3362, 19.2360, 16.6425, ..., 12.6393, 7.5003, 9.7640],
[ 47.4022, 44.4750, 39.5697, ..., 38.5065, 32.9435, 33.3205],
[ -4.2421, -2.4494, -5.3408, ..., -10.9985, -18.5314, -12.8872]],
device='cuda:0')
loss: 25.420941218783344
m_id: transformer.h.0.mlp.c_proj
module: Conv1D()
_input: (tensor([[[-0.0207, -0.0443, -0.0443, ..., -0.0023, -0.0130, 0.4856],
[-0.0408, -0.1439, 0.4277, ..., -0.0664, -0.0633, -0.1695],
[ 0.2120, 0.0559, -0.1148, ..., -0.1527, 0.1182, -0.0100],
...,
[ 0.1610, -0.1412, 0.0739, ..., -0.1489, -0.1346, -0.1699],
[ 0.2537, -0.0202, -0.0108, ..., -0.0000, -0.1172, -0.1695],
[ 0.0390, -0.1419, -0.0134, ..., -0.0238, -0.1479, -0.1672]]],
device='cuda:0'),)
_output: tensor([[[ 2.1918, 0.2757, -1.2079, ..., -1.6658, -0.3475, 3.0095],
[-0.0071, -1.0461, -3.7558, ..., -1.6386, 0.9958, -0.5041],
[ 1.1246, 0.2776, 0.0690, ..., -0.7909, -0.5560, 0.2900],
...,
[ 0.1987, -1.5820, 0.4737, ..., 0.0167, -2.0354, 0.1269],
[ 0.0065, -0.1991, -0.7786, ..., 0.3970, 1.0870, 3.0673],
[ 0.1992, 0.5976, 0.3043, ..., -0.8670, -0.4415, -0.7263]]],
device='cuda:0')
new_output: tensor([[[ 0.1005, 0.7678, 0.6264, ..., -1.4815, -0.3919, 1.2432],
[-1.0859, 0.3843, -2.6613, ..., 0.3616, 0.7369, -0.3413],
[-0.1934, 0.7930, -1.0414, ..., -0.8458, 0.0387, 0.3547],
...,
[ 0.3707, 0.1903, 0.9877, ..., 0.1598, -0.7087, 0.5787],
[-0.4105, 0.0394, -0.0291, ..., -0.1327, -0.5015, 2.8738],
[ 0.2976, 0.0607, -0.0353, ..., -0.3022, 0.1569, -0.0568]]],
device='cuda:0')
outputs.logits: tensor([[ 0.7795, 2.0030, -1.6746, ..., -4.6599, -4.3385, -2.4356],
[ 24.9088, 29.4419, 21.8366, ..., 21.9064, 17.4732, 17.0667],
[-23.4861, -19.9130, -27.9467, ..., -24.4025, -30.8767, -30.1642],
...,
[ 21.3362, 19.2360, 16.6425, ..., 12.6393, 7.5003, 9.7640],
[ 47.4022, 44.4750, 39.5697, ..., 38.5065, 32.9435, 33.3205],
[ -4.2421, -2.4494, -5.3408, ..., -10.9985, -18.5314, -12.8872]],
device='cuda:0')
loss: 25.420941218783344
제가 구현 하면서 혹시 놓친 부분이 있을까요?
아니면 Attn, MLP 내부 모듈을 대상으로는 hooking이 원래 안되는 걸까요?
아시는 분이 계시다면 답변 꼭 부탁드리겠습니다,.
감사합니다.