Stateful LSTM?

파이토치 LSTM에서는 Batch가 진행됨에 따라, hidden state가 초기화가 되나요?

케라스로 치면, stateful=False일때가 그런 상황이라 알고 있습니다.

pytorch에서는 설정해주는 방법이 있을까요?

자문자답 합니다.
model forward시 , hidden state를 담지 않으면
default로 hidden state는 매번 초기화되네요.

3개의 좋아요

설명하신 내용이 맞지만 다른 분들을 위해서 조금 부가 설명을 하겠습니다.

예를 들어서 forward([[1,2,3]],hidden_state)에서 간략하게 표현하면 아래 같은 형태로 동작합니다.

for i in range(3):
  out[i],hidden_state = forward(input[i] ,hidden_state)
return out, hidden_state

위를 보시면 입력 시퀀스에서 각 입력 심볼마다 hidden_state가 계속 업데이트되는 stateful하게 동작합니다. 위 동작이 기본으로 설정되어 있기 때문에 위와 같은 루프 없이도 forward([[1,2,3]],hidden_state)로 동일한 결과를 얻을 수 있습니다.

또 다른 예시로 forward([[1,2,3]])의 경우 즉 초기 hidden_state를 주지 않은 경우에는 아래와 같이 동작합니다.

hidden_state=zeros()
for i in range(3):
  out[i],hidden_state = forward(input[i] ,hidden_state)
return out, hidden_state

별도의 설정이 없는 경우에는 hidden_state를 초기화해서 이전 예제와 동일하게 수행합니다.

기본 설정과 달리 stateful이 아닌 연산을 하고 싶다면 별도의 loop를 만들어서 진행해야합니다.
예를 들어서 [1,2,3]에 같은 hidden_state를 쓰고 싶다면 아래와 같이 할 수도 있습니다.

for i in range(3):
  out[i], new_hidden[i]= forward(input[i] ,hidden_state)
return out, new_hidden

또는


forward([[1],[2],[3]], hidden_state.repeat(3)..)

감사합니다

4개의 좋아요