|
发表于 2024-2-24 09:57:58
|
显示全部楼层
本帖最后由 zhangjinxuan 于 2024-2-24 10:00 编辑
严格来说不会有 $O(n)$ 的时间复杂度,这个代码还是 $O(n \log_{10} n)$ 的复杂度。
不过你可以使用数位DP将时间复杂度做到 $O(\log_{10} n)$,虽然并不是楼主需要的 $O(n)$,但是这个算法比 $O(n)$ 算法优秀得多,可以解决 $n=10^{18}$ 甚至更大的情况。
- #include <bits/stdc++.h>
- using namespace std;
- #define int long long
- int d[18], m;
- long long l, r;
- struct Info {
- int a[10], numtot;
- const Info operator + (const Info &i) const {
- Info res = {{}, numtot + i.numtot};
- for (int j = 0; j <= 9; ++j) res.a[j] = a[j] + i.a[j];
- return res;
- }
- const Info operator - (const Info &i) const {
- Info res = {{}, numtot - i.numtot};
- for (int j = 0; j <= 9; ++j) res.a[j] = a[j] - i.a[j];
- return res;
- }
- } f[21][2][2];
- Info calc(int i, int zero, int limited) {
- if (i == m + 1) return {{}, zero};
- if (f[i][zero][limited].numtot != -1) return f[i][zero][limited];
- Info res = {{}, 0};
- if (i == 1) {
- for (int j = 0; j <= d[1]; ++j) {
- Info tmp = calc(i + 1, (j != 0), (j != d[1]));
- res = res + tmp;
- if (j != 0) {
- res.a[j] += tmp.numtot;
- }
- }
- } else {
- if (limited) {
- for (int j = 0; j <= 9; ++j) {
- Info tmp = calc(i + 1, (bool)(zero | j), 1);
- res = res + tmp;
- if ((bool)(zero | j)) {
- res.a[j] += tmp.numtot;
- }
- }
- } else { // ???üì? 0 ~ d[i] μ?????£?2¢?ú×?òa?ó
- for (int j = 0; j <= d[i]; ++j) {
- Info tmp = calc(i + 1, (bool)(zero | j), (j != d[i]));
- res = res + tmp;
- if ((bool)(zero | j)) {
- res.a[j] += tmp.numtot;
- }
- }
- }
- }
- return f[i][zero][limited] = res;
- }
- signed main() {
- int t = 1;
- while (t--) {
- l = 1;
- scanf("%lld", &l, &r);
- m = 0;
- for (long long tr = r; tr; tr /= 10)
- d[++m] = tr % 10;
- for (int i = 1, j = m; i < j; ++i, --j) swap(d[i], d[j]);
- for (int i = 0; i < 21; ++i) for (int j = 0; j < 2; ++j) for (int k = 0; k < 2; ++k) f[i][j][k].numtot = -1;
- Info tmp = calc(1, 0, 0);
- if (l != 1) {
- m = 0;
- memset(d, 255, sizeof(d));
- for (long long tr = l - 1; tr; tr /= 10)
- d[++m] = tr % 10;
- for (int i = 1, j = m; i < j; ++i, --j) swap(d[i], d[j]);
- for (int i = 0; i < 21; ++i) for (int j = 0; j < 2; ++j) for (int k = 0; k < 2; ++k) f[i][j][k].numtot = -1;
- tmp = tmp - calc(1, 0, 0);
- }
- long long res = 0;
- for (int i = 1; i <= 9; ++i) res += tmp.a[i] * i;
- printf("%lld\n", res);
- }
- // printf("%lld\n", tmp.numtot);
- return 0;
- }
- /*
- 01
- */
复制代码 |
|