如何加快系列生成?

2022-01-17 00:00:00 algorithm numbers optimization c++

问题需要生成一个类似于斐波那契数列的序列的第 n-th 元素.但是,这有点棘手,因为 n 非常大(1 <= n <= 10^9).答案然后模1000000007.序列定义如下:

The problem requires to generate the n-th element of a sequence that is similar to Fibonacci sequence. However, it's a bit tricky because n is very large (1 <= n <= 10^9). The answer then modulo 1000000007. The sequence is defined as follows:

使用生成函数,我得到以下公式:

Using generating function, I obtain the following formula:

如果我使用序列方法,那么答案可以是模数,但运行速度非常慢.事实上,我多次得到time limit exceeded.我还尝试使用表来预生成一些初始值(缓存),但速度仍然不够快.另外,我可以在 array/vector (C++) 中存储的最大元素数量与 10^9 相比太小了,所以我猜这种方法也行不通.
如果我使用直接公式,那么它的运行速度非常快,但仅适用于较小的 n.对于较大的 n,double 将被截断,而且我将无法使用该数字修改我的答案,因为 modulo 仅适用于整数.
我没有想法,我认为必须有一个非常好的技巧来解决这个问题,不幸的是我想不出一个.任何想法将不胜感激.

If I use the sequence approach then the answer can be modulo, but it run extremely slow. In fact, I got time limit exceed many times. I also tried to use a table to pre-generate some initial values (cache), still it was not fast enough. In addition, the maximum number of elements that I can store in an array/vector (C++) is too small compared with 10^9, so I guess this approach doesn't work either.
If I use the direct formula then it run extremely fast but only for n that is small. For n large, double will got truncated, plus I won't be able to mod my answer with that number because modulo only works with integer.
I ran out of idea, and I think there must be a very nice trick to work around this problem, unfortunately I just can't think of one. Any idea would be greatly appreciated.

这是我最初的方法:

#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
#include <cmath>
#include <cassert>
#include <bitset>
#include <fstream>
#include <iomanip>
#include <set>
#include <stack>
#include <sstream>
#include <cstdio>
#include <map>
#include <cmath>

using namespace std;

typedef unsigned long long ull;

ull count_fair_coins_by_generating_function(ull n) {
    n--;
    return 
        (sqrt(3.0) + 1)/((sqrt(3.0) - 1) * 2 * sqrt(3.0)) * pow(2 / (sqrt(3.0) - 1), n * 1.0) 
        +
        (1 - sqrt(3.0))/((sqrt(3.0) + 1) * 2 * sqrt(3.0)) * pow(-2 / (sqrt(3.0) + 1), n * 1.0);
}

ull count_fair_coins(ull n) {
    if (n == 1) {
        return 1;
    }
    else if (n == 2) {
        return 3;
    }
    else {
        ull a1 = 1;
        ull a2 = 3;
        ull result;
        for (ull i = 3; i <= n; ++i) {
            result = (2*a2 + 2*a1) % 1000000007;
            a1 = a2;
            a2 = result;
        }

        return result;
    }
}

void inout_my_fair_coins() {
    int test_cases;
    cin >> test_cases;

    map<ull, ull> cache;
    ull n;
    while (test_cases--) {
        cin >> n;
        cout << count_fair_coins_by_generating_function(n) << endl;
        cout << count_fair_coins(n) << endl;
    }
}

int main() {
    inout_my_fair_coins();
    return 0;
} 

更新比赛结束后,我将基于 tskuzzy 想法的解决方案发布给感兴趣的人.再次感谢 tskuzzy.您可以在此处查看原始问题陈述:http://www.codechef.com/problems/CSUMD
首先需要算出1个硬币2个硬币的概率,然后得到一些初始值得到序列.完整的解决方案在这里:

Update Since the contest was over, I posted my solution based on tskuzzy idea for those who are interested. Once again, thanks tskuzzy. You can view the original problem statement here: http://www.codechef.com/problems/CSUMD
First, you need to figure out the probability of those 1 coin and 2 coin, then get some initial values to obtain the sequence. The complete solution is here:

#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
#include <cmath>
#include <cassert>
#include <bitset>
#include <fstream>
#include <iomanip>
#include <set>
#include <stack>
#include <sstream>
#include <cstdio>
#include <map>
#include <cmath>

using namespace std;

typedef unsigned long long ull;

const ull special_prime = 1000000007;

/*
    Using generating function for the recurrence:
           | 1                     if n = 1
    a_n =  | 3                     if n = 2
           | 2a_{n-1} + 2a_{n-2}     if n > 2

    This method is probably the fastest one but it won't work 
    because when n is large, double just can't afford it. Plus,
    using this formula, we can't apply mod for floating point number.
    1 <= n <= 21
*/
ull count_fair_coins_by_generating_function(ull n) {
    n--;
    return 
        (sqrt(3.0) + 1)/((sqrt(3.0) - 1) * 2 * sqrt(3.0)) * pow(2 / (sqrt(3.0) - 1), n * 1.0) 
        +
        (1 - sqrt(3.0))/((sqrt(3.0) + 1) * 2 * sqrt(3.0)) * pow(-2 / (sqrt(3.0) + 1), n * 1.0);
}

/*
    Naive approach, it works but very slow. 
    Useful for testing.
*/
ull count_fair_coins(ull n) {
    if (n == 1) {
        return 1;
    }
    else if (n == 2) {
        return 3;
    }
    else {
        ull a1 = 1;
        ull a2 = 3;
        ull result;
        for (ull i = 3; i <= n; ++i) {
            result = (2*a2 + 2*a1) % 1000000007;
            a1 = a2;
            a2 = result;
        }

        return result;
    }
}

struct matrix_2_by_2 {
    ull m[2][2];
    ull a[2][2];
    ull b[2][2];

    explicit matrix_2_by_2(ull a00, ull a01, ull a10, ull a11) {
        m[0][0] = a00;
        m[0][1] = a01;
        m[1][0] = a10;
        m[1][1] = a11;
    }

    matrix_2_by_2 operator *(const matrix_2_by_2& rhs) const {
        matrix_2_by_2 result(0, 0, 0, 0);
        result.m[0][0] = (m[0][0] * rhs.m[0][0]) + (m[0][1] * rhs.m[1][0]);
        result.m[0][1] = (m[0][0] * rhs.m[0][1]) + (m[0][1] * rhs.m[1][1]);
        result.m[1][0] = (m[1][0] * rhs.m[0][0]) + (m[1][1] * rhs.m[1][0]);
        result.m[1][1] = (m[1][0] * rhs.m[0][1]) + (m[1][1] * rhs.m[1][1]);
        return result;
    }

    void square() {
        a[0][0] = b[0][0] = m[0][0];
        a[0][1] = b[0][1] = m[0][1];
        a[1][0] = b[1][0] = m[1][0];
        a[1][1] = b[1][1] = m[1][1];

        m[0][0] = (a[0][0] * b[0][0]) + (a[0][1] * b[1][0]);
        m[0][1] = (a[0][0] * b[0][1]) + (a[0][1] * b[1][1]);
        m[1][0] = (a[1][0] * b[0][0]) + (a[1][1] * b[1][0]);
        m[1][1] = (a[1][0] * b[0][1]) + (a[1][1] * b[1][1]);
    }

    void mod(ull n) {
        m[0][0] %= n;
        m[0][1] %= n;
        m[1][0] %= n;
        m[1][1] %= n;
    }

    /*
        exponentiation by squaring algorithm
                | 1                    if n = 0 
                | (1/x)^n              if n < 0 
        x^n =   | x.x^({(n-1)/2})^2    if n is odd
                | (x^{n/2})^2          if n is even

        The following algorithm calculate a^p % m
        int modulo(int a, int p, int m){
            long long x = 1;
            long long y = a; 

            while (p > 0) {
                if (p % 2 == 1){
                    x = (x * y) % m;
                }

                // squaring the base
                y = (y * y) % m; 
                p /= 2;
            }

            return x % c;
        }

        To apply for matrix, we need an identity which is
        equivalent to 1, then perform multiplication for matrix 
        in similar manner. Thus the algorithm is defined 
        as follows:
    */
    void operator ^=(ull p) {
        matrix_2_by_2 identity(1, 0, 0, 1);

        while (p > 0) {
            if (p % 2) {
                identity = operator*(identity);
                identity.mod(special_prime);
            }

            this->square();
            this->mod(special_prime);
            p /= 2;
        }

        m[0][0] = identity.m[0][0];
        m[0][1] = identity.m[0][1];
        m[1][0] = identity.m[1][0];
        m[1][1] = identity.m[1][1];
    }

    friend
    ostream& operator <<(ostream& out, const matrix_2_by_2& rhs) {
        out << rhs.m[0][0] << ' ' << rhs.m[0][1] << '
';
        out << rhs.m[1][0] << ' ' << rhs.m[1][1] << '
';
        return out;
    }
};

/*
    |a_{n+2}| = |2 2|^n  x |3| 
    |a_{n+1}|   |1 0|      |1|
*/
ull count_fair_coins_by_matrix(ull n) {
    if (n == 1) {
        return 1;
    } else {
        matrix_2_by_2 m(2, 2, 1, 0);
        m ^= (n - 1);
        return (m.m[1][0] * 3 + m.m[1][1]) % 1000000007;
    }
}

void inout_my_fair_coins() {
    int test_cases;
    scanf("%d", &test_cases);

    ull n;
    while (test_cases--) {
        scanf("%llu", &n);
        printf("%d
", count_fair_coins_by_matrix(n));
    }
}

int main() {
    inout_my_fair_coins();
    return 0;
}  

推荐答案

可以将数列的项写成矩阵指数:

You can write the terms of the sequence in terms of matrix exponentials:

可以使用平方求幂快速评估.这导致 O(log n) 解决方案应该在时间限制内很好地解决问题.

which can be quickly evaluated using exponentiation by squaring. This leads to an O(log n) solution which should solve the problem well within the time constraints.

仅供参考,如果您需要与大数相乘(不适用于这种情况,因为答案是取模 1000000007),您应该查看 Karatsuba 算法.这为您提供了二次时间乘法.

Just for future reference, if you are required to do multiplication with large numbers (not applicable in this situation since the answer is taken modulo 1000000007), you should look into the Karatsuba algorithm. This gives you sub-quadratic time multiplication.

相关文章