Darwin ✕ Torchvision
This tutorial shows how to train an instance segmentation model on a Darwin dataset using PyTorch's Torchvsion and darwin-py
. If you don't have PyTorch or Torchvision installed yet, please follow first these installation instructions.
Now, using darwin-py
's CLI, we will pull the dataset from Darwin and create train, validation, and test partitions.
darwin dataset pull v7-demo/bird-species
darwin dataset split v7-demo/bird-species --val-percentage 10 --test-percentage 20
Next, in Python, we will start by importing some torchvision
and darwin-py
functions, and by defining the function get_instance_segmentation_model
that we will use to instantiate a Mask-RCNN model using Torchvision's API.
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from darwin.torch import get_dataset
import darwin.torch.transforms as T
def collate_fn(batch):
return tuple(zip(*batch))
def get_instance_segmentation_model(num_classes):
# load an instance segmentation model pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
# add a new bounding box predictor
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# add a new mask predictor
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
hidden_layer,
num_classes)
return model
Then, we will load the dataset using darwin-py
's get_dataset
function, specifying the dataset slug, the dataset type (in this case we need an instance-segmentation
dataset), and the train
partition. The dataset that we get back can be used directly into Pytorch's standard DataLoader.
trfs_train = T.Compose([T.RandomHorizontalFlip(), T.ToTensor()])
dataset = get_dataset("v7-demo/bird-species", dataset_type="instance-segmentation",
partition="train", split_type="stratified", transform=trfs_train)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=4, collate_fn=collate_fn)
Next, we instantiate the instance segmentation model and define the optimizer and the learning rate scheduler.
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# get the model using our helper function
num_classes = dataset.num_classes + 1 # number of classes in the dataset + background
model = get_instance_segmentation_model(num_classes)
model.to(device)
# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.0001, momentum=0.9, weight_decay=0.0005)
# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
And finally, we write our training loop and train the model for 10 full epochs.
# let's train it for 10 epochs
for epoch in range(10):
# train for one epoch, printing every 10 iterations
print(f"Starting epoch {epoch}...")
acumm_loss = 0
for i, (images, targets) in enumerate(data_loader):
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items() if isinstance(v, torch.Tensor)} for t in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
optimizer.zero_grad()
losses.backward()
optimizer.step()
acumm_loss += losses.cpu().item()
if i % 10 == 0:
print(f"({i}/{len(data_loader)}) Loss: {acumm_loss/10}")
acumm_loss = 0
lr_scheduler.step()
Now that your model is trained, you can evaluate the model's performance using your own evaluation functions.
Updated 11 months ago
Are you familiar with Detectron2? Then you'll like to know that darwin-py also adds functionality to integrate your datasets in Facebook's library. If you're not familiar yet, then we think you might want to see that you can train and evaluate your models on your Darwin datasets in just a few lines of codes by combining the modularity and abstraction of Detectron2 and the APIs provided in darwin-py.