a2_encoder.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:text_classification 作者: brightmart 项目源码 文件源码
def init():
    #1. assign value to fields
    vocab_size=1000
    d_model = 512
    d_k = 64
    d_v = 64
    sequence_length = 5*10
    h = 8
    batch_size=4*32
    initializer = tf.random_normal_initializer(stddev=0.1)
    # 2.set values for Q,K,V
    vocab_size=1000
    embed_size=d_model
    Embedding = tf.get_variable("Embedding_E", shape=[vocab_size, embed_size],initializer=initializer)
    input_x = tf.placeholder(tf.int32, [batch_size,sequence_length], name="input_x") #[4,10]
    print("input_x:",input_x)
    embedded_words = tf.nn.embedding_lookup(Embedding, input_x) #[batch_size*sequence_length,embed_size]
    Q = embedded_words  # [batch_size*sequence_length,embed_size]
    K_s = embedded_words  # [batch_size*sequence_length,embed_size]
    num_layer=6
    mask = get_mask(batch_size, sequence_length)
    #3. get class object
    encoder_class=Encoder(d_model,d_k,d_v,sequence_length,h,batch_size,num_layer,Q,K_s,mask=mask) #Q,K_s,embedded_words
    return encoder_class,Q,K_s
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号