AI Model 찍먹해보기: torch hub 소개

별도 package 설치 없이 pytorch 만으로 간단하게 찍먹해볼 수 있는 방법인 torch hub를 소개드릴까 합니다.

Pytorch Hub란

공식 문서에서는 연구결과 재현을 쉽게 하기 위해 사전 훈련된 모델 저장소로 소개하고있습니다.

아무나 github 저장소를 통해 모델을 쉽게 배포할 수 있고, 몇 줄의 코드로 쉽게 가져다 쓸 수 있습니다.

공식 릴리즈 기준으로 v1.0.0에 추가되었습니다.

Torch hub 구조

Torch Hub 는 깃허브 저장소(혹은 로컬 디렉터리) 단위로 배포할 수 있고,
사용자는 torch.hub.load(repo,model_entry_point) 형태로 불러와서 사용할 수 있게 되어있습니다.

Torch hub 돌려보기

개요

시험삼아 torch vision 에 있는 고전적인 vision classification 모델인 resnet 을 실행해보도록 하겠습니다.

실행 대상 모델 vision/hubconf.py에 정의되어있으며 PyTorch에서 배포하는 모델입니다.

test run은 아래 환경에서 진행하였으며 실습 코드는 GitHub - zhoonit/pytorch-hub-practice 에 올려두었습니다.

  1. apple silicon (macbook m1 air ’13)
  2. python 3.8.9

환경 설정

  1. 프로젝트 생성 및 환경 설정
$ mkdir -p torch-hub-resnet
$ python3 -m venv venv # 가상환경 설정
$ source venv/bin/activate # 가상환경 실행
$ python3 -m pip install --upgrade pip 
$ python3 -m pip install torch torchvision
$ curl -OL https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt # 레이블 매핑 준비
$ curl -OL https://github.com/pytorch/hub/raw/master/images/dog.jpg  # 샘플 이미지 준비
  1. main.py 파일 생성
# 모델 준비하기
import torch
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model.eval()

# 이미지 열고 전처리
from PIL import Image
from torchvision import transforms
input_image = Image.open('dog.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# 가속 환경 설정
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
# mps 는 7월 17일 현재 사용시 오류 발생
# elif torch.backends.mps.is_available():
#    device = 'mps'

input_batch.to(device)
model.to(device)

with torch.no_grad():
    output = model(input_batch)
# label 1000 개에 대한 confidence score
# print(output[0])
# softmax를 취해 확률 값으로 변환
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# print(probabilities)

# 사람이 읽을 수 있는 이미지 레이블로 변환
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]

top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
    print(categories[top5_catid[i]], top5_prob[i].item())

결과

코드를 실행하면 앞서 샘플로 받아두었던 이미지를 대상으로 추론을 합니다.

$ python3 ./main.py
pytorch_vision_v0.10.0
Samoyed 0.8846230506896973
Arctic fox 0.04580485075712204
white wolf 0.04427614063024521
Pomeranian 0.0056213438510894775
Great Pyrenees 0.004651993978768587

해당 이미지는 사모예드 견일 확률이 88%로 잘 분류가 되고있음을 확인할 수 있습니다 :slight_smile:

4개의 좋아요