from collections import abc, Counter
from functools import lru_cache
from itertools import chain
import matplotlib.pyplot as plt
import mmh3
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from htools import hdir, eprint, assert_raises, flatten
sents = [
'I walked to the store so I hope it is not closed.',
'The theater is closed today and the sky is grey.',
'His dog is brown while hers is grey.'
]
labels = [0, 1, 1]
class Data(Dataset):
def __init__(self, sentences, labels, seq_len):
x = [s.split(' ') for s in sentences]
self.w2i = self.make_w2i(x)
self.seq_len = seq_len
self.x = self.encode(x)
self.y = torch.tensor(labels)
def __getitem__(self, i):
return self.x[i], self.y[i]
def __len__(self):
return len(self.y)
def make_w2i(self, tok_rows):
return {k: i for i, (k, v) in
enumerate(Counter(chain(*tok_rows)).most_common(), 1)}
def encode(self, tok_rows):
enc = np.zeros((len(tok_rows), self.seq_len), dtype=int)
for i, row in enumerate(tok_rows):
trunc = [self.w2i.get(w, 0) for w in row[:self.seq_len]]
enc[i, :len(trunc)] = trunc
return torch.tensor(enc)
ds = Data(sents, labels, 10)
ds[1]
dl = DataLoader(ds, batch_size=3)
x, y = next(iter(dl))
x, y
x.shape
ds.x
ds.w2i
For now, just convert int to str and take hash. Another option that is meant for ints is Knuth's multiplicative method:
hash(i) = i*2654435761 mod 2^32
But we'd need to make this dependent on a random seed.
def probabilistic_hash_item(x, n_buckets, mode=int, n_hashes=3):
"""Slightly hacky way to probabilistically hash an integer by
first converting it to a string.
Parameters
----------
x: int
The integer or string to hash.
n_buckets: int
The number of buckets that items will be mapped to. Typically
this would occur outside the hashing function, but since
the intended use case is so narrow here it makes sense to me
to include it here.
mode: type
The type of input you want to hash. This is user-provided to prevent
accidents where we pass in a different item than intended and hash
the wrong thing. One of (int, str). When using this inside a
BloomEmbedding layer, this must be `int` because there are no
string tensors. When used inside a dataset or as a one-time
pre-processing step, you can choose either as long as you
pass in the appropriate inputs.
n_hashes: int
The number of times to hash x, each time with a different seed.
Returns
-------
list[int]: A list of integers with length `n_hashes`, where each integer
is in [0, n_buckets).
"""
# Check type to ensure we don't accidentally hash Tensor(5) instead of 5.
assert isinstance(x, mode), f'Input `x` must have type {mode}.'
return [mmh3.hash(str(x), i, signed=False) % n_buckets
for i in range(n_hashes)]
def probabilistic_hash_tensor(x_r2, n_buckets, n_hashes=3, pad_idx=0):
"""Hash a rank 2 LongTensor.
Parameters
----------
x_r2: torch.LongTensor
Rank 2 tensor of integers. Shape: (bs, seq_len)
n_buckets: int
Number of buckets to hash items into (i.e. the number of
rows in the embedding matrix). Typically a moderately large
prime number, like 251 or 997.
n_hashes: int
Number of hashes to take for each input index. This determines
the number of rows of the embedding matrix that will be summed
to get the representation for each word. Typically 2-5.
pad_idx: int or None
If you want to pad sequences with vectors of zeros, pass in an
integer (same as the `padding_idx` argument to nn.Embedding).
If None, no padding index will be used. The sequences must be
padded before passing them into this function.
Returns
-------
torch.LongTensor: Tensor of indices where each row corresponds
to one of the input indices. Shape: (bs, seq_len, n_hashes)
"""
return torch.tensor(
[[probabilistic_hash_item(x.item(), n_buckets, int, n_hashes)
if x != pad_idx else [pad_idx]*n_hashes for x in row]
for row in x_r2]
)
for i, row in enumerate(x):
print(i)
print(row)
print([x.item() for x in row], end='\n\n')
probabilistic_hash_tensor(x, 11)
probabilistic_hash_tensor(x[0, None], 11)
probabilistic_hash_tensor(x[2, None], 11)
x
[probabilistic_hash_item(n.item(), 11) for n in x[0]]
for i in range(0, 200, 17):
print(probabilistic_hash_item(i, 11))
for row in [s.split(' ') for s in sents]:
eprint(list(zip(row, (probabilistic_hash_item(word, 11, str) for word in row))))
print()
assert isinstance(np.array(sents), (list, np.ndarray)), 'np array'
assert isinstance(sents, (list, np.ndarray)), 'list'
with assert_raises(AssertionError):
assert isinstance(x, (list, np.ndarray)), 'torch tensor'
for t in (abc.Iterable, abc.Collection, abc.Container, abc.Sequence, abc.MutableSequence):
tname = t.__name__
print(tname, 'tensor', isinstance(x, t))
print(tname, 'list', isinstance(sents, t))
print(tname, 'array', isinstance(np.array(sents), t))
print()
emb = nn.EmbeddingBag(5, 4, mode='sum')
emb.weight
x = torch.Tensor([[1, 2, 1],
[3, 2, 0]]).long()
emb(x)
class BloomEmbedding(nn.Module):
"""Bloom Embedding layer for memory-efficient word representations.
Each word is encoded by a combination of rows of the embedding
matrix. The number of rows can therefore be far lower than the number
of words in our vocabulary while still providing unique representations.
The reduction in rows allows us to use memory in other ways:
a larger embedding dimension, more or larger layers after the embedding,
or larger batch sizes.
"""
def __init__(self, n_emb=251, emb_dim=100, n_hashes=4, padding_idx=0,
pre_hashed=False):
"""
Parameters
----------
n_emb: int
Number of rows to create in the embedding matrix. A prime
number is recommended. Lower numbers will be more
memory-efficient but increase the chances of collisions.
emb_dim: int
Size of each embedding. If emb_dim=100, each word will
be represented by a 100-dimensional vector.
n_hashes: int
This determines the number of hashes that will be taken
for each word index, and as a result, the number of rows
that will be summed to create each unique representation.
The higher the number, the lower the chances of a collision.
padding_idx: int or None
If an integer is provided, this will set aside the corresponding
row in the embedding matrix as a vector of zeros. If None, no
padding vector will be allocated.
pre_hashed: bool
Pass in True if the input tensor will already be hashed by the time
it enters this layer (you may prefer pre-compute the hashes in the
Dataset to save computation time during training). In this
scenario, the layer is a simple embedding bag with mode "sum".
Pass in False if the inputs will be word indices that have not yet
been hashed. In this case, hashing will be done inside the `forward`
call.
Suggested values for a vocab size of ~30,000:
| n_emb | n_hashes | unique combos |
|-------|----------|---------------|
| 127 | 5 | 29,998 |
| 251 | 4 | 29,996 |
| 997 | 3 | 29,997 |
| 5,003 | 2 | 29,969 |
"""
super().__init__()
self.n_emb = n_emb
self.emb = nn.Embedding(n_emb, emb_dim, padding_idx=padding_idx)
self.n_hashes = n_hashes
self.pad_idx = padding_idx
self.pre_hashed = pre_hashed
def forward(self, x):
"""
Parameters
----------
x: torch.LongTensor
Input tensor of word indices (bs x seq_len) if pre_hashed is False.
Hashed indices (bs x seq_len x n_hashes) if pre_hashed is False.
Returns
-------
torch.FloatTensor: Words encoded with combination of embeddings.
(bs x seq_len x emb_dim)
"""
if not self.pre_hashed:
# (bs, seq_len) -> hash -> (bs, seq_len, n_hashes)
hashed = probabilistic_hash_tensor(x,
self.n_emb,
self.n_hashes,
self.pad_idx)
# (bs, seq_len, n_hashes, emb_dim) -> sum -> (bs, seq_len, emb_dim)
return self.emb(hashed).sum(-2)
x, y = next(iter(dl))
x, y
be = BloomEmbedding(11, 4)
be.emb.weight
x
for i in range(24):
print(probabilistic_hash_item(i, 11))
# (bs x seq_len) -> (bs -> seq_len -> emb_size)
y = be(x)
y.shape
y[0]
y[1]
y[2]
for w, i in ds.w2i.items():
print(w, i, be(torch.tensor([[i]])).detach().numpy().squeeze())
# .emb.weight[hash_int(i, be.n_emb)])
hashed = probabilistic_hash_tensor(torch.tensor([23]).unsqueeze(0), 11, 4)
hashed
be.emb.weight[hashed].sum(2)
def unique_combos(tups):
return len(set(tuple(sorted(x)) for x in tups))
def hash_all_idx(vocab_size, n_buckets, n_hashes):
return [probabilistic_hash_item(i, n_buckets, int, n_hashes)
for i in range(vocab_size)]
buckets2hashes = {127: 5,
251: 4,
997: 3,
5_003: 2}
for b, h in buckets2hashes.items():
tups = hash_all_idx(30_000, b, h)
unique = unique_combos(tups)
print('Buckets:', b, 'Hashes:', h, 'Unique combos:', unique,
'% unique:', round(unique/30_000, 5))
def eval_n_buckets(vocab_size, hash_sizes, bucket_sizes):
for bs in bucket_sizes:
for hs in hash_sizes:
tups = hash_all_idx(vocab_size, bs, hs)
unique = unique_combos(tups)
print('buckets:', bs,
'hashes:', hs,
'unique:', round(unique/vocab_size, 4))
eval_n_buckets(80, range(2, 6), [5, 11, 13, 19, 29, 37])
x = torch.randint(0, 30_000, (64, 500))
x.shape
%%timeit -n 5 -r 5
hashed = probabilistic_hash_tensor(x, 127, 5)
%%timeit -n 5 -r 5
hashed = probabilistic_hash_tensor(x, 251, 5)
%%timeit -n 5 -r 5
hashed = probabilistic_hash_tensor(x, 997, 5)
%%timeit -n 5 -r 5
hashed = probabilistic_hash_tensor(x, 5_003, 5)
%%timeit -n 5 -r 5
hashed = probabilistic_hash_tensor(x, 251, 4)
%%timeit -n 5 -r 5
hashed = probabilistic_hash_tensor(x, 997, 3)
%%timeit -n 5 -r 5
hashed = probabilistic_hash_tensor(x, 5_003, 2)
plt.scatter(range(2, 6), [284, 303, 318, 339], s=[127, 251, 997, 5_003])
plt.xlabel('# of Hashes')
plt.ylabel('Time (ms) \nto encode \na batch with \n32,000 words',
rotation=0, labelpad=40)
plt.show()
Try memoizing hash func.
Caching results does save computation but it increases memory usage, which is one of the benefits of Bloom Embeddings. However, this is probably still more memory efficient since each word in the cache is represented by only a few indices rather than a large embedding. A large embedding matrix also forces us to devote a lot of memory to unused gradients during training (if the layer is not frozen). Still, the fact that probabilistic_hash_tensor
is not vectorized means that it may be best to pre-compute these and load the indices in the Dataset. Rather than caching, it may be easiest to create a w2hash dict similar to a standard w2index dict. Leaving this implementation lets the user determine what tradeoff they want between speed and memory.
@lru_cache(maxsize=30_000)
def probabilistic_hash_item(x, n_buckets, mode=int, n_hashes=3):
"""Slightly hacky way to probabilistically hash an integer by
first converting it to a string.
Parameters
----------
x: int
The integer or string to hash.
n_buckets: int
The number of buckets that items will be mapped to. Typically
this would occur outside the hashing function, but since
the intended use case is so narrow here it makes sense to me
to include it here.
mode: type
The type of input you want to hash. This is user-provided to prevent
accidents where we pass in a different item than intended and hash
the wrong thing. One of (int, str). When using this inside a
BloomEmbedding layer, this must be `int` because there are no
string tensors. When used inside a dataset or as a one-time
pre-processing step, you can choose either as long as you
pass in the appropriate inputs.
n_hashes: int
The number of times to hash x, each time with a different seed.
Returns
-------
list[int]: A list of integers with length `n_hashes`, where each integer
is in [0, n_buckets).
"""
# Check type to ensure we don't accidentally hash Tensor(5) instead of 5.
assert isinstance(x, mode), f'Input `x` must have type {mode}.'
return [mmh3.hash(str(x), i, signed=False) % n_buckets
for i in range(n_hashes)]
def probabilistic_hash_tensor(x_r2, n_buckets, n_hashes=3, pad_idx=0):
"""Hash a rank 2 LongTensor.
Parameters
----------
x_r2: torch.LongTensor
Rank 2 tensor of integers. Shape: (bs, seq_len)
n_buckets: int
Number of buckets to hash items into (i.e. the number of
rows in the embedding matrix). Typically a moderately large
prime number, like 251 or 997.
n_hashes: int
Number of hashes to take for each input index. This determines
the number of rows of the embedding matrix that will be summed
to get the representation for each word. Typically 2-5.
pad_idx: int or None
If you want to pad sequences with vectors of zeros, pass in an
integer (same as the `padding_idx` argument to nn.Embedding).
If None, no padding index will be used. The sequences must be
padded before passing them into this function.
Returns
-------
torch.LongTensor: Tensor of indices where each row corresponds
to one of the input indices. Shape: (bs, seq_len, n_hashes)
"""
return torch.tensor(
[[probabilistic_hash_item(x.item(), n_buckets, int, n_hashes)
if x != pad_idx else [pad_idx]*n_hashes for x in row]
for row in x_r2]
)
%%timeit -n 5 -r 5
hashed = probabilistic_hash_tensor(x, 5_003, 5)
%%timeit -n 5 -r 5
hashed = probabilistic_hash_tensor(x, 5_003, 2)
%%timeit -n 5 -r 5
hashed = probabilistic_hash_tensor(x, 5_003, 2)
probabilistic_hash_item.cache_info()
len(set(flatten(x.numpy())))
probabilistic_hash_item.cache_clear()
probabilistic_hash_item.cache_info()