Back to Course
Deep Learning with PyTorch
Module 4 of 12
4. Transfer Learning
Standing on Giants
Don't train from scratch. Use a model that searched the entire internet (ResNet, EfficientNet).
pythonimport torchvision.models as models resnet = models.resnet18(pretrained=True) # Freeze the brain for param in resnet.parameters(): param.requires_grad = False # Replace the last layer for YOUR problem resnet.fc = nn.Linear(512, 2) # Dog vs Cat