如果我理解正确,您想计算从体积中的每个位置到给定类的最近位置的距离。为简单起见,我假设有趣的类标有1,但如果它不同,希望您可以根据您的情况调整它。该代码适用于 TensorFlow 2.0,但应该适用于 1.x。
执行此操作的最简单方法是使用1 计算体积中所有坐标与每个坐标之间的距离,然后从中选择最小的距离。你可以这样做:
import tensorflow as tf
# Make input data
w, h, d = 10, 20, 30
w, h, d = 2, 3, 4
t = tf.random.stateless_uniform([w, h, d], (0, 0), 0, 2, tf.int32)
print(t.numpy())
# [[[0 1 0 0]
# [0 0 0 0]
# [1 1 0 1]]
#
# [[1 0 0 0]
# [0 0 0 0]
# [1 1 0 0]]]
# Make coordinates
coords = tf.meshgrid(tf.range(w), tf.range(h), tf.range(d), indexing='ij')
coords = tf.stack(coords, axis=-1)
# Find coordinates that are positive
m = t > 0
coords_pos = tf.boolean_mask(coords, m)
# Find every pairwise distance
vec_d = tf.reshape(coords, [-1, 1, 3]) - coords_pos
# You may choose a difference precision type here
dists = tf.linalg.norm(tf.dtypes.cast(vec_d, tf.float32), axis=-1)
# Find minimum distances
min_dists = tf.reduce_min(dists, axis=-1)
# Reshape
out = tf.reshape(min_dists, [w, h, d])
print(out.numpy().round(3))
# [[[1. 0. 1. 2. ]
# [1. 1. 1.414 1. ]
# [0. 0. 1. 0. ]]
#
# [[0. 1. 1.414 2.236]
# [1. 1. 1.414 1.414]
# [0. 0. 1. 1. ]]]
这可能对您来说足够好,尽管它可能不是最有效的解决方案。最聪明的做法是在每个位置的相邻区域中搜索最接近的正位置,但要有效地做到这一点很复杂,一般来说,在 TensorFlow 中以矢量化方式更是如此。然而,我们可以通过几种方法来改进上面的代码。一方面,我们知道1 的位置总是零距离,因此不需要计算这些位置。另一方面,如果 3D 体积中的 1 类表示某种密集形状,那么如果我们只计算与该形状表面的距离,我们就可以节省一些时间。所有其他正位置必然与形状外的位置具有更大的距离。所以我们可以做同样的事情,但只计算从非正位置到正表面位置的距离。你可以这样做:
import tensorflow as tf
# Make input data
w, h, d = 10, 20, 30
w, h, d = 2, 3, 4
t = tf.dtypes.cast(tf.random.stateless_uniform([w, h, d], (0, 0)) > .15, tf.int32)
print(t.numpy())
# [[[1 1 1 1]
# [1 1 1 1]
# [1 1 0 0]]
#
# [[1 1 1 1]
# [1 1 1 1]
# [1 1 1 1]]]
# Find coordinates that are positive and on the surface
# (surrounded but at least one 0)
t_pad_z = tf.pad(t, [(1, 1), (1, 1), (1, 1)]) <= 0
m_pos = t > 0
m_surround_z = tf.zeros_like(m_pos)
# Go through the 6 surrounding positions
for i in range(3):
for s in [slice(None, -2), slice(2, None)]:
slices = tuple(slice(1, -1) if i != j else s for j in range(3))
m_surround_z |= t_pad_z.__getitem__(slices)
# Surface points are positive points surrounded by some zero
m_surf = m_pos & m_surround_z
coords_surf = tf.where(m_surf)
# Find coordinates that are zero
coords_z = tf.where(~m_pos)
# Find every pairwise distance
vec_d = tf.reshape(coords_z, [-1, 1, 3]) - coords_surf
dists = tf.linalg.norm(tf.dtypes.cast(vec_d, tf.float32), axis=-1)
# Find minimum distances
min_dists = tf.reduce_min(dists, axis=-1)
# Put minimum distances in output array
out = tf.scatter_nd(coords_z, min_dists, [w, h, d])
print(out.numpy().round(3))
# [[[0. 0. 0. 0.]
# [0. 0. 0. 0.]
# [0. 0. 1. 1.]]
#
# [[0. 0. 0. 0.]
# [0. 0. 0. 0.]
# [0. 0. 0. 0.]]]
编辑:这是一种使用 TensorFlow 循环将距离计算分成块的方法:
# Following from before
coords_surf = ...
coords_z = ...
CHUNK_SIZE = 1_000 # Choose chunk size
dtype = tf.float32
# If using TF 2.x you can know in advance the size of the tensor array
# (although the element shape will not be constant due to the last chunk)
num_z = tf.shape(coords_z)[0]
arr = tf.TensorArray(dtype, size=(num_z - 1) // CHUNK_SIZE + 1, element_shape=[None], infer_shape=False)
_, arr = tf.while_loop(lambda i, arr: i < num_z,
lambda i, arr: (i + CHUNK_SIZE, arr.write(i // CHUNK_SIZE,
tf.reduce_min(tf.linalg.norm(tf.dtypes.cast(
tf.reshape(coords_z[i:i + CHUNK_SIZE], [-1, 1, 3]) - coords_surf,
dtype), axis=-1), axis=-1))),
[tf.constant(0, tf.int32), arr])
min_dists = arr.concat()
out = tf.scatter_nd(coords_z, min_dists, [w, h, d])