파이토치 모델의 개요(summary)를 어떻게 확인할 수 있나요?

공식 홈페이지StackOverflow 등에서 자주 보이는 질문과 답변을 번역하고 있습니다.

다음 링크에서 원문을 함께 찾아보실 수 있습니다.


질문

  • 케라스(Keras)의 model.summary()처럼 모델의 개요(구조)를 보여주는 메소드는 없나요?
    Model Summary:
    ____________________________________________________________________________________________________
    Layer (type)                     Output Shape          Param #     Connected to                     
    ====================================================================================================
    input_1 (InputLayer)             (None, 1, 15, 27)     0                                            
    ____________________________________________________________________________________________________
    convolution2d_1 (Convolution2D)  (None, 8, 15, 27)     872         input_1[0][0]                    
    ____________________________________________________________________________________________________
    maxpooling2d_1 (MaxPooling2D)    (None, 8, 7, 27)      0           convolution2d_1[0][0]            
    ____________________________________________________________________________________________________
    flatten_1 (Flatten)              (None, 1512)          0           maxpooling2d_1[0][0]             
    ____________________________________________________________________________________________________
    dense_1 (Dense)                  (None, 1)             1513        flatten_1[0][0]                  
    ====================================================================================================
    Total params: 2,385
    Trainable params: 2,385
    Non-trainable params: 0
    

답변

  • torchinfo 패키지 이용하면 비슷한 결과를 볼 수도 있습니다. (torch-summary 패키지도 있습니다.)
    # `pip install torchinfo` 명령어로 패키지를 설치한 다음에 사용할 수 있습니다.
    >>> from torchvision import models
    >>> from torchinfo import summary
    >>> model = models.vgg16()
    >>> summary(model)
    =================================================================
    Layer (type:depth-idx)                   Param #
    =================================================================
    VGG                                      --
    ├─Sequential: 1-1                        --
    │    └─Conv2d: 2-1                       1,792
    │    └─ReLU: 2-2                         --
    │    └─Conv2d: 2-3                       36,928
    │    └─ReLU: 2-4                         --
    │    └─MaxPool2d: 2-5                    --
    │    └─Conv2d: 2-6                       73,856
    │    └─ReLU: 2-7                         --
    │    └─Conv2d: 2-8                       147,584
    │    └─ReLU: 2-9                         --
    │    └─MaxPool2d: 2-10                   --
    │    └─Conv2d: 2-11                      295,168
    │    └─ReLU: 2-12                        --
    │    └─Conv2d: 2-13                      590,080
    │    └─ReLU: 2-14                        --
    │    └─Conv2d: 2-15                      590,080
    │    └─ReLU: 2-16                        --
    │    └─MaxPool2d: 2-17                   --
    │    └─Conv2d: 2-18                      1,180,160
    │    └─ReLU: 2-19                        --
    │    └─Conv2d: 2-20                      2,359,808
    │    └─ReLU: 2-21                        --
    │    └─Conv2d: 2-22                      2,359,808
    │    └─ReLU: 2-23                        --
    │    └─MaxPool2d: 2-24                   --
    │    └─Conv2d: 2-25                      2,359,808
    │    └─ReLU: 2-26                        --
    │    └─Conv2d: 2-27                      2,359,808
    │    └─ReLU: 2-28                        --
    │    └─Conv2d: 2-29                      2,359,808
    │    └─ReLU: 2-30                        --
    │    └─MaxPool2d: 2-31                   --
    ├─AdaptiveAvgPool2d: 1-2                 --
    ├─Sequential: 1-3                        --
    │    └─Linear: 2-32                      102,764,544
    │    └─ReLU: 2-33                        --
    │    └─Dropout: 2-34                     --
    │    └─Linear: 2-35                      16,781,312
    │    └─ReLU: 2-36                        --
    │    └─Dropout: 2-37                     --
    │    └─Linear: 2-38                      4,097,000
    =================================================================
    Total params: 138,357,544
    Trainable params: 138,357,544
    Non-trainable params: 0
    =================================================================
    
  • 하지만, 아래와 같이 그냥 모델을 출력하기만 해도 어느 정도의 정보는 확인할 수 있습니다.
    >>> from torchvision import models
    >>> model = models.vgg16()
    >>> print(model)
    VGG(
     (features): Sequential(
       (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (1): ReLU(inplace=True)
       (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (3): ReLU(inplace=True)
       (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
       (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (6): ReLU(inplace=True)
       (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (8): ReLU(inplace=True)
       (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
       (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (11): ReLU(inplace=True)
       (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (13): ReLU(inplace=True)
       (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (15): ReLU(inplace=True)
       (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
       (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (18): ReLU(inplace=True)
       (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (20): ReLU(inplace=True)
       (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (22): ReLU(inplace=True)
       (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
       (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (25): ReLU(inplace=True)
       (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (27): ReLU(inplace=True)
       (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (29): ReLU(inplace=True)
       (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
     )
     (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
     (classifier): Sequential(
       (0): Linear(in_features=25088, out_features=4096, bias=True)
       (1): ReLU(inplace=True)
       (2): Dropout(p=0.5, inplace=False)
       (3): Linear(in_features=4096, out_features=4096, bias=True)
       (4): ReLU(inplace=True)
       (5): Dropout(p=0.5, inplace=False)
       (6): Linear(in_features=4096, out_features=1000, bias=True)
     )
    )
    

더 알아보기