【发布时间】:2020-06-25 12:27:18
【问题描述】:
我有这样的问题,而且是在最新版的tensorflow上出现的。我希望有人能给我一些建议。 我的代码如下:
%tensorflow_version 2.x
import tensorflow as tf
import numpy as np
import h5py
import t3f
import matplotlib.pyplot as plt
filename = "./video.h5"
np.random.seed(0)
with h5py.File(filename, "r") as f:
print("Keys: %s" % f.keys())
a_group_key = list(f.keys())[0]
data = list(f[a_group_key])
data_np = np.array(data)
data_tensor = tf.convert_to_tensor(data_np)
shape = [107]+[60]+[80]+[3]
# A is large tt-ranks tensor
A = t3f.to_tt_tensor(data_tensor)
# Create an X variable.
init_X = t3f.random_tensor(shape, tt_rank=3)
X = t3f.get_variable('X', initializer=init_X)
def step():
gradF = X - A
riemannian_grad = t3f.riemannian.project(gradF, X)
alpha = 1.0
t3f.assign(X, t3f.round(X - alpha * riemannian_grad, max_tt_rank=2))
return 0.5 * t3f.frobenius_norm_squared(X - A)
log = []
for i in range(1000):
F = step()
if i % 10 == 0:
print(F)
log.append(F.numpy())
退出:
ValueError Traceback(最近一次调用最后一次)
1 日志 = [] 2 for i in range(1000): ----> 3 F = step() 4 如果我 % 10 == 0: 5 打印(F)4 帧 /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/tensor_shape.py 在 assert_is_compatible_with(self, other) 1115 """ 1116
如果不是 self.is_compatible_with(other): -> 1117 raise ValueError("Shapes %s and %s are incompatible" % (self, other)) 1118 1119 def most_specific_compatible_shape(self, other):ValueError:形状 (1, 107, 3) 和 (1, 107, 2) 不兼容
但是data_tensor的形状是[107]+[60]+[80]+[3],和A或者X一样,我很困惑。
我用 tf2 和 python3.6 在 google colab 上运行了这段代码。 在这个链接中你会重复我的问题link
【问题讨论】:
标签: tensorflow2.0 tensor