我认为“不创建临时对象”是不可能的,尤其是因为 Python 中的“一切都是对象”。
如果你自己实现一些排序算法,你可以获得 O(1) 空间/对象数量,但如果你想要 O(n log n) 时间和稳定性,这很困难。如果您不关心稳定性(似乎很可能,因为您说要按a 排序,但实际上按a、b 和c 排序),堆排序相当简单:
def sort_together_heapsort(a, b, c):
n = len(a)
def swap(i, j):
a[i], a[j] = a[j], a[i]
b[i], b[j] = b[j], b[i]
c[i], c[j] = c[j], c[i]
def siftdown(i):
while (kid := 2*i+1) < n:
imax = kid if a[kid] > a[i] else i
kid += 1
if kid < n and a[kid] > a[imax]:
imax = kid
if imax == i:
return
swap(i, imax)
i = imax
for i in range(n // 2)[::-1]:
siftdown(i)
while n := n - 1:
swap(0, n)
siftdown(0)
无论如何,如果有人对节省一些内存感兴趣,可以通过就地装饰(构建元组并将它们存储在a)来完成:
def sort_together_decorate_in_a(a, b, c):
for i, a[i] in enumerate(zip(a, b, c)):
pass
a.sort()
for i, [a[i], b[i], c[i]] in enumerate(a):
pass
或者,如果您相信 list.sort 会按顺序要求元素的键(至少在 CPython 中确实如此,在引入 key 参数时已经这样做了 18 years ago,我怀疑会继续这样做):
def sort_together_iter_key(a, b, c):
it = iter(a)
b.sort(key=lambda _: next(it))
it = iter(a)
c.sort(key=lambda _: next(it))
a.sort()
用三个包含 100,000 个元素的列表测试内存和时间:
15,072,520 bytes 152 ms sort_together_sorted_zip
15,072,320 bytes 166 ms sort_together_sorted_zip_2
14,272,576 bytes 152 ms sort_together_sorted_zip_X
6,670,708 bytes 126 ms sort_together_decorate_in_a
6,670,772 bytes 177 ms sort_together_decorate_in_first_X
5,190,212 bytes 342 ms sort_multi_by_a_guest_X
1,597,400 bytes 100 ms sort_together_iter_key
1,597,448 bytes 102 ms sort_together_iter_key_X
744 bytes 1584 ms sort_together_heapsort
704 bytes 1663 ms sort_together_heapsort_X
168 bytes 1326 ms sort_together_heapsort_opti
188 bytes 1512 ms sort_together_heapsort_opti_X
注意:
- 第二种解决方案是您的缩短/改进版本,不需要临时变量和转换为列表。
- 带有
_X后缀的解决方案是采用任意多个列表作为参数的版本。
- @a_guest 来自他们的回答。运行时方面,它目前受益于我的数据是随机的,因为这不会暴露该解决方案的最坏情况复杂度 O(m * n²),其中 m 是列表的数量,n 是每个列表的长度。
用十个包含 100,000 个元素的列表测试内存和时间:
19,760,808 bytes 388 ms sort_together_sorted_zip_X
12,159,100 bytes 425 ms sort_together_decorate_in_first_X
5,190,292 bytes 1249 ms sort_multi_by_a_guest_X
1,597,528 bytes 393 ms sort_together_iter_key_X
704 bytes 4186 ms sort_together_heapsort_X
188 bytes 4032 ms sort_together_heapsort_opti_X
整个代码(Try it online!):
import tracemalloc as tm
from random import random
from timeit import timeit
def sort_together_sorted_zip(a, b, c):
a_sorted, b_sorted, c_sorted = map(list, zip(*sorted(zip(a, b, c))))
a[:] = a_sorted
b[:] = b_sorted
c[:] = c_sorted
def sort_together_sorted_zip_2(a, b, c):
a[:], b[:], c[:] = zip(*sorted(zip(a, b, c)))
def sort_together_sorted_zip_X(*lists):
sorteds = zip(*sorted(zip(*lists)))
for lst, lst[:] in zip(lists, sorteds):
pass
def sort_together_decorate_in_a(a, b, c):
for i, a[i] in enumerate(zip(a, b, c)):
pass
a.sort()
for i, [a[i], b[i], c[i]] in enumerate(a):
pass
def sort_together_decorate_in_first_X(*lists):
first = lists[0]
for i, first[i] in enumerate(zip(*lists)):
pass
first.sort()
for i, values in enumerate(first):
for lst, lst[i] in zip(lists, values):
pass
def sort_together_iter_key(a, b, c):
it = iter(a)
b.sort(key=lambda _: next(it))
it = iter(a)
c.sort(key=lambda _: next(it))
a.sort()
def sort_together_iter_key_X(*lists):
for lst in lists[1:]:
it = iter(lists[0])
lst.sort(key=lambda _: next(it))
lists[0].sort()
def sort_together_heapsort(a, b, c):
n = len(a)
def swap(i, j):
a[i], a[j] = a[j], a[i]
b[i], b[j] = b[j], b[i]
c[i], c[j] = c[j], c[i]
def siftdown(i):
while (kid := 2*i+1) < n:
imax = kid if a[kid] > a[i] else i
kid += 1
if kid < n and a[kid] > a[imax]:
imax = kid
if imax == i:
return
swap(i, imax)
i = imax
for i in range(n // 2)[::-1]:
siftdown(i)
while n := n - 1:
swap(0, n)
siftdown(0)
def sort_together_heapsort_X(*lists):
a = lists[0]
n = len(a)
def swap(i, j):
for lst in lists:
lst[i], lst[j] = lst[j], lst[i]
def siftdown(i):
while (kid := 2*i+1) < n:
imax = kid if a[kid] > a[i] else i
kid += 1
if kid < n and a[kid] > a[imax]:
imax = kid
if imax == i:
return
swap(i, imax)
i = imax
for i in range(n // 2)[::-1]:
siftdown(i)
while n := n - 1:
swap(0, n)
siftdown(0)
def sort_together_heapsort_opti(a, b, c):
# Avoid inner functions and range-loop to minimize memory.
# Makes it faster, too. But duplicates code. Not recommended.
n = len(a)
i0 = n // 2 - 1
while i0 >= 0:
i = i0
while (kid := 2*i+1) < n:
imax = kid if a[kid] > a[i] else i
kid += 1
if kid < n and a[kid] > a[imax]:
imax = kid
if imax == i:
break
a[i], a[imax] = a[imax], a[i]
b[i], b[imax] = b[imax], b[i]
c[i], c[imax] = c[imax], c[i]
i = imax
i0 -= 1
while n := n - 1:
a[0], a[n] = a[n], a[0]
b[0], b[n] = b[n], b[0]
c[0], c[n] = c[n], c[0]
i = 0
while (kid := 2*i+1) < n:
imax = kid if a[kid] > a[i] else i
kid += 1
if kid < n and a[kid] > a[imax]:
imax = kid
if imax == i:
break
a[i], a[imax] = a[imax], a[i]
b[i], b[imax] = b[imax], b[i]
c[i], c[imax] = c[imax], c[i]
i = imax
def sort_together_heapsort_opti_X(*lists):
# Avoid inner functions and range-loop to minimize memory.
# Makes it faster, too. But duplicates code. Not recommended.
a = lists[0]
n = len(a)
i0 = n // 2 - 1
while i0 >= 0:
i = i0
while (kid := 2*i+1) < n:
imax = kid if a[kid] > a[i] else i
kid += 1
if kid < n and a[kid] > a[imax]:
imax = kid
if imax == i:
break
for lst in lists:
lst[i], lst[imax] = lst[imax], lst[i]
i = imax
i0 -= 1
while n := n - 1:
for lst in lists:
lst[0], lst[n] = lst[n], lst[0]
i = 0
while (kid := 2*i+1) < n:
imax = kid if a[kid] > a[i] else i
kid += 1
if kid < n and a[kid] > a[imax]:
imax = kid
if imax == i:
break
for lst in lists:
lst[i], lst[imax] = lst[imax], lst[i]
i = imax
def sort_multi_by_a_guest_X(a, *lists):
indices = list(range(len(a)))
indices.sort(key=lambda i: a[i])
a.sort()
for lst in lists:
for i, j in enumerate(indices):
while j < i:
j = indices[j]
lst[i], lst[j] = lst[j], lst[i]
funcs = [
sort_together_sorted_zip,
sort_together_sorted_zip_2,
sort_together_sorted_zip_X,
sort_together_decorate_in_a,
sort_together_decorate_in_first_X,
sort_multi_by_a_guest_X,
sort_together_iter_key,
sort_together_iter_key_X,
sort_together_heapsort,
sort_together_heapsort_X,
sort_together_heapsort_opti,
sort_together_heapsort_opti_X,
]
n = 100000
a0 = [random() for _ in range(n)]
b0 = [x + 1 for x in a0]
c0 = [x + 2 for x in a0]
for _ in range(3):
for func in funcs:
a, b, c = a0[:], b0[:], c0[:]
time = timeit(lambda: func(a, b, c), number=1)
assert a == sorted(a0)
assert b == sorted(b0)
assert c == sorted(c0)
a, b, c = a0[:], b0[:], c0[:]
tm.start()
func(a, b, c)
memory = tm.get_traced_memory()[1]
tm.stop()
print(f'{memory:10,} bytes {int(time * 1e3):4} ms {func.__name__}')
print()