simple_sentence_breaker.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:learning_rnn 作者: qiangsiwei 项目源码 文件源码
def train_breaker(datafilename, sentence_num=1000, puncs=u',?.?!???', \
            RNN=recurrent.GRU, HIDDEN_SIZE=128, EPOCH_SIZE=10, validate=True):
    wordtable = WordTable()
    wordtable.parse(datafilename, sentence_num)

    X, Y = [], []
    for line in open(datafilename).readlines()[:sentence_num]:
        line = line.strip().decode('utf-8')
        line = re.sub(ur'(^[{0}]+)|([{0}]+$)'.format(puncs),'',line)
        words = wordtable.encode(re.sub(ur'[{0}]'.format(puncs),'',line))
        breaks = re.sub(ur'0[{0}]+'.format(puncs),'1',re.sub(ur'[^{0}]'.format(puncs),'0',line))
        if len(words) >= 30 and len(words) <= 50 and breaks.count('1') >= 4:
            x = np.zeros((len(words), wordtable.capacity), dtype=np.bool)
            y = np.zeros((len(breaks), 2), dtype=np.bool)
            for idx in xrange(len(words)):
                x[idx][words[idx]] = True
                y[idx][int(breaks[idx])] = True
            X.append(x)
            Y.append(y)
    print 'total sentence: ', len(X)

    if validate:
        # Set apart 10% for validation
        split_at = len(X) - len(X)/10
        X_train, X_val = X[:split_at], X[split_at:]
        y_train, y_val = Y[:split_at], Y[split_at:]
    else:
        X_train, y_train = X, Y

    model = Graph()
    model.add_input(name='input', input_shape=(None, wordtable.capacity))
    model.add_node(RNN(HIDDEN_SIZE, return_sequences=True), name='forward', input='input')
    model.add_node(TimeDistributedDense(2, activation='softmax'), name='softmax', input='forward')
    model.add_output(name='output', input='softmax')
    model.compile('adam', {'output': 'categorical_crossentropy'})

    for epoch in xrange(EPOCH_SIZE):
        print "epoch: ", epoch
        for idx, (seq, label) in enumerate(zip(X_train, y_train)):
            loss, accuracy = model.train_on_batch({'input':np.array([seq]), 'output':np.array([label])}, accuracy=True)
            if idx % 20 == 0:
                print "\tidx={0}, loss={1}, accuracy={2}".format(idx, loss, accuracy)

    if validate:
        _Y, _P = [], []
        for (seq, label) in zip(X_val, y_val):
            y = label.argmax(axis=-1)
            p = model.predict({'input':np.array([seq])})['output'][0].argmax(axis=-1)
            _Y.extend(list(y))
            _P.extend(list(p))
        _Y, _P = np.array(_Y), np.array(_P)
        print "should break right: ", ((_P == 1)*(_Y == 1)).sum()
        print "should break wrong: ", ((_P == 0)*(_Y == 1)).sum()
        print "should not break right: ", ((_P == 0)*(_Y == 0)).sum()
        print "should not break wrong: ", ((_P == 1)*(_Y == 0)).sum()

    with open('wordtable_json.txt','w') as wordtable_file:
        wordtable_file.write(wordtable.to_json())
    with open('model_json.txt','w') as model_file:
        model_file.write(model.to_json())
    model.save_weights('model_weights.h5', overwrite=True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号