The basics for building and training models are contained in this module.
%load_ext autoreload
%autoreload 2
%matplotlib inline
# Used in notebook but not needed in package.
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from htools import assert_raises
Model
BaseModel
allows models to freeze/unfreeze layers and provides several methods for weight diagnostics. It should not be instantiated directly, but used as a parent class for a model. Like all PyTorch models, its children will still need to call super().__init__()
and implement a forward()
method.
class SimpleModel(BaseModel):
def __init__(self, dim):
super().__init__()
self.fc1 = nn.Linear(dim, 2)
self.fc2 = nn.Linear(2, 1)
def forward(self, x):
x = F.leaky_relu(self.fc1(x))
return self.fc2(x)
class GroupedModel(BaseModel):
def __init__(self, dim):
super().__init__()
g1 = nn.Sequential(
nn.Linear(dim, 8),
nn.LeakyReLU(),
nn.Linear(8, 4),
nn.LeakyReLU()
)
g2 = nn.Linear(4, 1)
self.groups = nn.ModuleList([g1, g2])
def forward(self, x):
for group in self.groups:
x = group(x)
return x
snet = SimpleModel(2)
snet.freeze()
for n in range(5):
snet.unfreeze(n_layers=n)
unfrozen = [x[1] for x in snet.trainable()]
print('Unfrozen', unfrozen)
assert sum(unfrozen) == n
assert not any(unfrozen[:-n])
snet.freeze()
with assert_raises(AttributeError) as ar:
for n in range(3):
snet.unfreeze(n_groups=n)
gnet = GroupedModel(2)
gnet.freeze()
n_unfrozen = [0, 2, 6]
for n, nu in zip(range(3), n_unfrozen):
gnet.unfreeze(n_groups=n)
unfrozen = [x[1] for x in gnet.trainable()]
print('Unfrozen', unfrozen)
assert sum(unfrozen) == nu
gnet.freeze()
for n in range(7):
gnet.unfreeze(n_layers=n)
unfrozen = [x[1] for x in gnet.trainable()]
print('Unfrozen', unfrozen)
assert sum(unfrozen) == n
assert not any(unfrozen[:-n])
optim = variable_lr_optimizer(snet, 2e-3)
print(optim)
with assert_raises(ValueError) as ar:
optim = variable_lr_optimizer(snet, [3e-3, 1e-1])
optim
update_optimizer(optim, 1e-3, 0.5)
assert len(optim.param_groups) == 1
assert optim.param_groups[0]['lr'] == 1e-3
lrs = [1e-3, 3e-3]
optim = variable_lr_optimizer(gnet, lrs)
print(optim)
assert [group['lr'] for group in optim.param_groups] == lrs
update_optimizer(optim, 2e-3, lr_mult=1/3)
print([group['lr'] for group in optim.param_groups])
assert np.isclose(optim.param_groups[1]['lr'], optim.param_groups[0]['lr'] * 3)
optim = variable_lr_optimizer(gnet, 1e-3, lr_mult=0.5)
print([group['lr'] for group in optim.param_groups])
assert np.isclose(optim.param_groups[1]['lr'], optim.param_groups[0]['lr'] * 2)
optim = variable_lr_optimizer(snet, 2e-3)
print(optim)
with assert_raises(ValueError) as ar:
optim = variable_lr_optimizer(snet, [3e-3, 1e-1])
optim
update_optimizer(optim, 1e-3, 0.5)
assert len(optim.param_groups) == 1
assert optim.param_groups[0]['lr'] == 1e-3
lrs = [1e-3, 3e-3]
optim = variable_lr_optimizer(gnet, lrs)
print(optim)
assert [group['lr'] for group in optim.param_groups] == lrs
update_optimizer(optim, 2e-3, lr_mult=1/3)
print([group['lr'] for group in optim.param_groups])
assert np.isclose(optim.param_groups[1]['lr'], optim.param_groups[0]['lr'] * 3)
optim = variable_lr_optimizer(gnet, 1e-3, lr_mult=0.5)
print([group['lr'] for group in optim.param_groups])
assert np.isclose(optim.param_groups[1]['lr'], optim.param_groups[0]['lr'] * 2)