【发布时间】:2023-06-17 09:11:01
【问题描述】:
我正在做 CNN 模型压缩,并试图减少权重的位以获得位的长度和准确性之间的关系。但是当我使用Tensorflow网站的方法改变CNN的权重类型时,出现了错误:
“类型错误:传递给参数‘a’的值的 DataType int8 不在允许值列表中:float16、float32、float64、int32、complex64、complex128”。
似乎重量不能是其他Dtype。但我读了一些类似https://arxiv.org/pdf/1502.02551.pdf 的论文。可以将权重的位数减少到 6bits , 4bits ,甚至更低的位。
我的代码在这里(忽略导入的东西):
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
x = tf.placeholder(tf.int8,[None,784])
W = tf.Variable(tf.zeros([784,10]),tf.int8)
b = tf.Variable(tf.zeros([10]),tf.int8)
y = tf.nn.softmax(tf.matmul(x,W)+b)
#the error come out with "y = tf.nn.softmax(tf.matmul(x,W)+b)"
这只是一个标准的tensorflow官方代码,只是改变了变量的Dtype。我也尝试过 tf.cast ,但它仍然出现错误。
tf.cast(W,tf.int8)
tf.cast(b,tf.int8)
谁能告诉我如何克服这种情况?非常感谢!!
【问题讨论】:
-
请添加您的代码示例以及您的预期输出。
标签: python tensorflow training-data mnist 8-bit