根据所提供的参数优化Python代码以更快地生成置换和索引

2022-04-03 00:00:00 python permutation

问题描述

如何优化this answer,以给定的256字节为例?
示例字节为:

bytez = [197 215 20 156 94 67 20 100 27 208 186 248 71 48 128 75 7 165 148 223 94 163 233 15 161 104 246 66 242 142 118 165 204 0 252 22 233 28 136 197 113 122 72 229 11 91 133 142 20 204 119 211 170 104 63 39 46 68 150 123 148 95 96 95 17 133 243 35 45 66 76 19 41 200 141 120 110 215 140 230 252 182 42 166 59 249 171 97 124 8 138 59 112 191 87 170 218 31 51 74 112 23 37 13 63 96 61 200 110 189 59 18 11 99 94 63 245 107 31 11 217 51 133 35 113 36 154 179 223 92 31 239 20 51 200 102 133 183 240 88 104 29 81 122 28 246 161 90 89 6 241 241 19 40 43 248 78 6 234 40 171 23 143 70 122 246 180 148 183 67 158 198 212 41 0 98 171 81 122 114 229 193 213 212 65 72 120 191 228 32 132 172 88 100 104 119 253 166 159 242 246 6 66 190 31 57 175 105 161 1 109 8 1 50 97 60 101 25 131 93 51 243 203 41 11 140 231 59 131 68 177 58 80 142 9 21 20 106 132 161 187 21 253 234 222 190 91 106 192 149 4 70 77 139 170 172]  
distinct = 156  

我希望item_atitem_index这两个方法都使用256字节的列表,而不是字符串或整数。

更多详情:
对于item_atl256b将提供一个256字节的列表作为index,而对于distinct(提供的l256b中存在不同的字节数),0<=x<256将由方法的用户提供。
不需要参数alphabetlength,因为它们都是常量,对于alphabet是字节0<=x<256,对于length是256字节。
item_at必须返回256个字节的列表,该列表是提供的索引的排列。

对于item_indexl256b将提供一个256字节的列表作为item(排列),而对于distinct(提供的l256b中存在不同的字节数),该方法的用户将提供一个介于0<=d<256之间的值。
不需要参数alphabetlength,因为它们都是常量,对于alphabet是字节0<=x<256,对于length是256字节。
item_index必须返回256个字节的列表,这是提供的排列的索引。


解决方案

1-处理清单

有两个细节,一个是该解决方案设计为使用字符串,但可以很容易地进行修改以处理列表

2-仅跟踪前缀的计数

正在使用的代码正在构造整个前缀,从而在每次调用时扩展它,这会导致大量复制。在函数item_index中,该前缀仅用于知道是否使用了给定的符号。相反,可以做的是有一本字典,上面写着每个符号在前缀中的编号。然后使用prefixCount[d] != 0而不是选中d in prefix

3-调整缓存大小

您可以看到该解决方案使用LRU cache,默认情况下,这种类型的缓存将只存储128个最新的元素。您可以使用lru_cache(maxsize=None)或简单的cache()来修饰函数,如果您知道输入的最大长度是256,使用lru_cache(maxsize=256**2)就足够了。

@lru_cache(maxsize=256**2)
def count_seq(n_symbols, length, distinct, used=0):
    if distinct < 0:
        return 0
    if length == 0:
        return 1 if distinct == 0 else 0
    else:
        return 
          count_seq(n_symbols, length-1, distinct-0, used+0) * used + 
          count_seq(n_symbols, length-1, distinct-1, used+1) * (n_symbols - used)
def item_at(idx, alphabet, length, distinct, used=0, prefix=None):
    if prefix is None:
        prefix = [];
    if distinct < 0:
        return
    if length == 0:
        return prefix
    else:
        for d in alphabet:
            if d in prefix:
                branch_count = count_seq(len(alphabet), 
                                         length-1, distinct, used)
                if branch_count <= idx:
                    idx -= branch_count
                else:
                    prefix.append(d);
                    return item_at(idx, alphabet, 
                                   length-1, distinct, used, prefix)
            else:
                branch_count = count_seq(len(alphabet),
                                         length-1, distinct-1, used+1)
                if branch_count <= idx:
                    idx -= branch_count
                else:
                    prefix.append(d);
                    return item_at(idx, alphabet,
                                   length-1, distinct-1, used+1, prefix)

def item_index(item, alphabet, length, distinct, used=0, prefixCount=None, idx=0):
    if prefixCount is None:
        prefixCount = {a: 0 for a in alphabet}
    if distinct < 0:
        return 0
    if length == 0:
        return 0
    else:
        offset = 0
        for d in alphabet:
            if prefixCount[d] != 0:
                if d == item[idx]:
                    prefixCount[d] += 1
                    return offset + item_index(item, alphabet, 
                               length-1, distinct, used, prefixCount, idx+1)
                else:
                    offset += count_seq(len(alphabet), 
                                length-1, distinct, used)
            else:
                if d == item[idx]:
                    prefixCount[d] += 1;
                    return offset + item_index(item, alphabet, 
                             length-1, distinct-1, used+1, prefixCount, idx+1)
                else:
                    offset += count_seq(len(alphabet), 
                                 length-1, distinct-1, used+1)

在现代计算机中,它将在几毫秒内运行

迭代实施

我正在编写一个您将实例化的类,给出一个字母表和您想要的不同符号的数量,在本例中,在所有递归中distinct + used是不变的。count_seq的结果在构造时在矩阵C中预先计算。方法item_atitem_index是基于C计算结果的迭代实现。

在我看来,这变得不那么可读性,因为在递归实现中,一切都是用具有明确概念关联的函数调用来表示的。

class SequenceLookup:
    def __init__(self, alphabet, length, distinct):
        self.alphabet = list(alphabet)
        self.distinct = distinct
        n_symbols = len(alphabet)
        c = [0] * distinct + [1, 0]
        C = [c]
        for l in range(2,length+1):
            c = [
                c[d] * d + c[d+1] * (n_symbols - d)
                for d in range(distinct+1)
            ] + [0]
            C.append(c)
        self.C = C
    
    def item_index(self, item):
        length = len(item)
        offset = 0
        seen = set()
        for i,di in enumerate(item):
            for d in self.alphabet:
                if d == di:
                    break;
                if d in seen:
                    offset += self.C[length-1-i][len(seen)]
                else:
                    offset += self.C[length-1-i][len(seen)+1]
            seen.add(di)
        return offset
    def item_at(self, idx, length):
        seen = set()
        prefix = []
        for i in range(length):
            for d in self.alphabet:
                if d in prefix:
                    branch_count = self.C[length-1-i][len(seen)]
                else:
                    branch_count = self.C[length-1-i][len(seen)+1]
                if branch_count <= idx:
                    idx -= branch_count
                else:
                    prefix.append(d)
                    seen.add(d)
                    break
        return prefix
bytez=[197, 215, 20, 156, 94, 67, 20, 100, 27, 208, 186, 248, 
       71, 48, 128, 75, 7, 165, 148, 223, 94, 163, 233, 15,
       161, 104, 246, 66, 242, 142, 118, 165, 204, 0, 252,
       22, 233, 28, 136, 197, 113, 122, 72, 229, 11, 91, 133,
       142, 20, 204, 119, 211, 170, 104, 63, 39, 46, 68, 150,
       123, 148, 95, 96, 95, 17, 133, 243, 35, 45, 66, 76, 19,
       41, 200, 141, 120, 110, 215, 140, 230, 252, 182, 42, 
       166, 59, 249, 171, 97, 124, 8, 138, 59, 112, 191, 87, 
       170, 218, 31, 51, 74, 112, 23, 37, 13, 63, 96, 61, 200, 
       110, 189, 59, 18, 11, 99, 94, 63, 245, 107, 31, 11, 
       217, 51, 133, 35, 113, 36, 154, 179, 223, 92, 31, 239, 
       20, 51, 200, 102, 133, 183, 240, 88, 104, 29, 81, 122,
       28, 246, 161, 90, 89, 6, 241, 241, 19, 40, 43, 248, 78,
       6, 234, 40, 171, 23, 143, 70, 122, 246, 180, 148, 183,
       67, 158, 198, 212, 41, 0, 98, 171, 81, 122, 114, 229,
       193, 213, 212, 65, 72, 120, 191, 228, 32, 132, 172, 88,
       100, 104, 119, 253, 166, 159, 242, 246, 6, 66, 190, 31,
       57, 175, 105, 161, 1, 109, 8, 1, 50, 97, 60, 101, 25,
       131, 93, 51, 243, 203, 41, 11, 140, 231, 59, 131, 68,
       177, 58, 80, 142, 9, 21, 20, 106, 132, 161, 187, 21, 253, 
       234, 222, 190, 91, 106, 192, 149, 4, 70, 77, 139, 170, 172]
v = SequenceLookup(range(256), len(bytez), len(set(bytez)))
%%timeit
v = SequenceLookup(range(256), len(bytez), len(set(bytez)))

11.4 ms ± 229 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

#%%timeit
v.item_index(bytez)

7.57 ms ± 132 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

#%%timeit
v.item_at(t, 256)

33.6 ms ± 598 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

专门用于字节

使用固定字母的实现[0,..255]

class SequenceLookup:
    def __init__(self, length, distinct):
        self.distinct = distinct
        c = [0] * distinct + [1, 0]
        C = [c]
        for l in range(2,length+1):
            c = [
                c[d] * d + c[d+1] * (256 - d)
                for d in range(distinct+1)
            ] + [0]
            C.append(c)
        self.C = C
    
    def item_index(self, item):
        length = len(item)
        offset = 0
        seen = set()
        for i,di in enumerate(item):
            for d in range(256):
                if d == di:
                    break;
                if d in seen:
                    offset += self.C[length-1-i][len(seen)]
                else:
                    offset += self.C[length-1-i][len(seen)+1]
            seen.add(di)
        return offset

    def item_at(self, idx, length):
        seen = [0] * 256
        prefix = [0] * length
        used = 0
        for i in range(length):
            for d in range(256):
                if seen[d] != 0:
                    branch_count = self.C[length-1-i][used]
                else:
                    branch_count = self.C[length-1-i][used+1]
                if branch_count <= idx:
                    idx -= branch_count
                else:
                    prefix[i] = d;
                    if seen[d] == 0:
                        used += 1;
                    seen[d] = 1
                    break
        return prefix

使用此实现构造和item_index花费的时间基本相同,但item_at在我的测试中运行得更快

6.32 ms ± 91.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

这当然会有所不同,因此您可能希望自己对不同的数据结构尝试相同的算法。

相关文章