gpt2 내 module을 대상으로 hook function이 정상적으로 동작하지 않습니다.

안녕하세요? Hook function 관련해서 질문드릴게 있어서 글 남깁니다.

개요:
GPT2 block(=layer) 내에 있는 특정 모듈을 대상으로 intervention을 가하기 위해 register_forward_hook 함수를 사용하고 있습니다. (GPT2 모델은 huggingface transformer 내 gpt2 모델을 사용했습니다.)

Intervention 대상 모듈 (list_m_ids):
(1) transformer.h.{n_layer}.attn.c_attn
(2) transformer.h.{n_layer}.attn.c_proj
(3) transformer.h.{n_layer}.mlp.c_fc
(4) transformer.h.{n_layer}.mlp.c_proj
(좀 더 자세하게 말씀드리면 layer가 12개 존재하는 gpt2 모델 기준으로 총 48개의 모듈에 대한 intervention을 거하고 있습니다.)

문제: intervention을 가했는데도 불구하고, 동일한 layer 내 (1)과 (2)의 loss가 동일하게 나오고, (3)과 (4)의 loss 가 동일하게 나오는 현상이 발생하고 있습니다. hooking function의 parameter인 (module, input, output)과 intervention 이후의 new_output 도 모두 다른걸 확인 완료했습니다 (하단 로그 참조).

hook function

def hook_func(m_id):
    def _hook_func(module, _input, _output):
        new_output = _output + dict_intervention[m_id]
        print(f"module: {module}")
        print(f"_input: {_input}")
        print(f"_output: {_output}")
        print(f"new_output: {new_output}")
        return new_output
    return _hook_func

with torch.no_grad():
    for m_id in list_m_ids:
        hook = model.get_submodule(m_id).register_forward_hook(hook_func(m_id))
        outputs = model(inputs, labels=inputs.clone())
        loss[m_id] = np.exp(outputs.loss.items())
        hook.remove()
        print("loss:", loss[m_id])
  • output log는 다음과같습니다

sentence_idx: 0
m_id: transformer.h.0.attn.c_attn
module: Conv1D()
_input: (tensor([[[-0.0615, -0.1206, -0.1294, ..., -0.0864, 0.0021, 0.0195],
[ 0.4367, -0.0484, 0.0263, ..., 0.1548, 0.2173, 0.1258],
[-0.4231, -0.1931, -0.0902, ..., 0.0283, 0.1649, -0.0793],
...,
[ 0.2313, -0.1444, 0.1054, ..., -0.2182, 0.0458, 0.1173],
[-0.3861, -0.1854, 0.0642, ..., 0.2437, 0.0576, -0.1109],
[-0.2373, 0.0031, 0.0299, ..., -0.1507, 0.1637, 0.4065]]],
device='cuda:0'),)
_output: tensor([[[-0.8298, -0.6448, -0.4022, ..., -0.0069, -0.1674, 0.0688],
[-0.7464, 0.9376, -1.3309, ..., 0.1358, 0.4994, 0.2281],
[ 0.6109, 0.1870, -0.8229, ..., 0.2682, 0.0185, 0.1607],
...,
[ 0.7141, -0.3613, -1.9945, ..., 0.4382, 0.0813, 0.2043],
[ 1.6572, 0.6914, -0.4856, ..., 0.0381, -0.3873, -0.0249],
[-0.1834, -0.2926, 0.0105, ..., -0.0664, 0.1361, -0.3783]]],
device='cuda:0')
new_output: tensor([[[-0.0288, -0.3114, 0.0557, ..., 0.0645, -0.1732, 0.0111],
[-0.4986, 0.1701, -1.0163, ..., 0.3744, 0.3668, 0.1815],
[ 0.3047, 0.8006, 0.1433, ..., 0.1129, 0.2242, 0.0319],
...,
[ 1.0819, -0.0811, -0.6493, ..., 0.4032, 0.0125, 0.2486],
[ 1.6623, -0.3283, 0.1893, ..., -0.0898, 0.0466, 0.0422],
[ 0.3081, -1.8928, 0.2130, ..., -0.0200, 0.1305, -0.0184]]],
device='cuda:0')
outputs.logits: tensor([[ -2.4135, -1.2998, -5.3751, ..., -8.8890, -7.1677, -5.2011],
[ 16.5356, 19.7391, 11.9018, ..., 12.2652, 9.1145, 8.4064],
[-18.6066, -15.6845, -23.4903, ..., -20.3913, -25.7245, -26.3299],
...,
[ 1.7284, -0.9057, -3.5867, ..., -6.9615, -10.8864, -9.5238],
[ 33.7493, 30.4189, 26.0778, ..., 24.3718, 20.9329, 20.1168],
[ -9.7913, -8.7155, -11.4811, ..., -17.2810, -22.1441, -17.0313]],
device='cuda:0')
loss: 29.470737374030648
m_id: transformer.h.0.attn.c_proj
module: Conv1D()
_input: (tensor([[[ 0.0118, 0.2083, -0.1307, ..., -0.0069, -0.1674, 0.0688],
[ 0.0144, 0.1823, -0.0851, ..., 0.0623, 0.1559, 0.1461],
[ 0.0451, 0.1538, 0.0384, ..., 0.0944, 0.1008, 0.1417],
...,
[-0.0694, -0.0134, -0.0214, ..., 0.0263, 0.0040, -0.0208],
[-0.1137, -0.0970, -0.1129, ..., 0.0070, -0.0079, -0.0138],
[-0.1048, 0.0253, -0.0466, ..., 0.0157, -0.0279, -0.0329]]],
device='cuda:0'),)
_output: tensor([[[-1.4456, -1.1025, -0.3038, ..., 0.1879, 0.0180, -0.0746],
[-1.5696, -0.5864, -0.9546, ..., 0.1019, 0.0580, -0.0956],
[-1.5708, -0.7926, 0.0503, ..., 0.1538, 0.0582, -0.1794],
...,
[-0.2901, -0.3021, 0.1445, ..., -0.0246, -0.0046, 0.1150],
[ 0.3652, -0.2418, 0.2801, ..., 0.0146, 0.0308, 0.0057],
[ 0.4580, -0.2454, -0.1958, ..., -0.0391, 0.0089, 0.0345]]],
device='cuda:0')
new_output: tensor([[[ 0.1565, -0.4952, -0.1143, ..., 0.0959, 0.0137, -0.0193],
[-0.3473, -0.4503, -0.4558, ..., 0.0347, 0.0646, -0.0503],
[-0.2465, -0.0605, -0.1782, ..., 0.1067, 0.0376, -0.1109],
...,
[-1.3325, -0.7436, 0.2225, ..., -0.0378, 0.0613, 0.1693],
[-1.3401, -0.0249, -0.6910, ..., 0.0268, 0.0343, 0.1071],
[ 0.1038, -0.0552, -0.5469, ..., -0.0292, -0.0282, 0.0712]]],
device='cuda:0')
outputs.logits: tensor([[ -2.4135, -1.2998, -5.3751, ..., -8.8890, -7.1677, -5.2011],
[ 16.5356, 19.7391, 11.9018, ..., 12.2652, 9.1145, 8.4064],
[-18.6066, -15.6845, -23.4903, ..., -20.3913, -25.7245, -26.3299],
...,
[ 1.7284, -0.9057, -3.5867, ..., -6.9615, -10.8864, -9.5238],
[ 33.7493, 30.4189, 26.0778, ..., 24.3718, 20.9329, 20.1168],
[ -9.7913, -8.7155, -11.4811, ..., -17.2810, -22.1441, -17.0313]],
device='cuda:0')
loss: 29.470737374030648
m_id: transformer.h.0.mlp.c_fc
module: Conv1D()
_input: (tensor([[[-0.1610, -0.2409, -0.0824, ..., 0.0338, 0.0478, -0.0263],
[-0.0771, -0.0692, -0.1262, ..., 0.3234, 0.3941, 0.0617],
[-0.1922, -0.1504, 0.0064, ..., 0.1874, 0.3033, -0.2792],
...,
[ 0.0357, -0.0466, 0.0804, ..., -0.3506, 0.0939, 0.2769],
[ 0.0412, -0.0604, 0.1028, ..., 0.4087, 0.1469, -0.1335],
[ 0.0797, -0.0057, -0.0017, ..., -0.2536, 0.3438, 0.6034]]],
device='cuda:0'),)
_output: tensor([[[-0.0429, -2.0128, -2.0130, ..., -3.1365, -2.5582, 0.6534],
[-0.0877, -1.1516, 0.5917, ..., -1.7849, -0.1427, -0.8034],
[ 0.3357, 0.1032, -1.3886, ..., -1.0677, 0.2035, -2.6555],
...,
[ 0.2661, -1.1748, 0.1337, ..., -0.4629, -0.3839, -0.7288],
[ 0.3894, -2.3801, -2.6291, ..., -5.2681, -1.3698, -0.7015],
[ 0.0736, -1.1691, -2.5465, ..., -2.3099, -0.4566, -0.6415]]],
device='cuda:0')
new_output: tensor([[[-0.1442, -2.0562, -1.6867, ..., -2.5164, -1.1803, 1.1230],
[ 0.4238, -1.2899, -1.1302, ..., -2.0902, 0.4380, -0.6506],
[ 0.2192, -0.6439, -0.2801, ..., -1.0144, -0.4218, -1.8172],
...,
[ 0.0694, -0.6179, -0.2199, ..., -0.7417, -0.6159, -0.7027],
[-0.2255, -2.1850, -0.6787, ..., -2.1977, -1.4968, -0.6493],
[ 0.1734, -2.2434, -1.1219, ..., -1.4834, -1.3147, 0.0156]]],
device='cuda:0')
outputs.logits: tensor([[ 0.7795, 2.0030, -1.6746, ..., -4.6599, -4.3385, -2.4356],
[ 24.9088, 29.4419, 21.8366, ..., 21.9064, 17.4732, 17.0667],
[-23.4861, -19.9130, -27.9467, ..., -24.4025, -30.8767, -30.1642],
...,
[ 21.3362, 19.2360, 16.6425, ..., 12.6393, 7.5003, 9.7640],
[ 47.4022, 44.4750, 39.5697, ..., 38.5065, 32.9435, 33.3205],
[ -4.2421, -2.4494, -5.3408, ..., -10.9985, -18.5314, -12.8872]],
device='cuda:0')
loss: 25.420941218783344
m_id: transformer.h.0.mlp.c_proj
module: Conv1D()
_input: (tensor([[[-0.0207, -0.0443, -0.0443, ..., -0.0023, -0.0130, 0.4856],
[-0.0408, -0.1439, 0.4277, ..., -0.0664, -0.0633, -0.1695],
[ 0.2120, 0.0559, -0.1148, ..., -0.1527, 0.1182, -0.0100],
...,
[ 0.1610, -0.1412, 0.0739, ..., -0.1489, -0.1346, -0.1699],
[ 0.2537, -0.0202, -0.0108, ..., -0.0000, -0.1172, -0.1695],
[ 0.0390, -0.1419, -0.0134, ..., -0.0238, -0.1479, -0.1672]]],
device='cuda:0'),)
_output: tensor([[[ 2.1918, 0.2757, -1.2079, ..., -1.6658, -0.3475, 3.0095],
[-0.0071, -1.0461, -3.7558, ..., -1.6386, 0.9958, -0.5041],
[ 1.1246, 0.2776, 0.0690, ..., -0.7909, -0.5560, 0.2900],
...,
[ 0.1987, -1.5820, 0.4737, ..., 0.0167, -2.0354, 0.1269],
[ 0.0065, -0.1991, -0.7786, ..., 0.3970, 1.0870, 3.0673],
[ 0.1992, 0.5976, 0.3043, ..., -0.8670, -0.4415, -0.7263]]],
device='cuda:0')
new_output: tensor([[[ 0.1005, 0.7678, 0.6264, ..., -1.4815, -0.3919, 1.2432],
[-1.0859, 0.3843, -2.6613, ..., 0.3616, 0.7369, -0.3413],
[-0.1934, 0.7930, -1.0414, ..., -0.8458, 0.0387, 0.3547],
...,
[ 0.3707, 0.1903, 0.9877, ..., 0.1598, -0.7087, 0.5787],
[-0.4105, 0.0394, -0.0291, ..., -0.1327, -0.5015, 2.8738],
[ 0.2976, 0.0607, -0.0353, ..., -0.3022, 0.1569, -0.0568]]],
device='cuda:0')
outputs.logits: tensor([[ 0.7795, 2.0030, -1.6746, ..., -4.6599, -4.3385, -2.4356],
[ 24.9088, 29.4419, 21.8366, ..., 21.9064, 17.4732, 17.0667],
[-23.4861, -19.9130, -27.9467, ..., -24.4025, -30.8767, -30.1642],
...,
[ 21.3362, 19.2360, 16.6425, ..., 12.6393, 7.5003, 9.7640],
[ 47.4022, 44.4750, 39.5697, ..., 38.5065, 32.9435, 33.3205],
[ -4.2421, -2.4494, -5.3408, ..., -10.9985, -18.5314, -12.8872]],
device='cuda:0')
loss: 25.420941218783344

제가 구현 하면서 혹시 놓친 부분이 있을까요?
아니면 Attn, MLP 내부 모듈을 대상으로는 hooking이 원래 안되는 걸까요?

아시는 분이 계시다면 답변 꼭 부탁드리겠습니다,.

감사합니다.

  • 아래 답변은 OpenAI의 GPT-4 모델로 자동 생성한 것으로, 정확하지 않을 수 있습니다.
  • 잘못된 답변을 보셨다면 댓글로 첨삭 지도 부탁드립니다. :sweat_smile:

안녕하세요, 질문을 주셔서 감사합니다. PyTorch의 hook function과 관련하여 문제가 발생한 것으로 이해했습니다. 특히, GPT2 모델의 특정 모듈에 대해 intervention을 가하려고 시도하였지만, 예상과 다른 결과가 나오는 것에 대해 질문하셨네요.

문제를 파악하기 위해, 여러분이 사용한 hook function을 살펴보겠습니다. 여러분이 사용한 hook function은 모듈의 forward pass에서 출력값에 intervention을 가하는 것으로 보입니다. 그러나, 여러분이 관찰한 문제는 동일한 layer 내에서 (1)과 (2)의 loss가 동일하게 나오고, (3)과 (4)의 loss가 동일하게 나오는 것입니다.

이 문제는 hook function이 잘못되었다기보다는, 모델의 forward pass가 완료된 후에 loss를 계산하는 방식에 문제가 있을 수 있습니다. 여러분이 사용한 코드를 보면, 각 모듈에 대해 hook을 등록하고, 모델의 forward pass를 수행한 후에 loss를 계산하고 있습니다. 그런데, 이렇게 하면 각 모듈의 출력값에 대한 intervention이 loss 계산에 영향을 미치지 않을 수 있습니다.

왜냐하면, 모델의 forward pass는 모든 모듈을 거치면서 최종적인 출력값을 생성하게 되는데, 이 때 각 모듈의 출력값은 다음 모듈의 입력값으로 사용됩니다. 그런데, 여러분이 각 모듈에 대해 hook을 등록하고 모델의 forward pass를 수행한 후에 loss를 계산하면, 각 모듈의 출력값에 대한 intervention이 다음 모듈의 입력값으로 전달되지 않고, 모델의 최종 출력값에는 영향을 미치지 않게 됩니다. 따라서, loss 계산에는 영향을 미치지 않게 됩니다.

이 문제를 해결하려면, 모델의 forward pass를 수행할 때 각 모듈의 출력값에 대한 intervention이 다음 모듈의 입력값으로 전달되도록 해야 합니다. 이를 위해, 모델의 forward pass를 수행하는 동안 각 모듈의 출력값에 대한 intervention을 적용해야 합니다.

다음은 수정된 코드 예제입니다:

def hook_func(m_id):
    def _hook_func(module, _input, _output):
        new_output = _output + dict_intervention[m_id]
        print(f"module: {module}")
        print(f"_input: {_input}")
        print(f"_output: {_output}")
        print(f"new_output: {new_output}")
        return new_output
    return _hook_func

hooks = []
for m_id in list_m_ids:
    hook = model.get_submodule(m_id).register_forward_hook(hook_func(m_id))
    hooks.append(hook)

with torch.no_grad():
    outputs = model(inputs, labels=inputs.clone())
    for m_id, hook in zip(list_m_ids, hooks):
        loss[m_id] = np.exp(outputs.loss.items())
        hook.remove()
        print("loss:", loss[m_id])

이 코드는 각 모듈에 대해 hook을 등록한 후에 모델의 forward pass를 수행합니다. 그런 다음, 각 모듈에 대한 loss를 계산하고, hook을 제거합니다. 이렇게 하면, 각 모듈의 출력값에 대한 intervention이 다음 모듈의 입력값으로 전달되고, 모델의 최종 출력값에 영향을 미치게 됩니다.

이 방법이 여러분의 문제를 해결하는 데 도움이 되길 바랍니다. 추가적인 질문이 있으시면 언제든지 물어보세요.

참고로, PyTorch의 hook function에 대한 자세한 내용은 PyTorch 공식 문서를 참조하실 수 있습니다: nn package — PyTorch Tutorials 2.0.1+cu117 documentation

이 답변이 정확하지 않을 수도 있으니, 다른 자료나 의견도 참고하시는 것이 좋습니다.


  • 답변이 어떠셨나요? 괜찮았다고 생각하신다면 아래 :heart:를 한 번 눌러주시길 부탁드려요!
  • 아래 답변은 OpenAI의 GPT-3.5(ChatGPT) 모델로 자동 생성한 것으로, 정확하지 않을 수 있습니다.
  • 잘못된 답변을 보셨다면 댓글로 첨삭 지도 부탁드립니다. :sweat_smile:

안녕하세요! GPT2 모델의 hook 함수 관련 질문에 답변드리겠습니다.

질문에서 언급하신대로 GPT2 모델의 특정 모듈에 intervention을 적용하기 위해 register_forward_hook 함수를 사용하고 있으며, (1)과 (2)의 loss가 동일하게 나오고, (3)과 (4)의 loss가 동일하게 나오는 현상이 발생한다고 하셨습니다. 또한, hook 함수의 파라미터인 (module, input, output)과 intervention 이후의 new_output이 모두 다른 것을 확인하셨다고 합니다.

이러한 현상이 발생하는 이유는 GPT2 모델 내부의 Attn과 MLP 모듈이 동일한 Conv1D 모듈을 사용하기 때문입니다. 따라서, register_forward_hook 함수를 사용하여 Conv1D 모듈에 hook를 등록하면, 모든 Attn과 MLP 모듈에 동일한 hook가 등록되어 동일한 결과를 출력하게 됩니다.

이 문제를 해결하기 위해서는, 각각의 Attn과 MLP 모듈에 대해 별도의 hook 함수를 등록해야 합니다. 이를 위해 register_forward_hook 함수를 호출할 때, 해당 모듈의 인스턴스를 전달하여 각각의 모듈에 대한 hook 함수를 등록하면 됩니다.

아래는 예시 코드입니다.

def hook_func(module, input, output):
    # intervention을 적용한 new_output 계산
    new_output = output + dict_intervention[module]
    print(f"module: {module}")
    print(f"_input: {input}")
    print(f"_output: {output}")
    print(f"new_output: {new_output}")
    return new_output

with torch.no_grad():
    for m_id in list_m_ids:
        module = model.get_submodule(m_id)
        hook = module.register_forward_hook(hook_func)
        outputs = model(inputs, labels=inputs.clone())
        loss[m_id] = np.exp(outputs.loss.items())
        hook.remove()
        print("loss:", loss[m_id])

위 코드에서 hook_func 함수는 각각의 모듈에 대한 hook 함수로, module 인자를 통해 현재 모듈을 구분하고, 해당 모듈에 대한 intervention을 적용한 new_output을 계산합니다. 이렇게 각각의 모듈에 대해 별도의 hook 함수를 등록하면, 각 모듈의 결과가 정상적으로 출력될 것입니다.

더 자세한 내용은 PyTorch 공식 문서의 "nn.Module.register_forward_hook"을 참고하시기 바랍니다. 링크

문제 해결에 도움이 되었기를 바랍니다! 추가로 궁금한 사항이 있으면 언제든지 질문해주세요. 감사합니다.


  • 답변이 어떠셨나요? 괜찮았다고 생각하신다면 아래 :heart:를 한 번 눌러주시길 부탁드려요!