【问题标题】:Tensorflow, conv2d and filtersTensorFlow、conv2d 和过滤器
【发布时间】:2017-03-16 18:52:57
【问题描述】:

我是深度学习的初学者,并试图了解算法的工作原理,并使用 JavaScript 编写它们。现在我正在像Tensorflow一样研究conv2d的JavaScript实现,并且误解了如何处理不同数量的过滤器,我已经成功处理了一个输出过滤器和多个输出,但是我很困惑如何使用多个过滤器输入进行操作,例如32 -> 64

这是使用ndarray 的代码示例 :

const outCount = 32 // count of inputs filters
const inCount = 1 // count of output features
const filterSize = 3
const stride = 1
const inShape = [1, 10, 10, outCount]
const outShape = [
  1,
  Math.ceil((inShape[1] - filterSize + 1) / stride),
  Math.ceil((inShape[2] - filterSize + 1) / stride),
  outCount
];
const filters = ndarray([], [filterSize, filterSize, inCount, outCount])

const conv2d = (input) => {
  const result = ndarray(outShape)
   // for each output feature

  for (let fo = 0; fo < outCount; fo += 1) { 
    for (let x = 0; x < outShape[1]; x += 1) {
      for (let y = 0; y < outShape[2]; y += 1) {
      const fragment = ndarray([], [filterSize, filterSize]);
      const filter = ndarray([], [filterSize, filterSize]);

      // agregate fragment of image and filter
      for (let fx = 0; fx < filterSize; fx += 1) {
        for (let fy = 0; fy < filterSize; fy += 1) {
          const dx = (x * stride) + fx;
          const dy = (y * stride) + fy;

          fragment.data.push(input.get(0, dx, dy, 0));
          filter.data.push(filters.get(fx, fy, 0, fo));
        }
      }

      // calc dot product of filter and image fragment
      result.set(0, x, y, fo, dot(filter, fragment));
      }
    }
  }

  return result
}

为了测试,我使用 Tenforflow 作为 true 的来源,它的算法工作正常,但使用 1 -&gt; N。但我的问题是如何在输入值中添加对多个过滤器的支持,例如N -&gt; M

有人能解释一下如何修改这个算法以使其更兼容 Tensorflow tf.nn.conv2d 非常感谢。

【问题讨论】:

    标签: javascript multidimensional-array tensorflow convolution


    【解决方案1】:

    您需要添加另一个 for 循环。您没有指定所有输入的形状和尺寸,所以实际上很难准确地写出来,但它看起来像这样。

      // agregate fragment of image and filter
      for (let fx = 0; fx < filterSize; fx += 1) {
        for (let fy = 0; fy < filterSize; fy += 1) {
          //addition
          for (let ch = 0; ch < input.get_channels) {
            const dx = (x * stride) + fx;
            const dy = (y * stride) + fy;
    
            fragment.data.push(input.get(0, dx, dy, ch));
            filter.data.push(filters.get(fx, fy, ch, fo));
          }
        }
      }
    

    【讨论】:

    • 看来你说的很对,非常感谢!
    猜你喜欢
    • 1970-01-01
    • 2023-01-27
    • 2018-12-13
    • 1970-01-01
    • 1970-01-01
    • 2017-04-01
    • 2018-01-12
    • 2021-12-31
    • 2018-01-16
    相关资源
    最近更新 更多