import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import os
from segment_anything import sam_model_registry
Define dataset with masks
class CustomDataset(torch.utils.data.Dataset):
def init(self, root_dir, mask_dir, transform=None):
self.root_dir = root_dir
self.mask_dir = mask_dir
self.transform = transform
self.images = os.listdir(self.root_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.images[idx])
mask_name = os.path.join(self.mask_dir, self.images[idx])
image = Image.open(img_name)
mask = Image.open(mask_name)
# Convert to RGB
# image = image.convert('RGB')
# mask = mask.convert('L')
if self.transform:
image = self.transform(image)
mask = self.transform(mask)
# Convert mask to binary
mask = torch.where(mask > 0, torch.tensor(1), torch.tensor(0))
return image, mask
Define transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
Load custom dataset
dataset = CustomDataset(root_dir='3/folder/train/org', mask_dir='3/folder/train/mask', transform=transform)
Create a DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
Load the model
sam = sam_model_registry"vit_b"
Move model to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sam.to(device)
sam.train()
Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(sam.parameters(), lr=0.001, momentum=0.9)
Training loop
for epoch in range(2):
running_loss = 0.0
for i, (inputs, labels) in enumerate(dataloader, 0):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = sam(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
ORG IMAGE
MASK IMAGE
현재 이러한 버그가 있는데
어떤 것이 원인인지 혹시 경험해보신 분 계실까요?
TypeError: Sam.forward() missing 1 required positional argument: 'multimask_output'