안녕하세요. Libtorch로 MaskRCNN을 공부하고 있는 사람입니다.
Libtorch버전을 v1.0.1에서 v1.13.1으로 변경하여 COCO2017을 학습하려합니다만,
이전버젼 v1.0.1에서는 num_batches_tracked에 관련 에러가 없었는데
최선버젼 v1.13.1에서는 아래와 같이 에러가 발생합니다.
학습과정에서 중간 val.key를 확인 해보니 데이터셋 학습 후 저장할 때는 num_batches_tracked의 정보가 없는데, Inference하기 위한 Load 시 에러가 발생합니다.
BatchNorm2d에서 track_running_stats 값이 true가 Default이기에 학습시 running_mean,running_var,num_batches_tracked가 모두 생성되는 줄 알았는데, num_batches_tracked 만 생성이 안되었습니다.
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
학습시 num_batches_tracked을 생성하거나, Load시 num_batches_tracked을 무시하는 방법을 알고 계신분의 귀중한 답변부탁드립니다.
--- 프로그램 환경 ---
- OS : Linux Ubuntu 22.04
- Libtorch : 1.13.1
- Cuda : 11.7
--- Source 부분 ---
void LoadStateDict(torch::nn::Module& module,
const std::string& file_name,
const std::string& ignore_name_regex) {
torch::serialize::InputArchive archive;
archive.load_from(file_name);
torch::NoGradGuard no_grad;
std::regex re(ignore_name_regex);
std::smatch m;
auto params = module.named_parameters(true /recurse/);
auto buffers = module.named_buffers(true /recurse/);
...
--- 에러 메세지 ---
val.key : fpn.c1.1.running_mean
val.key : fpn.c1.1.running_var
val.key : fpn.c1.1.num_batches_tracked
No such serialized tensor 'fpn.C1.1.num_btaches_tracked'