这是我制定的一个可能的解决方案,不确定它是否真的适合您的情况。这个想法是根据给定的大小递归地细分所需的网格。有趣的是,这在 TensorFlow 2.x 中不起作用,因为它试图将递归函数变成一个图,这将导致一个无限图 - 不确定是否有解决方法。该解决方案显然比直接在整个网格上进行计算要慢得多,但结果技术上是相同的。问题是错误。如果网格真的很大,那么对其整体进行缩减很可能会产生大量错误。不过,这首先发生在没有进行细分的情况下,实际上细分似乎可以减少错误,如果有的话(至少在我做过的一些实验中)。
不管怎样,代码有点长,虽然概念上不算太复杂,但希望cmets能说清楚。
import tensorflow as tf
def block_reduction_func(block):
# This function computes the reduction of a block
return tf.math.reduce_sum(block)
def intermediate_reduction_func(values):
# This function computes the reduction of an
# array of intermediate reduction results
# (in this example is the same)
return block_reduction_func(values)
def make_block(aa):
# Makes an actual block from some space slices
return tf.stack(tf.meshgrid(*aa), axis=-1)
def get_block_slices(aa, i, size):
# Selects the space slices corresponding to a particular block
aa2 = []
for dim, a in enumerate(aa):
# Number of slices in this level for this dimension
s = tf.size(a)
n = s // size
n += tf.dtypes.cast(s % size > 0, n.dtype)
# Select dimension slice
j = i % n
aa2.append(aa[dim][j * size:(j + 1) * size])
i //= n
return aa2
def by_blocks(aa, blocks):
# Reduces a space by blocks
if not blocks:
# When there are no more subdivisions to do
# just reduce the current block
res = block_reduction_func(make_block(aa))
with tf.control_dependencies([]): #([tf.print(res, aa)]):
return res + 0
else:
# Get current division size
size, *blocks = blocks
# Get number of blocks in this recursion level
num_blocks = 1
for a in aa:
s = tf.size(a)
n = s // size
n += tf.dtypes.cast(s % size > 0, n.dtype)
num_blocks *= n
# Array for intermediate results
ta = tf.TensorArray(aa[0].dtype, num_blocks, element_shape=())
# Loop through blocks
_, ta = tf.while_loop(
lambda i, ta: i < num_blocks,
lambda i, ta: (i + 1,
ta.write(i, by_blocks(get_block_slices(aa, i, size), blocks))),
[0, ta], parallel_iterations=1)
# Reduce intermediate results
values = ta.stack()
return intermediate_reduction_func(values)
# Test
b = 1.0
n = 100
d = 3
# Recursive divisions of n (can have arbitrary size)
# Divide in blocks of 60, then blocks of 12
blocks = [60, 12]
with tf.Graph().as_default(), tf.Session():
# Using positive values only in this example
# so the errors do not overtake the result
a = tf.linspace(0., b, n)
aa = [a] * d
r1 = block_reduction_func(make_block(aa))
r2 = by_blocks(aa, blocks)
# Check results (should be 1500000)
print(r1.eval())
# 1499943.6
print(r2.eval())
# 1499998.8
# CPU timings
%timeit r1.eval()
# 99.3 µs ± 169 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit r2.eval()
# 96.8 ms ± 170 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# GPU timings
%timeit r1.eval()
# 195 µs ± 615 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit r2.eval()
# 316 ms ± 1.54 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)