%load_ext autoreload
%autoreload 2
%matplotlib inline
# Used for testing only.
from collections import defaultdict, Counter
from itertools import chain
import matplotlib.pyplot as plt
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from htools import assert_raises, InvalidArgumentError, smap
from incendio.data import probabilistic_hash_item
import pandas_htools
def plot_activations(z, a, mode='scatter', **kwargs):
"""Plot an input tensor and its corresponding activations. Both tensors
will be flattened for plotting.
Parameters
----------
z: tf.Tensor
Tensor containing values to plot on the x axis (we can often think of
this as the output of a linear layer, where z=f(x) and a=mish(z)).
a: tf.Tensor
Tensor containing values to plot on y axis.
mode: str
'scatter' for scatter plot or 'plot' for line plot.
kwargs: Values to be passed to the matplotlib plotting function, such as
's' when in 'scatter' mode or 'lw' in 'plot' mode.
Returns
-------
None
"""
plt_func = getattr(plt, mode)
kwargs = kwargs or {}
if mode == 'scatter' and not kwargs:
kwargs = {'s': .75}
plt_func(z.numpy().flatten(), a.numpy().flatten(), **kwargs)
plt.axvline(0, lw=.5, alpha=.5)
plt.axhline(0, lw=.5, alpha=.5)
plt.show()
x = torch.arange(-5, 5, .05)
a = mish(x)
plot_activations(x, a, 'plot')
conv = ConvBlock(3, 5, norm=False)
conv
x = torch.rand(2, 3, 4, 4)
conv(x).shape
ResBlock(4)
ResBlock(4, norm=False)
def show_img(img):
plt.imshow(img.permute(1, 2, 0) / 255)
plt.show()
rconv = ReflectionPaddedConv2d(3, 3, kernel_size=1, padding=2)
rconv
x = torch.randint(255, (1, 3, 3, 3)).float()
show_img(x[0])
x2 = rconv.reflect(x)
show_img(x2[0])
# Tests
assert nn.Conv2d.__doc__ in ReflectionPaddedConv2d.__doc__
with assert_raises(InvalidArgumentError):
ReflectionPaddedConv2d(3, 3, padding_mode='zeros')
class Net(nn.Module):
def __init__(self):
super().__init__()
self.drop = Dropin()
def forward(self, x):
return self.drop(x)
net = Net()
x = torch.randn(8, 128, 128, 3)
assert np.corrcoef(net(x).flatten(), x.flatten())[0][1] > .9
net.eval()
assert torch.eq(net(x), x).all()
assert not net.drop.training
def simulate_activation_stats(scale=1.0, trials=10_000):
act_stats = defaultdict(list)
noise_stats = defaultdict(list)
drop = Dropin(scale)
for _ in range(trials):
x = torch.randn(3, 4, dtype=torch.float)
z = drop(x)
noise = drop.noise
noise_stats['mean'].append(noise.mean())
noise_stats['std'].append(noise.std())
noise_stats['act_corr'].append(
np.corrcoef(z.flatten(), noise.flatten())[0][1]
)
act_stats['mean'].append(z.mean())
act_stats['std'].append(z.std())
act_stats['x_corr'].append(
np.corrcoef(z.flatten(), x.flatten())[0][1]
)
return pd.DataFrame(dict(
act={k: np.mean(v).round(4) for k, v in act_stats.items()},
noise={k: np.mean(v).round(4) for k, v in noise_stats.items()}
))
for scale in [10, 1, .75, .5, .25, .1]:
print('\n', scale)
simulate_activation_stats(scale, 1_000).pprint()
InitializedEmbedding(4, 3, 0).weight
InitializedEmbedding(4, 3, 3).weight
InitializedEmbedding(4, 3).weight
class Data(Dataset):
def __init__(self, sentences, labels, seq_len):
x = [s.split(' ') for s in sentences]
self.w2i = self.make_w2i(x)
self.seq_len = seq_len
self.x = self.encode(x)
self.y = torch.tensor(labels)
def __getitem__(self, i):
return self.x[i], self.y[i]
def __len__(self):
return len(self.y)
def make_w2i(self, tok_rows):
return {k: i for i, (k, v) in
enumerate(Counter(chain(*tok_rows)).most_common(), 1)}
def encode(self, tok_rows):
enc = np.zeros((len(tok_rows), self.seq_len), dtype=int)
for i, row in enumerate(tok_rows):
trunc = [self.w2i.get(w, 0) for w in row[:self.seq_len]]
enc[i, :len(trunc)] = trunc
return torch.tensor(enc)
sents = [
'I walked to the store so I hope it is not closed.',
'The theater is closed today and the sky is grey.',
'His dog is brown while hers is grey.'
]
labels = [0, 1, 1]
ds = Data(sents, labels, 10)
ds[1]
dl = DataLoader(ds, batch_size=3)
x, y = next(iter(dl))
x, y
x, y = next(iter(dl))
x, y
be = BloomEmbedding(11, 4)
be.emb.weight
x
# (bs x seq_len) -> (bs -> seq_len -> emb_size)
y = be(x)
y.shape
y[0]
Below, we show by step how to get from x to y. This is meant to demonstrate the basic mechanism, not to show how PyTorch actually implements this under the hood. Let's look at a single row of x, corresponding to 1 sentence where each word is mapped to its index in the vocabulary.
x[0]
Next, we hash each item.
hashed = [probabilistic_hash_item(i.item(), 11, int, 4) for i in x[0]]
hashed
Then use each row of hashed integers to index into the embedding weight matrix.
output = []
for row in hashed:
row_out = be.emb.weight[row]
output.append(row_out)
output = torch.stack(output)
print(output.shape)
output[:2]
Finally, we sum up the embedding rows. Above, each word is represented by four rows of the embedding matrix. After summing, we get a single vector for each word.
output = output.sum(-2)
output
Notice that the values now match the output of our embedding layer.
assert torch.isclose(output, y[0]).all()
Axial encodings are intended to work as positional embeddings for transformer-like architectures. It's possible they could work for word embeddings as well, similar to our use of Bloom embeddings. However, the standard version of axial encodings results in similar vectors for adjacent indices - this makes some sense for positional indices, but for word indices it might require some additional preprocessing. For example, we could compress word embeddings down to 1 dimension and sort them, or simply sort by number of occurrences in our corpus which could be considered to be doing the same thing. Large chunks of the outputs vectors will be shared among different inputs, whereas Bloom embeddings seem like they would have a greater capacity to avoid this issue.
def reduction_ratio(ax, vocab_size, emb_dim):
"""For testing purposes. Lets us compare the number of weights in a
traditional embedding matrix vs. the number of weights in our axial
encoding.
"""
normal_n = vocab_size * emb_dim
ax_n = sum(e.weight.numel() for e in ax.emb)
print('Normal embedding weights:', normal_n)
print('Axial encoding weights:', ax_n)
print('Difference:', normal_n - ax_n)
print('Ratio:', normal_n / ax_n)
vocab_size = 30_000
emb_dim = 100
bs = 12
ax = AxialEncoding(vocab_size, emb_dim)
x = torch.randint(0, vocab_size, (bs, 2))
print(x.shape)
ax
res = ax(x)
print(res.shape)
reduction_ratio(ax, vocab_size, emb_dim)
vocab_size = 30_000
emb_dim = 100
bs = 12
ax = MultiAxialEncoding(vocab_size, emb_dim, 4)
x = torch.randint(0, vocab_size, (bs, 2))
print(x.shape)
ax
res1 = ax(x)
res1.shape
vocab_size = 30_000
emb_dim = 100
bs = 12
ax_pre = MultiAxialEncoding(vocab_size, emb_dim, 4, pre_hashed=True)
ax_pre
By setting the weights of our pre-hashed embedding to the weights of our hashing embedding, we can check that the outputs are ultimately the same.
for e, e_pre in zip(ax.emb, ax_pre.emb):
e_pre.weight.data = e.weight.data
xhash = probabilistic_hash_tensor(x, 14, 4)
res2 = ax_pre(xhash)
res2.shape
(res1 == res2).all()
reduction_ratio(ax_pre, vocab_size, emb_dim)
I imagine that as we increase n_blocks
, there's likely a point where we simply won't have enough weights to encode the amount of information that's present in the data. It would take some experimentation to find where that line is, however.
ax_large = MultiAxialEncoding(vocab_size, emb_dim, 8, pre_hashed=True)
ax_large
reduction_ratio(ax_large, vocab_size, emb_dim)
bs, c, h, w = 4, 3, 8, 8
n = 3
xb = [torch.randn(bs, c, h, w) for _ in range(n)]
smap(*xb)
class TripletNet(SiameseBase):
def __init__(self, c_in=3):
super().__init__()
self.conv = nn.Conv2d(c_in, 16, kernel_size=3, stride=2)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
def _forward(self, xb):
print(xb.shape)
xb = self.conv(xb)
print(xb.shape)
xb = self.pool(xb)
print(xb.shape)
xb = xb.squeeze(-1).squeeze(-1)
print(xb.shape)
return xb
In this example, each image is encoded as a 16D vector. We have 3 images per row and 4 rows per batch so we end up with a tensor of shape (4, 3, 16). Notice we only perform 1 forward pass: while we could simply define a separate encoder and pass each image through it separately (e.g. [self.encoder(x) for x in xb]
), this becomes rather slow if n is large or if our encoder is enormous.
tnet = TripletNet()
yh = tnet(*xb)
yh.shape
Our name TripletNet was slightly misleading here: the network can actually handle any choice of n. For instance, here we use it as a Siamese Net.
yh = tnet(*xb[:2])
yh.shape