LightningModule을 통해 GCN 모델을 구축하였고 이를 fit하여 학습 시키는 중인데 학습중 0에포크 학습 후 1에포크로 넘어가질 않고 무한정 멈추거나 0에포크의 1 혹은 2 사이클에서 멈추는 현상이 지속됩니다. 문제는 같은 코드를 계속 똑같은 상황에서 실행할때마다 멈추는 구간이 다르며 가끔은 멈추지않고 진행되기도 합니다 (이러다가 다시 에포크 끝에서 멈출때가 있습니다). 이런 현상에 대해 겪은 적이 있으신 분들이 있으시다면 답변 부탁드립니다...
이 부분은 model fit 코드이며:
train_loader = DataLoader(train_dataset, batch_size = 64, collate_fn = collate_fn, shuffle = True, drop_last = True, num_workers = 3)
val_loader = DataLoader(val_dataset, batch_size = 64, collate_fn = collate_fn, shuffle = False, num_workers = 3)
model = GraphModel(emb_dim = 128, hidden_dim = 128, output_dim = [64, 32, 1], dropout_rate = 0.1)
trainer = pl.Trainer(
strategy = 'ddp',
logger = logger,
# callbacks = [early_stop_callback], # [LoggingCallback(file_name = version)],
enable_checkpointing = True,
enable_progress_bar = True,
accelerator = "auto",
devices = [1, 2],
max_epochs = 200,
precision = "16-mixed"
)
trainer.fit(model, train_dataloaders = train_loader, val_dataloaders = val_loader)
이 부분이 training과 validation step 코드입니다.
def training_step(self,
batch: TrainingBatch):
labels = batch.labels
outputs = self.forward(batch)
batch_size = outputs.size(0)
loss = self.criterion(outputs, labels)
self.log('loss_train', loss, prog_bar = True, on_epoch = True, batch_size = batch_size)
self.r2_score_train.update(outputs, labels)
self.mae_train.update(outputs, labels)
return loss
def on_train_epoch_end(self):
self.log('r2_score_train', self.r2_score_train.compute(), on_step = False, on_epoch = True, sync_dist=True)
self.log('MAE_train', self.mae_train.compute(), on_step = False, on_epoch = True, sync_dist=True)
self.r2_score_train.reset()
self.mae_train.reset()
def validation_step(self,
batch: TrainingBatch):
labels = batch.labels
outputs = self.forward(batch)
batch_size = outputs.size(0)
loss = self.criterion(outputs, labels)
self.log('loss_val', loss, prog_bar = True, on_epoch = True, batch_size = batch_size)
self.r2_score_val.update(outputs, labels)
self.mae_val.update(outputs, labels)
return loss
def on_validation_epoch_end(self):
self.log('r2_score_val', self.r2_score_val.compute(), on_step = False, on_epoch = True, sync_dist=True)
self.log('MAE_val', self.mae_val.compute(), on_step = False, on_epoch = True, sync_dist=True)
self.r2_score_val.reset()
self.mae_val.reset()
def configure_optimizers(self):
opt = torch.optim.Adam(self.parameters(), lr = self.hparams.lr)
return opt