'num_batches_tracked' Error 문의

안녕하세요. Libtorch로 MaskRCNN을 공부하고 있는 사람입니다.
Libtorch버전을 v1.0.1에서 v1.13.1으로 변경하여 COCO2017을 학습하려합니다만,
이전버젼 v1.0.1에서는 num_batches_tracked에 관련 에러가 없었는데
최선버젼 v1.13.1에서는 아래와 같이 에러가 발생합니다.

학습과정에서 중간 val.key를 확인 해보니 데이터셋 학습 후 저장할 때는 num_batches_tracked의 정보가 없는데, Inference하기 위한 Load 시 에러가 발생합니다.

BatchNorm2d에서 track_running_stats 값이 true가 Default이기에 학습시 running_mean,running_var,num_batches_tracked가 모두 생성되는 줄 알았는데, num_batches_tracked 만 생성이 안되었습니다.
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))

학습시 num_batches_tracked을 생성하거나, Load시 num_batches_tracked을 무시하는 방법을 알고 계신분의 귀중한 답변부탁드립니다.

--- 프로그램 환경 ---

  1. OS : Linux Ubuntu 22.04
  2. Libtorch : 1.13.1
  3. Cuda : 11.7

--- Source 부분 ---
void LoadStateDict(torch::nn::Module& module,
const std::string& file_name,
const std::string& ignore_name_regex) {

torch::serialize::InputArchive archive;
archive.load_from(file_name);
torch::NoGradGuard no_grad;
std::regex re(ignore_name_regex);
std::smatch m;
auto params = module.named_parameters(true /recurse/);
auto buffers = module.named_buffers(true /recurse/);
...

--- 에러 메세지 ---
val.key : fpn.c1.1.running_mean
val.key : fpn.c1.1.running_var
val.key : fpn.c1.1.num_batches_tracked
No such serialized tensor 'fpn.C1.1.num_btaches_tracked'

1개의 좋아요

얘기하신대로 학습시 default로 track_running_stats=True 여서 있어야 할 것 같은데 이상하네요.

모델 load시 레이어 이름에 num_batches_tracked 가 포함되는 레이어는 무시되도록 수정을 하셔서 읽어오게 하셔야 할 것 같습니다.

1개의 좋아요

답변 주심 감사합니다. 혹시나 계산중에 'num_batches_tracked'을 참조하거나 사용하지는 않겠죠?


재분석해보니 Null로 저장되어 검색이 안되었습니다.
Source 대로 true 이면 자동생성이 됩니다.
검토해주심 감사드립니다.

1개의 좋아요

추론 중에는 사용되지 않을것 같네요. 감사합니다.