【发布时间】:2019-05-09 20:46:19
【问题描述】:
我正在尝试通过将我的 tensorflow1.13 脚本(如下)转换为 tensorflow2.0 脚本来了解 tensorflow2.0 的动态。但是我正在努力做到这一点。
我认为我挣扎的主要原因是因为 tensorflow2.0 的例子我见过训练神经网络,所以他们有一个model,他们有compile 和fit。但是,在下面的简单示例中,我没有使用神经网络,因此我看不到如何将此代码调整为 tensorflow2.0(例如,如何替换会话?)。非常感谢您的帮助,并在此先感谢您。
data = tf.placeholder(tf.int32)
theta = tf.Variable(np.zeros(100))
p_s = tf.nn.softmax(theta)
loss = tf.reduce_mean(-tf.log(tf.gather(p_s, data)))
train_step = tf.train.AdamOptimizer().minimize(loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(10):
for datum in sample_data(): #sample_data() is a list of integer datapoints
_ = sess.run([train_step], feed_dict={data:datum})
print(sess.run(p_s))
我查看了this(这是最相关的),到目前为止我想出了以下内容:
#data = tf.placeholder(tf.int32)
theta = tf.Variable(np.zeros(100))
p_s = tf.nn.softmax(theta)
loss = tf.reduce_mean(-tf.math.log(tf.gather(p_s, **data**)))
optimizer = tf.keras.optimizers.Adam()
for epoch in range(10):
for datum in sample_data():
optimizer.apply_gradients(loss)
print(p_s)
但是上面显然没有运行,因为损失函数中的占位符 data 不再存在 - 但是我不知道如何替换它。 :S
有人吗? 请注意,我没有 def forward(x) 因为我的输入 datum 没有转换 - 它直接用于计算损失。 em>
【问题讨论】:
标签: python tensorflow tensorflow2.0