python类reset_default_graph()的实例源码

unet.py 文件源码 项目:lung-cancer-detector 作者: YichenGong 项目源码 文件源码 阅读 37 收藏 0 点赞 0 评论 0
def __init__(self, channels=3, n_class=2, cost="cross_entropy", cost_kwargs={}, **kwargs):
        tf.reset_default_graph()

        self.n_class = n_class
        self.summaries = kwargs.get("summaries", True)

        self.x = tf.placeholder("float", shape=[None, None, None, channels])
        self.y = tf.placeholder("float", shape=[None, None, None, n_class])
        self.keep_prob = tf.placeholder(tf.float32) #dropout (keep probability)

        logits, self.variables, self.offset = create_conv_net(self.x, self.keep_prob, channels, n_class, **kwargs)

        self.cost = self._get_cost(logits, cost, cost_kwargs)

        self.gradients_node = tf.gradients(self.cost, self.variables)

        self.cross_entropy = tf.reduce_mean(cross_entropy(tf.reshape(self.y, [-1, n_class]),
                                                          tf.reshape(pixel_wise_softmax_2(logits), [-1, n_class])))

        self.predicter = pixel_wise_softmax_2(logits)
        self.correct_pred = tf.equal(tf.argmax(self.predicter, 3), tf.argmax(self.y, 3))
        self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32))
policy_agent.py 文件源码 项目:DeepPath 作者: xwhan 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def retrain():
    print 'Start retraining'
    tf.reset_default_graph()
    policy_network = PolicyNetwork(scope = 'supervised_policy')

    f = open(relationPath)
    training_pairs = f.readlines()
    f.close()

    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, 'models/policy_supervised_' + relation)
        print "sl_policy restored"
        episodes = len(training_pairs)
        if episodes > 300:
            episodes = 300
        REINFORCE(training_pairs, policy_network, episodes)
        saver.save(sess, 'models/policy_retrained' + relation)
    print 'Retrained model saved'
shrinkage.py 文件源码 项目:onsager_deep_learning 作者: mborgerding 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def show_shrinkage(shrink_func,theta,**kwargs):
    tf.reset_default_graph()
    tf.set_random_seed(kwargs.get('seed',1) )

    N = kwargs.get('N',500)
    L = kwargs.get('L',4)
    nsigmas = kwargs.get('sigmas',10)
    shape = (N,L)
    rvar = 1e-4
    r = np.reshape( np.linspace(0,nsigmas,N*L)*math.sqrt(rvar),shape)
    r_ = tfcf(r)
    rvar_ = tfcf(np.ones(L)*rvar)

    xhat_,dxdr_ = shrink_func(r_,rvar_ ,tfcf(theta))

    with tf.Session() as sess:
        sess.run( tf.global_variables_initializer() )
        xhat = sess.run(xhat_)
    import matplotlib.pyplot as plt
    plt.figure(1)
    plt.plot(r.reshape(-1),r.reshape(-1),'y')
    plt.plot(r.reshape(-1),xhat.reshape(-1),'b')
    if kwargs.has_key('title'):
        plt.suptitle(kwargs['title'])
    plt.show()
test_tf_util.py 文件源码 项目:combine-DT-with-NN-in-RL 作者: Burning-Bear 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def test_multikwargs():
    tf.reset_default_graph()
    x = tf.placeholder(tf.int32, (), name="x")
    with tf.variable_scope("other"):
        x2 = tf.placeholder(tf.int32, (), name="x")
    z = 3 * x + 2 * x2

    lin = function([x, x2], z, givens={x2: 0})
    with single_threaded_session():
        initialize()
        assert lin(2) == 6
        assert lin(2, 2) == 10
        expt_caught = False
        try:
            lin(x=2)
        except AssertionError:
            expt_caught = True
        assert expt_caught
test_tf_util.py 文件源码 项目:combine-DT-with-NN-in-RL 作者: Burning-Bear 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def test_multikwargs():
    tf.reset_default_graph()
    x = tf.placeholder(tf.int32, (), name="x")
    with tf.variable_scope("other"):
        x2 = tf.placeholder(tf.int32, (), name="x")
    z = 3 * x + 2 * x2

    lin = function([x, x2], z, givens={x2: 0})
    with single_threaded_session():
        initialize()
        assert lin(2) == 6
        assert lin(2, 2) == 10
        expt_caught = False
        try:
            lin(x=2)
        except AssertionError:
            expt_caught = True
        assert expt_caught
inception_v2_test.py 文件源码 项目:tf_classification 作者: visipedia 项目源码 文件源码 阅读 42 收藏 0 点赞 0 评论 0
def testUnknownImageShape(self):
    tf.reset_default_graph()
    batch_size = 2
    height, width = 224, 224
    num_classes = 1000
    input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
    with self.test_session() as sess:
      inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
      logits, end_points = inception.inception_v2(inputs, num_classes)
      self.assertTrue(logits.op.name.startswith('InceptionV2/Logits'))
      self.assertListEqual(logits.get_shape().as_list(),
                           [batch_size, num_classes])
      pre_pool = end_points['Mixed_5c']
      feed_dict = {inputs: input_np}
      tf.global_variables_initializer().run()
      pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
      self.assertListEqual(list(pre_pool_out.shape), [batch_size, 7, 7, 1024])
mobilenet_v1_test.py 文件源码 项目:tf_classification 作者: visipedia 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def testUnknownImageShape(self):
    tf.reset_default_graph()
    batch_size = 2
    height, width = 224, 224
    num_classes = 1000
    input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
    with self.test_session() as sess:
      inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
      logits, end_points = mobilenet_v1.mobilenet_v1(inputs, num_classes)
      self.assertTrue(logits.op.name.startswith('MobilenetV1/Logits'))
      self.assertListEqual(logits.get_shape().as_list(),
                           [batch_size, num_classes])
      pre_pool = end_points['Conv2d_13_pointwise']
      feed_dict = {inputs: input_np}
      tf.global_variables_initializer().run()
      pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
      self.assertListEqual(list(pre_pool_out.shape), [batch_size, 7, 7, 1024])
inception_v3_test.py 文件源码 项目:tf_classification 作者: visipedia 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def testUnknownImageShape(self):
    tf.reset_default_graph()
    batch_size = 2
    height, width = 299, 299
    num_classes = 1000
    input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
    with self.test_session() as sess:
      inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
      logits, end_points = inception.inception_v3(inputs, num_classes)
      self.assertListEqual(logits.get_shape().as_list(),
                           [batch_size, num_classes])
      pre_pool = end_points['Mixed_7c']
      feed_dict = {inputs: input_np}
      tf.global_variables_initializer().run()
      pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
      self.assertListEqual(list(pre_pool_out.shape), [batch_size, 8, 8, 2048])
inception_v1_test.py 文件源码 项目:tf_classification 作者: visipedia 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def testUnknownImageShape(self):
    tf.reset_default_graph()
    batch_size = 2
    height, width = 224, 224
    num_classes = 1000
    input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
    with self.test_session() as sess:
      inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
      logits, end_points = inception.inception_v1(inputs, num_classes)
      self.assertTrue(logits.op.name.startswith('InceptionV1/Logits'))
      self.assertListEqual(logits.get_shape().as_list(),
                           [batch_size, num_classes])
      pre_pool = end_points['Mixed_5c']
      feed_dict = {inputs: input_np}
      tf.global_variables_initializer().run()
      pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
      self.assertListEqual(list(pre_pool_out.shape), [batch_size, 7, 7, 1024])
plan_test.py 文件源码 项目:fold 作者: tensorflow 项目源码 文件源码 阅读 36 收藏 0 点赞 0 评论 0
def test_run_feed_dict(self):
    p = self.create_plan(loom_input_tensor=None)
    p.examples = [1] * 4
    self.check_plan(p, [])
    # Test that we don't clobber a better checkpoint with a worse one.
    tf.reset_default_graph()
    self._ClearCachedSession()
    p = self.create_plan(loom_input_tensor=None)
    p.examples = [1] * 4
    p.epochs = 1
    p._loss_total = tf.constant(42.0)
    # We aren't using a managed session, so we need to run this ourselves.
    init_op = tf.global_variables_initializer()
    sv = p.create_supervisor()
    with self.test_session() as sess:
      sess.run(init_op)
      p.run(sv, sess)
      log_str = p.print_file.getvalue()
      self.assertNotIn('new best model saved', log_str)
tflearn_seq2seq.py 文件源码 项目:tflearn_seq2seq 作者: ichuang 项目源码 文件源码 阅读 36 收藏 0 点赞 0 评论 0
def test_train_predict2():
    '''
    Test that the embedding_attention model works, with saving and loading of weights
    '''
    import tempfile
    sp = SequencePattern()
    tempdir = tempfile.mkdtemp()
    ts2s = TFLearnSeq2Seq(sp, seq2seq_model="embedding_attention", data_dir=tempdir, name="attention")
    tf.reset_default_graph()
    ts2s.train(num_epochs=1, num_points=1000, weights_output_fn=1, weights_input_fn=0)
    assert os.path.exists(ts2s.weights_output_fn)

    tf.reset_default_graph()
    ts2s = TFLearnSeq2Seq(sp, seq2seq_model="embedding_attention", data_dir="DATA", name="attention", verbose=1)
    prediction, y = ts2s.predict(Xin=range(10), weights_input_fn=1)
    assert len(prediction==10)

    os.system("rm -rf %s" % tempdir)
tflearn_seq2seq.py 文件源码 项目:tflearn_seq2seq 作者: ichuang 项目源码 文件源码 阅读 37 收藏 0 点赞 0 评论 0
def test_train_predict3():
    '''
    Test that a model trained on sequencees of one length can be used for predictions on other sequence lengths
    '''
    import tempfile
    sp = SequencePattern("sorted", in_seq_len=10, out_seq_len=10)
    tempdir = tempfile.mkdtemp()
    ts2s = TFLearnSeq2Seq(sp, seq2seq_model="embedding_attention", data_dir=tempdir, name="attention")
    tf.reset_default_graph()
    ts2s.train(num_epochs=1, num_points=1000, weights_output_fn=1, weights_input_fn=0)
    assert os.path.exists(ts2s.weights_output_fn)

    tf.reset_default_graph()
    sp = SequencePattern("sorted", in_seq_len=20, out_seq_len=8)
    tf.reset_default_graph()
    ts2s = TFLearnSeq2Seq(sp, seq2seq_model="embedding_attention", data_dir="DATA", name="attention", verbose=1)
    x = np.random.randint(0, 9, 20)
    prediction, y = ts2s.predict(x, weights_input_fn=1)
    assert len(prediction==8)

    os.system("rm -rf %s" % tempdir)
tflearn_seq2seq.py 文件源码 项目:tflearn_seq2seq 作者: ichuang 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def test_main3():
    '''
    Integration test - training then prediction: attention model
    '''
    import tempfile
    wfn = "tmp_weights.tfl"
    if os.path.exists(wfn):
        os.unlink(wfn)
    arglist = "-e 2 -o tmp_weights.tfl -v -v -v -v -m embedding_attention train 5000"
    arglist = arglist.split(' ')
    tf.reset_default_graph()
    ts2s = CommandLine(arglist=arglist)
    assert os.path.exists(wfn)

    arglist = "-i tmp_weights.tfl -v -v -v -v -m embedding_attention predict 1 2 3 4 5 6 7 8 9 0" 
    arglist = arglist.split(' ')
    tf.reset_default_graph()
    ts2s = CommandLine(arglist=arglist)
    assert len(ts2s.prediction_results[0][0])==10

#-----------------------------------------------------------------------------
test_tf_util.py 文件源码 项目:rl-attack-detection 作者: yenchenlin 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def test_multikwargs():
    tf.reset_default_graph()
    x = tf.placeholder(tf.int32, (), name="x")
    with tf.variable_scope("other"):
        x2 = tf.placeholder(tf.int32, (), name="x")
    z = 3 * x + 2 * x2

    lin = function([x, x2], z, givens={x2: 0})
    with single_threaded_session():
        initialize()
        assert lin(2) == 6
        assert lin(2, 2) == 10
        expt_caught = False
        try:
            lin(x=2)
        except AssertionError:
            expt_caught = True
        assert expt_caught
inception_v2_test.py 文件源码 项目:isbi2017-part3 作者: learningtitans 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def testUnknownImageShape(self):
    tf.reset_default_graph()
    batch_size = 2
    height, width = 224, 224
    num_classes = 1000
    input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
    with self.test_session() as sess:
      inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
      logits, end_points = inception.inception_v2(inputs, num_classes)
      self.assertTrue(logits.op.name.startswith('InceptionV2/Logits'))
      self.assertListEqual(logits.get_shape().as_list(),
                           [batch_size, num_classes])
      pre_pool = end_points['Mixed_5c']
      feed_dict = {inputs: input_np}
      tf.initialize_all_variables().run()
      pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
      self.assertListEqual(list(pre_pool_out.shape), [batch_size, 7, 7, 1024])
inception_v3_test.py 文件源码 项目:isbi2017-part3 作者: learningtitans 项目源码 文件源码 阅读 37 收藏 0 点赞 0 评论 0
def testUnknownImageShape(self):
    tf.reset_default_graph()
    batch_size = 2
    height, width = 299, 299
    num_classes = 1000
    input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
    with self.test_session() as sess:
      inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
      logits, end_points = inception.inception_v3(inputs, num_classes)
      self.assertListEqual(logits.get_shape().as_list(),
                           [batch_size, num_classes])
      pre_pool = end_points['Mixed_7c']
      feed_dict = {inputs: input_np}
      tf.initialize_all_variables().run()
      pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
      self.assertListEqual(list(pre_pool_out.shape), [batch_size, 8, 8, 2048])
inception_v1_test.py 文件源码 项目:isbi2017-part3 作者: learningtitans 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def testUnknownImageShape(self):
    tf.reset_default_graph()
    batch_size = 2
    height, width = 224, 224
    num_classes = 1000
    input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
    with self.test_session() as sess:
      inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
      logits, end_points = inception.inception_v1(inputs, num_classes)
      self.assertTrue(logits.op.name.startswith('InceptionV1/Logits'))
      self.assertListEqual(logits.get_shape().as_list(),
                           [batch_size, num_classes])
      pre_pool = end_points['Mixed_5c']
      feed_dict = {inputs: input_np}
      tf.initialize_all_variables().run()
      pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
      self.assertListEqual(list(pre_pool_out.shape), [batch_size, 7, 7, 1024])
test_sgld.py 文件源码 项目:chemblnet 作者: jaak-s 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def test_sgld_sparse(self):
        tf.reset_default_graph()

        z     = tf.Variable(tf.zeros((5, 2)), dtype=tf.float32)
        idx   = tf.placeholder(tf.int32)
        zi    = tf.gather(z, idx)
        zloss = tf.square(zi - [10.0, 5.0])

        sgld = SGLD(learning_rate=0.4)
        train_op_sgld = sgld.minimize(zloss)

        sess = tf.InteractiveSession()
        sess.run(tf.global_variables_initializer())

        self.assertTrue(np.alltrue(sess.run(z) == 0.0))

        sess.run(train_op_sgld, feed_dict={idx: 3})
        zh = sess.run(z)
        self.assertTrue(np.alltrue(zh[[0, 1, 2, 4], :] == 0.0))
        self.assertTrue(zh[3, 0] > 0)
test_sgld.py 文件源码 项目:chemblnet 作者: jaak-s 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def test_psgld_sparse(self):
        tf.reset_default_graph()

        z     = tf.Variable(tf.zeros((5, 2)), dtype=tf.float32)
        idx   = tf.placeholder(tf.int32)
        zi    = tf.gather(z, idx)
        zloss = tf.square(zi - [10.0, 5.0])

        psgld = pSGLD(learning_rate=0.4)
        train_op_psgld = psgld.minimize(zloss)

        sess = tf.InteractiveSession()
        sess.run(tf.global_variables_initializer())

        self.assertTrue(np.alltrue(sess.run(z) == 0.0))

        sess.run(train_op_psgld, feed_dict={idx: 3})
        zh = sess.run(z)
        self.assertTrue(np.alltrue(zh[[0, 1, 2, 4], :] == 0.0))
        self.assertTrue(zh[3, 0] > 0)
test_tf_util.py 文件源码 项目:baselines 作者: openai 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def test_multikwargs():
    tf.reset_default_graph()
    x = tf.placeholder(tf.int32, (), name="x")
    with tf.variable_scope("other"):
        x2 = tf.placeholder(tf.int32, (), name="x")
    z = 3 * x + 2 * x2

    lin = function([x, x2], z, givens={x2: 0})
    with single_threaded_session():
        initialize()
        assert lin(2) == 6
        assert lin(2, 2) == 10
        expt_caught = False
        try:
            lin(x=2)
        except AssertionError:
            expt_caught = True
        assert expt_caught
test_tensorflow.py 文件源码 项目:tensoronspark 作者: liangfengsid 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def test_save_restore():
    tf.reset_default_graph()
    sess = tf.Session()
    path = '/tmp/tensor_saved_test2'
    meta_path = path + '.meta'

    r = tf.train.import_meta_graph(meta_path)
    r.restore(sess, path)

    s = _get_saver(100)
    s.save(sess, path)


    tf.reset_default_graph()
    sess2 = tf.Session()

    r2 = tf.train.import_meta_graph(meta_path)
    r2.restore(sess2, path)
    s2 = _get_saver(100)
    s2.save(sess2, path)
run_dqn_atari.py 文件源码 项目:rl_algorithms 作者: DanielTakeshi 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def get_session():
    tf.reset_default_graph()
    tf_config = tf.ConfigProto(
        inter_op_parallelism_threads=1,
        intra_op_parallelism_threads=1)

    # This was the default provided in the starter code.
    #session = tf.Session(config=tf_config)

    # Use this if I want to see what is on the GPU.
    #session = tf.Session(config=tf.ConfigProto(log_device_placement=True))

    # Use this for limiting memory allocated for the GPU.
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
    session = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    print("AVAILABLE GPUS: ", get_available_gpus())
    return session
rcnn_proposal_test.py 文件源码 项目:luminoth 作者: tryolabs 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def setUp(self):
        super(RCNNProposalTest, self).setUp()

        self._num_classes = 3
        self._image_shape = (900, 1440)
        self._config = EasyDict({
            'class_max_detections': 100,
            'class_nms_threshold': 0.6,
            'total_max_detections': 300,
            'min_prob_threshold': 0.0,
        })

        self._equality_delta = 1e-03

        self._shared_model = RCNNProposal(self._num_classes, self._config)
        tf.reset_default_graph()
rcnn_target_test.py 文件源码 项目:luminoth 作者: tryolabs 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def setUp(self):
        super(RCNNTargetTest, self).setUp()

        # We don't care about the class labels or the batch number in most of
        # these tests.
        self._num_classes = 5
        self._placeholder_label = 3.

        self._config = EasyDict({
            'foreground_threshold': 0.5,
            'background_threshold_high': 0.5,
            'background_threshold_low': 0.1,
            'foreground_fraction': 0.5,
            'minibatch_size': 2,
        })
        # We check for a difference smaller than this numbers in our tests
        # instead of checking for exact equality.
        self._equality_delta = 1e-03

        self._shared_model = RCNNTarget(
            self._num_classes, self._config, seed=0
        )
        tf.reset_default_graph()
object_detection_dataset_test.py 文件源码 项目:luminoth 作者: tryolabs 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def setUp(self):
        self.base_config = EasyDict({
            'dataset': {
                'dir': '',
                'split': 'train',
                'image_preprocessing': {
                    'min_size': 600,
                    'max_size': 1024,
                },
                'data_augmentation': {},
            },
            'train': {
                'num_epochs': 1,
                'batch_size': 1,
                'random_shuffle': False,
                'seed': None,
            }
        })
        tf.reset_default_graph()
vae.py 文件源码 项目:vampyre 作者: GAMPTeam 项目源码 文件源码 阅读 47 收藏 0 点赞 0 评论 0
def build_graph(self):
        """
        Builds graph
        """

        # Clear the grapht
        tf.reset_default_graph()

        # Create the placeholder for the input
        nx = self.enc_dim[0]
        self.x = tf.placeholder("float", shape=[None, nx], name='x')


        # Builds the various components        
        if self.mode == 'train':
            self.build_enc()
        self.build_dec()
        self.build_loss_fn()

        # Add the summary op
        self.summary_op = tf.summary.merge_all()

        # Create a saver
        self.saver = tf.train.Saver()
weights_loading_scope.py 文件源码 项目:tflearn 作者: tflearn 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def demonstrate_loading_two_instances_of_model1():
    print("="*60 + " Demonstrate loading weights from model1 into two instances of model1 in scopeA and scopeB")
    tf.reset_default_graph()
    with tf.variable_scope("scopeA") as scope:
        m1a = Model1()
        print ("-" * 40 + " Trying to load model1 weights: should fail")
        try:
            m1a.model.load("model1.tfl", weights_only=True)
        except Exception as err:
            print ("Loading failed, with error as expected, because variables are in scopeA")
            print ("error: %s" % str(err))
        print ("-" * 40)

        print ("=" * 60 + " Trying to load model1 weights: should succeed")
        m1a.model.load("model1.tfl", scope_for_restore="scopeA", verbose=True, weights_only=True)

    with tf.variable_scope("scopeB") as scope:
        m1b = Model1()
        m1b.model.load("model1.tfl", scope_for_restore="scopeB", verbose=True, weights_only=True)
    print ("="*60 + " Successfully restored weights to two instances of model1, in different scopes")
seq2seq_example.py 文件源码 项目:tflearn 作者: tflearn 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def test_train_predict2():
    '''
    Test that the embedding_attention model works, with saving and loading of weights
    '''
    import tempfile
    sp = SequencePattern()
    tempdir = tempfile.mkdtemp()
    ts2s = TFLearnSeq2Seq(sp, seq2seq_model="embedding_attention", data_dir=tempdir, name="attention")
    tf.reset_default_graph()
    ts2s.train(num_epochs=1, num_points=1000, weights_output_fn=1, weights_input_fn=0)
    assert os.path.exists(ts2s.weights_output_fn)

    tf.reset_default_graph()
    ts2s = TFLearnSeq2Seq(sp, seq2seq_model="embedding_attention", data_dir="DATA", name="attention", verbose=1)
    prediction, y = ts2s.predict(Xin=range(10), weights_input_fn=1)
    assert len(prediction==10)

    os.system("rm -rf %s" % tempdir)
seq2seq_example.py 文件源码 项目:tflearn 作者: tflearn 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def test_train_predict3():
    '''
    Test that a model trained on sequencees of one length can be used for predictions on other sequence lengths
    '''
    import tempfile
    sp = SequencePattern("sorted", in_seq_len=10, out_seq_len=10)
    tempdir = tempfile.mkdtemp()
    ts2s = TFLearnSeq2Seq(sp, seq2seq_model="embedding_attention", data_dir=tempdir, name="attention")
    tf.reset_default_graph()
    ts2s.train(num_epochs=1, num_points=1000, weights_output_fn=1, weights_input_fn=0)
    assert os.path.exists(ts2s.weights_output_fn)

    tf.reset_default_graph()
    sp = SequencePattern("sorted", in_seq_len=20, out_seq_len=8)
    tf.reset_default_graph()
    ts2s = TFLearnSeq2Seq(sp, seq2seq_model="embedding_attention", data_dir="DATA", name="attention", verbose=1)
    x = np.random.randint(0, 9, 20)
    prediction, y = ts2s.predict(x, weights_input_fn=1)
    assert len(prediction==8)

    os.system("rm -rf %s" % tempdir)
seq2seq_example.py 文件源码 项目:tflearn 作者: tflearn 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def test_main2():
    '''
    Integration test - training then prediction
    '''
    import tempfile
    tempdir = tempfile.mkdtemp()
    arglist = "--data-dir %s -e 2 --iter-num=1 -v -v --tensorboard-verbose=1 train 5000" % tempdir
    arglist = arglist.split(' ')
    tf.reset_default_graph()
    ts2s = CommandLine(arglist=arglist)
    wfn = ts2s.weights_output_fn
    assert os.path.exists(wfn)

    arglist = "-i %s predict 1 2 3 4 5 6 7 8 9 0" % wfn
    arglist = arglist.split(' ')
    tf.reset_default_graph()
    ts2s = CommandLine(arglist=arglist)
    assert len(ts2s.prediction_results[0][0])==10

    os.system("rm -rf %s" % tempdir)


问题


面经


文章

微信
公众号

扫码关注公众号