def _build(self):
V = self.V
M = self.flags.embedding_size # 64
H = self.flags.num_units
C = self.flags.classes
netname = "CBOW"
with tf.variable_scope(netname):
self.inputs = tf.placeholder(dtype=tf.int32,shape=[None, None]) #[B,S]
layer_name = "{}/embedding".format(netname)
x = self._get_embedding(layer_name, self.inputs, V, M, reuse=False) # [B, S, M]
netname = "RNN"
cell_name = self.flags.cell
with tf.variable_scope(netname):
args = {"num_units":H,"num_proj":C}
cell_f = self._get_rnn_cell(cell_name=cell_name, args=args)
cell_b = self._get_rnn_cell(cell_name=cell_name, args=args)
(out_f, out_b), _ = tf.nn.bidirectional_dynamic_rnn(cell_f,cell_b,x,dtype=tf.float32)
#logit = (out_f[:,-1,:] + out_b[:,-1,:])*0.5 # [B,1,C]
logit = tf.reduce_mean(out_f+out_b,axis=1)
logit = tf.squeeze(logit) # [B,C]
self.logit = logit
评论列表
文章目录