def generate_batch(batch_size, num_skips, skip_window):
global data_index
assert batch_size % num_skips == 0
assert num_skips <= 2 * skip_window
batch = np.ndarray(shape=(batch_size,num_skips), dtype=np.int32)
labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
span = 2 * skip_window + 1 # [ skip_window target skip_window ]
buffer = collections.deque(maxlen=span)
for _ in range(span):
buffer.append(data[data_index])
data_index = (data_index + 1) % len(data)
for i in range(batch_size):
mask = [1] * span #[1 1 1]
mask[skip_window] = 0 # [1 0 1]
batch[i, :] = list(compress(buffer, mask)) # all surrounding words
labels[i, 0] = buffer[skip_window] # the word at the center
buffer.append(data[data_index])
data_index = (data_index + 1) % len(data)
return batch, labels
评论列表
文章目录