【发布时间】:2021-09-22 09:13:35
【问题描述】:
我正在研究文本生成示例https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/charmodelling/generatetext/GenerateTxtCharCompGraphModel.java。 lstm网络的输出是一个概率分布,据我理解,这是一个双精度数组,其中每个值表示字符对应于数组中索引的概率。所以我无法理解我们从分布中获取字符索引的以下代码:
/** Given a probability distribution over discrete classes, sample from the distribution
* and return the generated class index.
* @param distribution Probability distribution over classes. Must sum to 1.0
*/
static int sampleFromDistribution(double[] distribution, Random rng){
double d = 0.0;
double sum = 0.0;
for( int t=0; t<10; t++ ) {
d = rng.nextDouble();
sum = 0.0;
for( int i=0; i<distribution.length; i++ ){
sum += distribution[i];
if( d <= sum ) return i;
}
//If we haven't found the right index yet, maybe the sum is slightly
//lower than 1 due to rounding error, so try again.
}
//Should be extremely unlikely to happen if distribution is a valid probability distribution
throw new IllegalArgumentException("Distribution is invalid? d="+d+", sum="+sum);
}
似乎我们得到了一个随机值。为什么我们不直接选择价值最高的索引呢?如果我想选择的不是一个,而是两个或三个最有可能的下一个字符,我应该怎么做?
【问题讨论】: