Segment Anything 모델 미세조정하기 (How To Fine-Tune Segment Anything)




Segment Anything 모델 미세조정하기 / How To Fine-Tune Segment Anything

Alexandre Bonnet, Published April 13, 2023•Edited April 21, 2023•10 min read


컴퓨터 비전은 지난 주 메타(Meta)의 세그먼트 애니씽 모델(SAM)의 출시와 함께 ChatGPT의 순간을 맞이하고 있습니다. 110억 개 이상의 세그먼테이션 마스크를 학습한 SAM은 생성형 AI가 아닌 예측형 AI 사용 사례를 위한 기반 모델입니다. 광범위한 이미지 양식과 문제 공간을 세분화하는 능력에서 놀라운 유연성을 보여줬지만, '미세 조정' 기능 없이 출시되었습니다.

Computer vision is having its ChatGPT moment with the release of the Segment Anything Model (SAM) by Meta last week. Trained over 11 billion segmentation masks, SAM is a foundation model for predictive AI use cases rather than generative AI. While it has shown an incredible amount of flexibility in its ability to segment over wide-ranging image modalities and problem spaces, it was released without “fine-tuning” functionality.

이 튜토리얼에서는 마스크 디코더(mask decoder)를 사용하여 SAM을 미세 조정하는 몇 가지 주요 단계를 설명하며, 특히 미세 조정에 적합한 상태가 되도록 데이터를 사전/사후 처리하는 데 SAM의 어떤 기능을 사용해야 하는지에 대해 설명합니다.

This tutorial will outline some of the key steps to fine-tune SAM using the mask decoder, particularly describing which functions from SAM to use to pre/post process the data so that it's in a good shape for fine tuning.

업데이트: 독자들의 요청에 따라 SAM을 미세 조정하는 데 필요한 모든 코드가 포함된 전체 Colab Notebook을 포함했습니다. 링크는 아래에서 확인할 수 있습니다 :point_down:

Update: By popular demand - we've included a full Colab Notebook with all the code you need to fine-tune SAM. The link can be found reading on :point_down:


Segment Anything 모델(SAM)이란 무엇인가요? / What is the Segment Anything Model (SAM)?


Segment Anything 모델(SAM)은 메타 AI에서 개발한 세분화 모델입니다. 컴퓨터 비전을 위한 최초의 파운데이션 모델(foundation model)로 간주됩니다. SAM은 수백만 개의 이미지와 수십억 개의 마스크가 포함된 방대한 데이터 코퍼스를 학습하여 매우 강력합니다. 이름에서 알 수 있듯이 SAM은 다양한 이미지에 대해 정확한 분할 마스크를 생성할 수 있습니다. SAM은 사람의 프롬프트를 고려할 수 있도록 설계되어, 사람이 개입한 어노테이션(Human in the Loop annotation) 작업 시 특히 강력하게 동작합니다. 이러한 프롬프트는 분할할 영역의 점, 분할할 오브젝트 주변의 경계 상자 또는 분할해야 할 내용에 대한 텍스트 프롬프트 등 다양한 모드로 표시될 수 있습니다. > The Segment Anything Model (SAM) is a segmentation model developed by Meta AI. It is considered the first foundational model for Computer Vision. SAM was trained on a huge corpus of data containing millions of images and billions of masks, making it extremely powerful. As its name suggests, SAM is able to produce accurate segmentation masks for a wide variety of images. Sam’s design allows it to take human prompts into account, making it particularly powerful for Human In The Loop annotation. These prompts can be multi-modal: they can be points on the area to be segmented, a bounding box around the object to be segmented or a text prompt about what should be segmented.

이 모델은 이미지 인코더, 프롬프트 인코더, 마스크 디코더의 세 가지 구성 요소로 구성되어 있습니다.

The model is structured into 3 components: an image encoder, a prompt encoder and a mask decoder.

세그먼트 애니씽(SA; Segment Anything (SA)) 모델의 기초 모델 아키텍처를 표시하는 이미지

[출처/Source]

이미지 인코더는 세그먼트되는 이미지에 대한 임베딩을 생성하고, 프롬프트 인코더는 프롬프트에 대한 임베딩을 생성합니다. 이미지 인코더는 모델에서 특히 큰 부분을 차지하는 구성 요소입니다. 이는 임베딩을 기반으로 세그멘테이션 마스크를 예측하는 경량 마스크 디코더와 대조되는 부분입니다. 메타 AI는 10억 개의 세그먼트 마스크(SA-1B) 데이터 세트에서 학습된 모델의 가중치와 편향을 모델 체크포인트로 사용할 수 있도록 했습니다. Segment Anything의 작동 방식에 대한 자세한 내용은 설명 블로그 게시물 여기를 참고하세요.

The image encoder generates an embedding for the image being segmented, whilst the prompt encoder generates an embedding for the prompts. The image encoder is a particularly large component in the model. This is in contrast to the lightweight mask decoder, which predicts segmentation masks based on the embeddings. Meta AI has made the weights and biases of the model trained on the Segment Anything 1 Billion Mask (SA-1B) dataset available as a model checkpoint. Learn more about how Segment Anything works in our explainer blog post here.


모델 미세 조정이란 무엇인가요? / What is Model Fine-tuning?


공개적으로 사용 가능한 최신 모델들(State-of-the-Art)은 각자의 아키텍처를 가지고 있으며 일반적으로 사전 학습된 모델 가중치와 함께 제공됩니다. 이러한 아키텍처가 가중치 없이 제공된다면 사용자가 모델을 처음부터 다시 학습시켜야 하며, 최신 성능을 얻기 위해 방대한 데이터 세트를 사용해야 합니다. > Publicly available state of the art models have a custom architecture and are typically supplied with pre-trained model weights. If these architectures were supplied without weights then the models would need to be trained from scratch by the users, who would need to use massive datasets to obtain state of the art performance.

모델 미세 조정은 사전 학습된 모델(아키텍처+가중치)을 가져와 특정 사용 사례에 대한 데이터를 학습시키는 프로세스입니다. 이러한 데이터는 일반적으로 모델이 이전에 보지 못했거나 원래 학습 데이터 세트에서 제대로 표현되지 않은 데이터입니다.

Model fine tuning is the process of taking a pre-trained model (architecture+weights) and showing it data for a particular use case. This will typically be data that the model hasn’t seen before, or that is underrepresented in its original training dataset.

모델을 미세 조정하는 것과 처음부터 시작하는 것의 차이점은 가중치(weight)와 편향(bias)의 시작 값입니다. 처음부터 학습하는 경우, 이러한 값은 특정 전략에 따라 무작위로 초기화됩니다. 이러한 시작 구성에서는 모델이 당면한 작업에 대해 '아무것도 모르기 때문에' 성능이 저하될 수 있습니다. 기존의 가중치와 편향을 시작점으로 사용하면 가중치와 편향을 '미세 조정'하여 모델이 사용자 지정 데이터 세트에서 더 잘 작동하도록 할 수 있습니다. 예를 들어, 고양이를 인식하기 위해 학습한 정보(가장자리 감지, 발 개수 세기)는 개를 인식하는 데에도 유용할 수 있습니다.

The difference between fine tuning the model and starting from scratch is the starting value of the weights and biases. If we were training from scratch, these would be randomly initialised according to some strategy. In such a starting configuration, the model would ‘know nothing’ of the task at hand and perform poorly. By using pre existing weights and biases as a starting point we can ‘fine tune’ the weights and biases so that our model works better on our custom dataset. For example: the information learnt to recognise cats (edge detection, counting paws) will be useful for recognising dogs.


모델을 미세 조정하는 이유는 무엇인가요? / Why Would I Fine-tune a Model?


모델을 미세 조정하는 목적은 사전 학습된 모델이 이전에 보지 못했던 데이터에서 더 높은 성능을 얻기 위한 것입니다. 예를 들어, 휴대폰 카메라에서 수집한 방대한 데이터로 학습된 이미지 분할 모델은 대부분 수평 관점에서 이미지를 보았을 것입니다. > The purpose of fine tuning a model is to obtain higher performance on data which the pre-trained model has not seen before. For example, an image segmentation model trained on a broad corpus of data gathered from phone cameras will have mostly seen images from a horizontal perspective.

이 모델을 수직 관점에서 촬영한 위성 이미지에 사용하려고 하면 성능이 좋지 않을 수 있습니다. 지붕(rooftop)을 세그먼트(segment)하려고 할 때 이 모델을 사용하면 최상의 결과를 얻지 못할 수도 있습니다. 사전 학습은 모델이 일반적으로 물체를 세그먼트하는 방법을 학습했기 때문에 유용하며, 이 시작점을 활용하여 지붕을 정확하게 세분화할 수 있는 모델을 구축하고자 합니다. 또한 사용자 지정 데이터 세트에는 수백만 개의 예가 없을 가능성이 높으므로 모델을 처음부터 학습시키는 대신 미세 조정을 하고자 합니다.

If we tried to use this model for satellite imagery taken from a vertical perspective, it may not perform as well. If we were trying to segment rooftops, the model may not yield the best results. The pre-training is useful because the model will have learnt how to segment objects in general, so we want to take advantage of this starting point to build a model which can accurately segment rooftops. Furthermore, it is likely that our custom dataset would not have millions of examples, so we want to fine tune instead of training the model from scratch.

모델을 처음부터 학습시키는 연산 비용을 들이지 않고도 특정 사용 사례에서 더 나은 성능을 얻을 수 있도록 미세 조정을 하는 것이 바람직합니다.

Fine tuning is desirable so that we can obtain better performance on our specific use case, without having to incur the computational cost of training a model from scratch.


Segment Anything 모델 미세조정하기 [코드 포함] / How to Fine-tune Segment Anything Model [With Code]

배경 및 아키텍처 / Background & Architecture

소개 섹션에서는 SAM 아키텍처에 대한 개요를 설명했습니다. 이미지 인코더는 많은 매개 변수가 있는 복잡한 아키텍처를 가지고 있습니다. 모델을 미세 조정하려면 가볍고 미세 조정이 더 쉽고 빠르며 메모리 효율이 높은 마스크 디코더에 집중하는 것이 좋습니다.

We gave an overview of the SAM architecture in the introduction section. The image encoder has a complex architecture with many parameters. In order to fine tune the model, it makes sense for us to focus on the mask decoder which is lightweight and therefore easier, faster and more memory efficient to fine tune.

SAM을 미세 조정하려면 아키텍처의 기본 요소(이미지 및 프롬프트 인코더, 마스크 디코더)를 추출해야 합니다. 두 가지 이유로 SamPredictor.predict(링크)를 사용할 수 없습니다:

In order to fine tune SAM, we need to extract the underlying pieces of its architecture (image and prompt encoders, mask decoder). We cannot use SamPredictor.predict (link) for two reasons:

  • 마스크 디코더만 미세 조정하고 싶습니다.
  • 이 함수는 @torch.no_grad() 데코레이터(링크)가 있는 SamPredictor.predict_torch를 호출하여 변화도(gradient)를 계산하지 못하도록 합니다.
  • We want to fine tune only the mask decoder
  • This function calls SamPredictor.predict_torch which has the @torch.no_grad() decorator (link), which prevents us from computing gradients

따라서 SamPredictor.predict 함수를 살펴보고 미세 조정하려는 부분(마스크 디코더)에서 변화도(gradient) 계산을 활성화한 상태로 적절한 함수를 호출해야 합니다. 이렇게 하면 SAM의 작동 방식에 대해 자세히 알아볼 수 있습니다.

Thus, we need to examine the SamPredictor.predict function and call the appropriate functions with gradient calculation enabled on the part we want to fine tune (the mask decoder). Doing this is also a good way to learn more about how SAM works.

사용자 정의 데이터셋 만들기 / Creating a Custom Dataset

모델을 미세 조정하려면 세 가지가 필요합니다:

We need three things to fine tune our model:

  • 세그먼트를 할 대상 이미지
  • 세그먼트 정답(ground truth) 마스크(mask)
  • 모델에 입력할 프롬프트 - 여기서는 바운딩 박스를 사용합니다
  • Images on which to draw segmentations
  • Segmentation ground truth masks
  • Prompts to feed into the model, I am using bounding boxes

스탬프(stamp) 검증 데이터셋(링크를 선택한 이유는 SAM이 훈련에서 보지 못한 데이터(예: 문서에 찍힌 도장)를 포함하고 있기 때문입니다.) 이 데이터셋에 대해 사전 학습된 가중치로 추론을 실행하면 완벽하지는 않지만 잘 작동하는 것을 확인할 수 있습니다. 정밀한 정답(ground truth) 마스크도 있기 때문에 손실(loss)을 정확히 계산할 수 있습니다. 마지막으로, 이 데이터셋에는 세그먼트 마스크 주위에 바운딩 박스가 포함되어 있어 SAM에 대한 프롬프트로 사용할 수 있습니다. 예시 이미지를 아래에 표시해봤습니다. 이러한 바운딩 박스는 사람이 세그먼트를 생성할 때 거치는 워크플로우와 잘 일치합니다.

I chose the stamp verification dataset (link) since it has data which SAM may not have seen in its training (i.e., stamps on documents). I can verify that it performs well, but not perfectly, on this dataset by running inference with the pre-trained weights. The ground truth masks are also extremely precise, which will allow us to calculate accurate losses. Finally, this dataset contains bounding boxes around the segmentation masks, which we can use as prompts to SAM. An example image is shown below. These bounding boxes align well with the workflow that a human annotator would go through when looking to generate segmentations.

입력 데이터 전처리 / Input Data Preprocessing

스캔 이미지를을 NumPy 배열에서 파이토치 텐서(PyTorch Tensor)로 전처리해야 합니다. 이를 위해 이미지를 전처리하는 SamPredictor.set_image (링크) 및 SamPredictor.set_torch_image (링크) 내부에서 일어나는 일을 따라가면 됩니다. 먼저, 예측기 내부에서 이미지 변환기(transformer, 링크)로 사용하고 있는 utils.transform.ResizeLongestSide를 써서 이미지의 크기를 조정합니다. 그런 다음 이미지를 파이토치 텐서로 변환하고, SAM 전처리 메서드(link를 사용하여 전처리를 완료할 수 있습니다.

We need to preprocess the scans from numpy arrays to pytorch tensors. To do this, we can follow what happens inside SamPredictor.set_image (link) and SamPredictor.set_torch_image (link) which preprocesses the image. First, we can use utils.transform.ResizeLongestSide to resize the image, as this is the transformer used inside the predictor (link). We can then convert the image to a pytorch tensor and use the SAM preprocess method (link) to finish preprocessing.

학습 설정 / Training Setup

vit_b 모델에 대한 모델 체크포인트를 다운로드하여 로드합니다:

We download the model checkpoint for the vit_b model and load them in:

sam_model = sam_model_registry['vit_b'](checkpoint='sam_vit_b_01ec64.pth')

기본값으로 Adam 옵티마이저를 설정하고 조정할 파라미터를 마스크 디코더의 파라미터로 지정합니다:

We can set up an Adam optimizer with defaults and specify that the parameters to tune are those of the mask decoder:

optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters())

또한, MSE(Mean Sequared Error)와 같은 손실 함수를 설정합니다.

At the same time, we can set up our loss function, for example Mean Squared Error

loss_fn = torch.nn.MSELoss()

학습 루프 / Training Loop

메인 학습 루프에서는 데이터 항목을 반복하며 마스크를 생성하고, 이를 정답(ground truth) 데이터 마스크와 비교하여 손실 함수를 기반으로 모델 파라미터를 최적화할 수 있습니다.

In the main training loop, we will be iterating through our data items, generating masks and comparing them to our ground truth masks so that we can optimise the model parameters based on the loss function.

이 예제에서는 훨씬 빠른 속도를 위해 CPU보다 GPU를 사용했습니다. GPU에서 사용할 대상 텐서들에 .to(device) 를 사용하여 CPU가 아닌 GPU에 두는 것이 중요합니다.

In this example we used a GPU for training since it is much faster than using a CPU. It is important to use **.to(device)**on the appropriate tensors to make sure that we don’t have certain tensors on the CPU and others on the GPU.

또한 이미지 인코더를 미세 조정하려는 것이 아니기 때문에, torch.no_grad() 컨텍스트 관리자로 인코더 부분을 감싸는 식으로 이미지를 임베드하여 메모리 문제를 피합니다.

We want to embed images by wrapping the encoder in the torch.no_grad() context manager, since otherwise we will have memory issues, along with the fact that we are not looking to fine tune the image encoder.

with torch.no_grad():
    image_embedding = sam_model.image_encoder(input_image)

no_grad 컨텍스트 관리자 내에서 프롬프트 임베딩을 생성할 수도 있습니다. 여기서는 파이토치 텐서로 변환된 바운딩 박스 좌표를 사용합니다.

We can also generate the prompt embeddings within the no_grad context manager. We use our bounding box coordinates, converted to pytorch tensors.

with torch.no_grad():
      sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
          points=None,
          boxes=box_torch,
          masks=None,
      )

마지막으로 마스크를 생성합니다. 여기서는 (일반적으로 출력되는 3개의 마스크가 아닌) 단일 마스크 생성 모드임을 참고하세요.

Finally, we can generate the masks. Note that here we are in single mask generation mode (in contrast to the 3 masks that are normally output).

low_res_masks, iou_predictions = sam_model.mask_decoder(
  image_embeddings=image_embedding,
  image_pe=sam_model.prompt_encoder.get_dense_pe(),
  sparse_prompt_embeddings=sparse_embeddings,
  dense_prompt_embeddings=dense_embeddings,
  multimask_output=False,
)

마지막 단계는 마스크의 해상도가 낮기 때문에 원래 이미지 크기로 다시 업스케일링하는 것입니다. 이를 위해 Sam.postprocess_masks를 사용합니다. 또한 예측된 마스크에서 바이너리 마스크를 생성하여 정답과 비교할 수 있도록 해야 합니다. 역전파 시 문제가 생기지 않도록 하기 위해 torch functional의 함수들을 사용하는 것이 중요합니다.

The final step here is to upscale the masks back to the original image size, since they are low resolution. We can use Sam.postprocess_masks to achieve this. We will also want to generate binary masks from the predicted masks so that we can compare these to our ground truths. It is important to use torch functionals in order to not break backpropagation.

upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)

from torch.nn.functional import threshold, normalize

binary_mask = normalize(threshold(upscaled_masks, 0.0, 0)).to(device)

마지막으로 손실을 계산하고 최적화 단계를 실행합니다:

Finally we can calculate the loss and run an optimisation step:

loss = loss_fn(binary_mask, gt_binary_mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()

이 작업을 여러 에포크(epoch)와 배치(batch)에 걸쳐 반복하면서 SAM 디코더를 미세 조정합니다.

By repeating this over a number of epochs and batches we can fine tune the SAM decoder.

체크포인트를 저장하고 모델 불러오기 / Saving Checkpoints and Starting a Model from it

학습 완료 후 성능 향상이 만족스러우면:

Once we are done with training and satisfied by the performance uplift, we can use:

torch.save(model.state_dict(), PATH)

위 코드를 사용하여 조정된 모델을 상태 사전(state dict)를 저장합니다. 모델을 미세 조정하는 데 사용한 데이터와 유사한 데이터에 대해 추론을 수행할 때 이 상태 사전(모델)을 불러오면 됩니다.

to save the state dict of the tuned model. We can then load this state dict when we want to perform inference on data that is similar to the data we used to fine tune the model.

**SAM을 미세 조정하는 데 필요한 모든 코드가 포함된 Colab Notebook은 여기에서 찾을 수 있습니다. 바로 사용할 수 있는 솔루션을 원하신다면 계속 읽어보세요!
You can find the Colab Notebook with all the code you need to fine-tune SAM here. Keep reading if you want a fully working solution out of the box!


(역자 주: 아래는 encord 플랫폼을 사용하는 부분을 다루고 있습니다.)


다운스트림 애플리케이션을 위한 미세 조정 / Fine-tuning for Downstream Applications


SAM은 현재 즉시 사용 가능한 미세 조정 기능을 제공하지는 않지만, Encord 플랫폼과 통합된 맞춤형 미세 조정기를 구축하고 있습니다. 이 게시물에서 볼 수 있듯이 이를 위해 디코더를 미세 튜닝하고 있습니다. 이 기능은 웹 앱에서 원클릭으로 바로 사용할 수 있으며, 하이퍼파라미터가 자동으로 설정됩니다. > While SAM does not currently offer fine-tuning out of the box, we are building a custom fine tuner integrated with the Encord platform. As shown in this post, we fine tune the decoder in order to achieve this. This is available as an out of the box one click procedure in the web app, where the hyperparameters are automatically set.

Image displaying training the Segment Anything Model (SAM) in the Encord platform

오리지널 바닐라 SAM 마스크:

Original vanilla SAM mask:

Image of the original vanilla SAM mask

모델을 미세 조정하여 생성된 마스크:

Mask generated by fine tuned version of the model:

Image of the mask generated by the fine tuned version of the model

이 마스크가 원래 마스크보다 더 타이트한 것을 볼 수 있습니다. 이는 스탬프 검증 데이터 세트에서 이미지의 작은 하위 집합을 미세 조정한 다음 이전에 볼 수 없었던 예제에서 조정된 모델을 실행한 결과입니다. 추가 학습과 더 많은 예제를 통해 더 나은 결과를 얻을 수 있습니다.

We can see that this mask is tighter than the original mask. This was the result of fine tuning on a small subset of images from the stamp verification dataset, and then running the tuned model on a previously unseen example. With further training and more examples we could obtain even better results.

기초 모델을 나만의 것으로 만들기

Make foundation models your own.

Label with SAM in Encord

결론 / Conclusion

여기까지입니다!

That's all, folks!

이제 세그먼트 애니씽 모델(SAM)을 미세 조정하는 방법을 배웠습니다. SAM을 즉시 미세 조정하고 싶다면, 코드를 작성하지 않고도 모델을 미세 조정할 수 있는 최근 출시된 세그먼트 애니씽 모델이 있다는 사실도 알아두면 좋을 것 같습니다. 여기를 클릭하여 encord 무료 평가판을 사용해 보세요.

You have now learned how to fine-tune the Segment Anything Model (SAM). If you're looking to fine-tune SAM out of the box, you might also be interested to learn that we have recently released the Segment Anything Model in Encord, allowing you to fine-tune the model without writing any code. Click here for a free trial.

2개의 좋아요