tf.train.SessionRunHook
这个方法是为了打印运行时的数据,方便debug,初学者理解
import tensorflow as tf
import numpy as np
x = tf.placeholder(shape=(10, 2), dtype=tf.float32)
w = tf.Variable(initial_value=[[10.], [10.]])
w0 = [[1], [1.]]
y = tf.matmul(x, w0)
loss = tf.reduce_mean((tf.matmul(x, w) - y) ** 2)
optimizer = tf.train.AdamOptimizer(0.001).minimize(loss)
class _Hook(tf.train.SessionRunHook):
def __init__(self, loss, w):
self.loss = loss
self.w = w
def begin(self):
pass
def before_run(self, run_context):
return tf.train.SessionRunArgs([self.loss, self.w]) #这里输入的[]对应run_values.results
def after_run(self, run_context, run_values):
print(type(run_values))
print(run_values)
loss_value = run_values.results[0]
w_value = run_values.results[1]
print("loss value:", loss_value)
print("w:",w_value)
#sess = tf.train.MonitoredSession(hooks=[_Hook(loss,w)])
sess = tf.train.MonitoredTrainingSession(hooks=[_Hook(loss,w)])
for _ in range(10):
x_ = np.random.random((10, 2))
sess.run(optimizer, {x: x_})
tf.nn.embedding_lookup_sparse
加载meta文件,查看模型结构
import tensorflow as tf
from tensorflow.python.platform import gfile
#这是从文件格式的meta文件加载模型
graph = tf.get_default_graph()
graphdef = graph.as_graph_def()
# graphdef.ParseFromString(gfile.FastGFile("/data/TensorFlowAndroidMNIST/app/src/main/expert-graph.pb", "rb").read())
# _ = tf.import_graph_def(graphdef, name="")
_ = tf.train.import_meta_graph("./dsp_graph.meta")
summary_write = tf.summary.FileWriter("./log" , graph)
#然后tensorboard --logdir log即可
tf.add_to_collection
在一个计算图中,可以通过集合collection来管理不同类别的资源。可以通过tf.add_to_collection函数将资源计入一个或多个集合中,然后通过tf.get_collection获取一个集合里面的所有资源。tensorflow也自动管理来一些常用的集合如下
add:0的表示方式
下面脚本里面add:0表示result这个张量是计算节点”add”输出的第一个结果,编号从0开始
会话和计算图的关系
tf会自动生成一个默认图,不会自动生成默认会话
import tensorflow as tf
//使用默认的计算图
a = tf.constant([1,2],name="a")
b = tf.constant([3,4],name="b")
result = a + b
//自己创建计算图
g1 = tf.Graph()
with g1.as_default():
a = tf.constant([1,2],name="a")
b = tf.constant([3,4],name="b")
result = a + b
//使用默认计算图
with tf.Session() as sess:
print(sess.run(result))
//使用自己创建的计算图
with tf.Session(graph=g1) as sess:
print(sess.run(result))