【发布时间】:2018-08-26 00:48:45
【问题描述】:
我已尝试按照本文 (3 way quicksort (C implementation)) 中的描述实现 C 快速选择算法。 但是,我得到的只是性能比默认的 qsort 低 5 到 10 倍(即使是初始改组)。 我试图挖掘此处提供的原始 qsort 源代码 (https://github.com/lattera/glibc/blob/master/stdlib/qsort.c),但它太复杂了。 有人有更简单,更好的算法吗? 欢迎任何想法。 谢谢, 注意:我最初的问题是尝试将数组的第 K 个最小值获取到第一个 Kth 索引。所以我打算调用quickselect K次 编辑 1:这是从上面的链接复制和改编的 Cython 代码
cdef void qswap(void* a, void* b, const size_t size) nogil:
cdef char temp[size]# C99, use malloc otherwise
#char serves as the type for "generic" byte arrays
memcpy(temp, b, size)
memcpy(b, a, size)
memcpy(a, temp, size)
cdef void qshuffle(void* base, size_t num, size_t size) nogil: #implementation of Fisher
cdef int i, j, tmp# create local variables to hold values for shuffle
for i in range(num - 1, 0, -1): # for loop to shuffle
j = c_rand() % (i + 1)#randomise j for shuffle with Fisher Yates
qswap(base + i*size, base + j*size, size)
cdef void partition3(void* base,
size_t *low, size_t *high, size_t size,
QComparator compar) nogil:
# Modified median-of-three and pivot selection.
cdef void *ptr = base
cdef size_t lt = low[0]
cdef size_t gt = high[0] # lt is the pivot
cdef size_t i = lt + 1# (+1 !) we don't compare pivot with itself
cdef int c = 0
while (i <= gt):
c = compar(ptr + i * size, ptr + lt * size)
if (c < 0):# base[i] < base[lt] => swap(i++,lt++)
qswap(ptr + lt * size, ptr + i * size, size)
i += 1
lt += 1
elif (c > 0):#base[i] > base[gt] => swap(i, gt--)
qswap(ptr + i * size, ptr + gt* size, size)
gt -= 1
else:#base[i] == base[gt]
i += 1
#base := [<<<<<lt=====gt>>>>>>]
low[0] = lt
high[0] = gt
cdef void qselectk3(void* base, size_t lo, size_t hi,
size_t size, size_t k,
QComparator compar) nogil:
cdef size_t low = lo
cdef size_t high = hi
partition3(base, &low, &high, size, compar)
if ((k - 1) < low): #k lies in the less-than-pivot partition
high = low - 1
low = lo
elif ((k - 1) >= low and (k - 1) <= high): #k lies in the equals-to-pivot partition
qswap(base, base + size*low, size)
return
else: # k > high => k lies in the greater-than-pivot partition
low = high + 1
high = hi
qselectk3(base, low, high, size, k, compar)
"""
A selection algorithm to find the nth smallest elements in an unordered list.
these elements ARE placed at the nth positions of the input array
"""
cdef void qselect(void* base, size_t num, size_t size,
size_t n,
QComparator compar) nogil:
cdef int k
qshuffle(base, num, size)
for k in range(n):
qselectk3(base + size*k, 0, num - k - 1, size, 1, compar)
我使用 python timeit 来获取方法 pyselect(with N=50) 和 pysort 的性能。 像这样
def testPySelect():
A = np.random.randint(16, size=(10000), dtype=np.int32)
pyselect(A, 50)
timeit.timeit(testPySelect, number=1)
def testPySort():
A = np.random.randint(16, size=(10000), dtype=np.int32)
pysort(A)
timeit.timeit(testPySort, number=1)
【问题讨论】:
-
您需要展示您尝试过的内容,以及您是如何编译它的(编译时没有优化是导致性能问题的第二大常见原因)。
-
我已经习惯了链接中提供的代码。对于编译,我在 Mac OS X Lion 上使用 GCC 5。
-
但是在编译过程中使用了哪些标志/选项?
-
我在项目中也使用了openmp。所以除了numpy,包括,只有openmp编译和链接选项(-fopenmp)。
-
没有优化的编译是导致性能问题的第二大常见原因
标签: algorithm cython quicksort quickselect