Huggingface BertForSequenceClassification를 이용한 문장 분류

pretrained bert를 이용해서 문장 분류를 해보고 있는데, 학습이 잘 안 되네요
데이터는 torchtext에서 제공하는 SST2 데이터를 쓰고 있습니다.
Huggingface의 BertForSequenceClassification class를 써서 하고 있는데,
뭔가 output을 잘못된 걸 가져오는 게 아닌가 하는데 해결법을 모르겠습니다.
혹시 뭐가 잘못됐는지 보이시는 분 계신가요?

import numpy as np
from sklearn.metrics import f1_score, precision_recall_fscore_support
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from torchtext.datasets import SST2
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel, BertForSequenceClassification

LR = 0.0005
EPOCHS = 5
BATCH_SIZE = 128

device = torch.device("cuda:" + "0" if torch.cuda.is_available() else "cpu")


model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = BertForSequenceClassification.from_pretrained(model_name)
max_input_length = 128


model = base_model.to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = AdamW(model.parameters(), lr=LR)

train_datapipe = SST2(split="train")
valid_datapipe = SST2(split="dev")
# Transform the raw dataset using non-batched API (i.e apply transformation line by line)
def collate_batch(batch):
    ids, types, masks, label_list = [], [], [], []
    for text, label in batch:
        tokenized = tokenizer(text,
                              padding="max_length", max_length=max_input_length,
                              truncation=True, return_tensors="pt")
        ids.append(tokenized['input_ids'])
        types.append(tokenized['token_type_ids'])
        masks.append(tokenized['attention_mask'])
        label_list.append(label)

    input_data = {
        "input_ids": torch.squeeze(torch.stack(ids)).to(device),
        "token_type_ids": torch.squeeze(torch.stack(types)).to(device),
        "attention_mask": torch.squeeze(torch.stack(masks)).to(device)
    }
    label_list = torch.tensor(label_list, dtype=torch.int64)
    return input_data, label_list


train_dataloader = DataLoader(train_datapipe, shuffle=True, batch_size=BATCH_SIZE, collate_fn=collate_batch)
valid_dataloader = DataLoader(valid_datapipe, batch_size=BATCH_SIZE, collate_fn=collate_batch)

# print("total instances: ", len(train_dataloader))
for epoch in range(EPOCHS):
    model.train()
    train_loss = []
    all_labels = []
    all_outs = []

    for i, (input_data, label) in enumerate(tqdm(train_dataloader)):
        model.zero_grad()
        output = model(**input_data).logits
        label = label.to(device)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        loss_value = loss.item()
        train_loss.append(loss_value)

        label = label.cpu().numpy()
        all_labels.extend(label)
        predicted_labels = torch.argmax(output, dim=-1).cpu().numpy()
        all_outs.extend(predicted_labels)

    # calc metric
    train_loss = np.mean(train_loss)
    print(train_loss)
    print(precision_recall_fscore_support(all_labels, all_outs))
    print("-" * 30)

    model.eval()
    val_loss = []
    val_accuracy = []
    with torch.no_grad():
        for i, (input_data, label) in enumerate(tqdm(valid_dataloader)):
            output = model(**input_data).logits
            label = label.to(device)
            loss = criterion(output, label)

            loss_value = loss.item()
            val_loss.append(loss_value)

            predicted_labels = torch.argmax(output, dim=-1)
            accuracy = (predicted_labels == label).cpu().numpy().mean() * 100
            val_accuracy.append(accuracy)

        print(np.mean(val_loss))
        print("accuracy:", np.mean(val_accuracy))
        print("=" * 30)

    # save