【发布时间】:2020-01-23 02:09:52
【问题描述】:
我有这段代码可以使用 TensorFlow RNN 进行文本分类,但是如何将其更改为进行文本生成呢?
以下文本分类有 3D 输入,但有 2D 输出。是否应该将其更改为 3D 输入和 3D 输出以生成文本?以及如何?
示例数据为:
t0 t1 t2
british gray is => cat (y=0)
0 1 2
white samoyed is => dog (y=1)
3 4 2
对于分类喂养“英国灰色是”导致“猫”。我希望得到的是喂“英国人”应该导致下一个词“灰色”。
import tensorflow as tf;
tf.reset_default_graph();
#data
'''
t0 t1 t2
british gray is => cat (y=0)
0 1 2
white samoyed is => dog (y=1)
3 4 2
'''
Bsize = 2;
Times = 3;
Max_X = 4;
Max_Y = 1;
X = [[[0],[1],[2]], [[3],[4],[2]]];
Y = [[0], [1] ];
#normalise
for I in range(len(X)):
for J in range(len(X[I])):
X[I][J][0] /= Max_X;
for I in range(len(Y)):
Y[I][0] /= Max_Y;
#model
Inputs = tf.placeholder(tf.float32, [Bsize,Times,1]);
Expected = tf.placeholder(tf.float32, [Bsize, 1]);
#single LSTM layer
#'''
Layer1 = tf.keras.layers.LSTM(20);
Hidden1 = Layer1(Inputs);
#'''
#multi LSTM layers
'''
Layers = tf.keras.layers.RNN([
tf.keras.layers.LSTMCell(30), #hidden 1
tf.keras.layers.LSTMCell(20) #hidden 2
]);
Hidden2 = Layers(Inputs);
'''
Weight3 = tf.Variable(tf.random_uniform([20,1], -1,1));
Bias3 = tf.Variable(tf.random_uniform([ 1], -1,1));
Output = tf.sigmoid(tf.matmul(Hidden1,Weight3) + Bias3);
Loss = tf.reduce_sum(tf.square(Expected-Output));
Optim = tf.train.GradientDescentOptimizer(1e-1);
Training = Optim.minimize(Loss);
#train
Sess = tf.Session();
Init = tf.global_variables_initializer();
Sess.run(Init);
Feed = {Inputs:X, Expected:Y};
for I in range(1000): #number of feeds, 1 feed = 1 batch
if I%100==0:
Lossvalue = Sess.run(Loss,Feed);
print("Loss:",Lossvalue);
#end if
Sess.run(Training,Feed);
#end for
Lastloss = Sess.run(Loss,Feed);
print("Loss:",Lastloss,"(Last)");
#eval
Results = Sess.run(Output,Feed);
print("\nEval:");
print(Results);
print("\nDone.");
#eof
【问题讨论】:
-
你的意思是它的当前状态?或者你可以重新训练它吗?
-
@Recessive 我的意思是如何得到下一个词而不是类,例如,喂“英国”,我应该能够得到“灰色”而不是喂“英国灰色是”来得到“猫”
-
示例数据令人困惑,但看起来不兼容。由于您没有回答,我假设您可以重新训练网络,在这种情况下,最好的做法是相同的输入和输出尺寸,可能是 1d。为此,您可以获取训练数据中的所有单词并将它们用作输入和输出的非常大的 1 热向量。例如,假设您有单词
['hello', 'hi','is','that','yes'],那么您的输入将是长度为 5 的 1d,而要输入'hello',您将在索引 0 处输入 1
标签: tensorflow machine-learning nlp recurrent-neural-network text-classification