학습에서 행렬추가시

안녕하세요.
STL10을 사용해서 간단하게 학습중인 모델이 있는데요,
이 모델의 텐서값을 확인해보고 싶어서 중간에 몇개 값을 뽑는 for문을 추가했습니다.
근데 그 텐서 값을 확인하는 for문을 추가했더니 학습시간이 엄청나게 느려지네요;;
원래 1에폭에 30초 걸리던게 27분이 걸립니다.
혹시 이유를 아시는 분 계실까요?

텐서는 similarity_matrix이며 torch.matmul로 생성해주었고 사진은 추가한 코드부분입니다.

스크린샷, 2022-08-02 13-56-59

안녕하세요, @utjune 님!

전체 코드를 못 봐서 자세히는 모르겠는데, 먼저 두 가지 경우를 한 번 확인해보시는건 어떨까 싶습니다.

첫번째는 (아마도) 매 batch마다 len(similarity_matrix) ** 2 만큼 탐색(traverse)하게 되어 기본적으로 len(similarity_matrix)total_rows / batch_size에 비례해서 연산 복잡도가 증가할 것 같습니다. batch_size를 키워보시거나 len(similarity_matrix)을 줄여보시면서 연산 시간이 얼마나 변하는지 측정을 해보시면 어떠실까 싶습니다.

두번째는 혹시나 similarity_matrix가 모델 내부에 선언되어 있고, requires_grad가 켜져있다면 문제가 있을 수 있지 않을까 싶습니다. 한 번 위 코드 블럭을 with torch.no_grad():로 감싸보시고 테스트해보시면 어떠실까 싶습니다.

혹시 위 2가지 모두 아니라면 코드가 조금 더 필요할 것 같습니다. :smiley:

1개의 좋아요

감사합니다 한번 해보겠습니다

1개의 좋아요

안녕하세요.
다름이 아니라 조언해주신대로 코딩을 해봤습니다.
첫번째 방식은 제가 batch size나 len(similarity_matrix)를 조절할 수 없어서 사용하지 못했습니다.(연구상의 이유)
두번째 방식대로 했더니 조금 빨라지긴 했으나 아무래도 for문이 2개나 들어가있다보니 for문이 없었을때보다는
현저히 느리더라구요.

현재 제가 연구하고 있는 문제라 전체 코드를 보여드릴 수는 없지만 간단하게 설명드리면
similarity_matrix는 feature끼리 matmul해준 텐서입니다.
그중 한 행에 있는 어떤 인덱스의 값을 선택하여 그 행에 포함된 모든 값들과 비교 연산을 하는 과정입니다.
아무래도 비교연산을 하는 과정에 for문이 포함되어있고 연산이 cpu에서 진행되어 현저히 느려지는거라고 추측하고있습니다.(물론 아닐수도 있습니다)

혹시 이 연산을 텐서에서 바로 할 수 있는 방법이 있을까요?
며칠동안 구글링을 해도 텐서에서 바로 저 for문처럼 값을 비교하는 방식을 찾지 못해서 다시 질문드립니다.

다시 한번 조언감사합니다.

2개의 좋아요

네, 대략 추측하기로는 2D tensor를 돌면서 cosine similarity를 비교하는 부분은 tensor.where() 함수 같은 것으로 치환하실 수 있지 않을까 싶습니다.

tensor.where() 함수는 numpy.where() 함수와 동일하게 tensor.where(조건, True 시 값, False 시 값) 입니다.

tensor.where() 함수의 자세한 사용법은 아래 문서 링크를 참고해주세요.

https://pytorch.org/docs/stable/generated/torch.where.html

간단히 로컬에서 아래와 같이 시간 비교를 해봤는데요,
[1000, 1000] 크기의 행렬에서 torch.where()는 5.19ms, for-loop x2는 18.9s 정도 걸렸습니다.

image

일전에 말씀드렸던 것처럼 행렬의 크기가 커지거나 batch_size가 커지면 더 많은 시간이 걸릴 것 같습니다.

2개의 좋아요

마지막 댓글로부터 24 시간이 지나 글타래가 자동으로 닫혔습니다. 새로운 댓글을 다실 수 없습니다.

앗, 설정이 잘못되었는지 글이 닫혔네요ㅠ
@utjune 님, 혹시 잘 되시는지요?

아 넵!
알려주신 where메소드 써서 행렬 횡단하지않고 텐서만 조작해서 값을 구하니
전처럼 다시 속도가 빨라졌습니다!

감사합니다 !!

1개의 좋아요

도움이 되셨다니 다행이네요! :tada:

2개의 좋아요