안녕하세요 wav2vec을 이용해서 fine-tuning을 해보려고 합니다.
허깅페이스에서 kresnik/wav2vec2-large-xlsr-korean 이 모델로 선택했습니다.
마지막 lm_head 부분만 미세 조정을 하려고 하고, vocab을 추가해서 1405 -> 2100 정도로 수정했습니다.
lm_head를 제외한 부분은 requires_grad =False로 하였고 모델에 label, attention_mask를 넣으면 loss를 바로 계산해줘서 그걸로 loss.backward()를 하고 있습니다.
근데 cuda out of memory 문제가 있는데요..
학습하고자 하는 오디오는 0~30초 정도의 길이를 가지고 있고, 현재 gpu의 용량은 24기가입니다.
batch_size = 1 까지는 잘되나, 2 이상부터는 cuda out of memory가 계속 뜹니다.
for epoch in tqdm(range(epochs)):
model.train()
loss_value = 0
valid_loss_value = 0
total_batch_size = 0
valid_total_batch_size = 0
for batch in tqdm(train_dataloader):
optimizer.zero_grad()
audio, text_label, attn_mask = batch
batch_size = audio.shape[0]
out = model(audio.to(device), labels=text_label.to(device), attention_mask=attn_mask.to(device))
prediction_ids = torch.argmax(out.logits, dim=-1)
pred = processor.batch_decode(prediction_ids)
true_y = processor.batch_decode(text_label)
wer_prob = wer(pred, true_y).item()
wer_record.append(wer_prob)
loss = out.loss
loss.backward()
optimizer.step()
loss_value += loss.item() * batch_size
total_batch_size += batch_size
학습 코드는 이렇습니다.
제 오디오 input이 길어서 문제가 생기는지, 어떻게 해야할지 모르겠습니다..
batch_size=8 로 해서 next(iter(train_datalodaer)를 하고 model에 단순히 넣었을때는 사용하는 메모리가 10000정도 밖에 안되는데 위의 코드로 돌리면 batch_size =1로 해도 거의 24기가 꽉차네요..
out = model(audio.to(device), labels=text_label.to(device), attention_mask=attn_mask.to(device)) 45 # allocated_memory = torch.cuda.memory_allocated() 46 # print(f"Allocated CUDA memory at model : {allocated_memory / 10243:.2f} GB") 47 48 # max_allocated_memory = torch.cuda.max_memory_allocated() 49 # print(f"Max allocated CUDA memory: {max_allocated_memory / 10243:.2f} GB") 52 prediction_ids = torch.argmax(out.logits, dim=-1) File ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = , File ~/.local/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:1716, in Wav2Vec2ForCTC.forward(self, input_values, attention_mask, output_attentions, output_hidden_states, return_dict, labels) 1706 r""" 1707 labels (torch.LongTensor
of shape (batch_size, target_length)
, optional): 1708 Labels for connectionist temporal classification. Note that target_length
has to be smaller or equal to (...) 1711 config.vocab_size - 1]`. 1712 """ 1714 return_dict = return_dict if return_dict is not None else self.config.use_return_dict -> 1716 outputs = self.wav2vec2( 1717 input_values, 1718 attention_mask=attention_mask, 1719 output_attentions=output_attentions, 1720 output_hidden_states=output_hidden_states, 1721 return_dict=return_dict, 1722 ) 1724 hidden_states = outputs[0] 1725 hidden_states = self.dropout(hidden_states) File ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = , File ~/.local/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:1347, in Wav2Vec2Model.forward(self, input_values, attention_mask, mask_time_indices, output_attentions, output_hidden_states, return_dict) 1342 output_hidden_states = ( 1343 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1344 ) 1345 return_dict = return_dict if return_dict is not None else self.config.use_return_dict -> 1347 extract_features = self.feature_extractor(input_values) 1348 extract_features = extract_features.transpose(1, 2) 1350 if attention_mask is not None: 1351 # compute reduced attention_mask corresponding to feature vectors File ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = , File ~/.local/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:515, in Wav2Vec2FeatureEncoder.forward(self, input_values) 510 hidden_states = torch.utils.checkpoint.checkpoint( 511 create_custom_forward(conv_layer), 512 hidden_states, 513 ) 514 else: --> 515 hidden_states = conv_layer(hidden_states) 517 return hidden_states File ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = , File ~/.local/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:387, in Wav2Vec2LayerNormConvLayer.forward(self, hidden_states) 386 def forward(self, hidden_states): --> 387 hidden_states = self.conv(hidden_states) 389 hidden_states = hidden_states.transpose(-2, -1) 390 hidden_states = self.layer_norm(hidden_states) File ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = , File ~/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:313, in Conv1d.forward(self, input) 312 def forward(self, input: Tensor) -> Tensor: --> 313 return self._conv_forward(input, self.weight, self.bias) File ~/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:309, in Conv1d._conv_forward(self, input, weight, bias) 305 if self.padding_mode != 'zeros': 306 return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), 307 weight, bias, self.stride, 308 _single(0), self.dilation, self.groups) --> 309 return F.conv1d(input, weight, bias, self.stride, 310 self.padding, self.dilation, self.groups)