为什么使用流的这段代码在 Java 9 中的运行速度比 Java 8 快得多?

2022-01-22 00:00:00 performance java-8 java java-stream java-9

我在解决 问题 205 时发现了这一点/" rel="noreferrer">欧拉计划.问题如下:

<块引用>

彼得有九个四面(金字塔形)骰子,每个骰子的面编号为 1、2、3、4.Colin 有六个六面(立方)骰子,每个骰子的面编号为 1、2、3、4、5、6.

Peter 和 Colin 掷骰子并比较总分:总分最高的获胜.如果总分相等,则为平局.

金字塔皮特击败立方科林的概率是多少?以 0.abcdefg 的形式给出你的答案,四舍五入到小数点后七位

我用 Guava 写了一个简单的解决方案:

导入 com.google.common.collect.Sets;导入 com.google.common.collect.ImmutableSet;导入 java.math.BigDecimal;导入 java.math.RoundingMode;导入 java.util.*;导入 java.util.stream.Collectors;公共课问题205 {公共静态无效主要(字符串[]参数){long startTime = System.currentTimeMillis();列表<整数>彼得 = Sets.cartesianProduct(Collections.nCopies(9, ImmutableSet.of(1, 2, 3, 4))).溪流().map(l -> l.溪流().mapToInt(整数::intValue).和()).collect(Collectors.toList());列表<整数>colin = Sets.cartesianProduct(Collections.nCopies(6, ImmutableSet.of(1, 2, 3, 4, 5, 6))).溪流().map(l -> l.溪流().mapToInt(整数::intValue).和()).collect(Collectors.toList());long startTime2 = System.currentTimeMillis();//这里很重要!v长解决方案 = 彼得.溪流().mapToLong(p -> colin.溪流().filter(c -> p > c).数数()).和();//这里很重要!^System.out.println("计数解决方案占用" + (System.currentTimeMillis() - startTime2) + "ms");System.out.println("解:" + BigDecimal.valueOf(解决方案).divide(BigDecimal.valueOf((long) Math.pow(4, 9) * (long) Math.pow(6, 6)),7、RoundingMode.HALF_UP));System.out.println("发现于:" + (System.currentTimeMillis() - startTime) + "ms");}}

我突出显示的代码,它使用一个简单的 filter()count()sum(),似乎运行很多Java 9 比 Java 8 更快.具体来说,Java 8 在我的机器上计算解决方案的时间为 37465 毫秒.Java 9 大约在 16000 毫秒内完成,无论我运行使用 Java 8 编译的文件还是使用 Java 9 编译的文件都是一样的.

如果我用似乎与 pre-streams 完全等效的代码替换流代码:

长解 = 0;对于(整数 p:彼得){长计数 = 0;对于(整数 c:科林){如果 (p > c) {计数++;}}解决方案 += 计数;}

它在大约 35000 毫秒内计算解决方案,Java 8 和 Java 9 之间没有可测量的差异.

我在这里缺少什么?为什么 Java 9 中的流编码速度如此之快,为什么 for 循环没有?

<小时>

我正在运行 Ubuntu 16.04 LTS 64 位.我的 Java 8 版本:

java 版本1.8.0_131"Java(TM) SE 运行时环境 (build 1.8.0_131-b11)Java HotSpot(TM) 64 位服务器 VM(内部版本 25.131-b11,混合模式)

我的 Java 9 版本:

java 版本 "9"Java(TM) SE 运行时环境 (build 9+181)Java HotSpot(TM) 64 位服务器 VM(构建 9+181,混合模式)

解决方案

1.为什么流在 JDK 9 上运行得更快

Stream.count() 实现是 相当愚蠢:它只是遍历整个流,为每个元素添加 1L.p>

这是在 JDK 9 中已修复.尽管错误报告中提到大小流,新代码也改进了非大小流.

如果将 .count() 替换为 Java 8 风格的实现 .mapToLong(e -> 1L).sum(),它甚至会再次变慢在 JDK 9 上.

2.为什么天真的循环运行缓慢

当您将所有代码放在 main 方法中时,它无法有效地进行 JIT 编译.该方法只执行一次,它开始在解释器中运行,之后,当 JVM 检测到热循环时,它会从解释模式切换到即时编译模式.这称为堆栈上替换 (OSR).

OSR 编译通常不如常规编译方法优化.我之前已经详细解释过,请参阅 this 和 这个答案.

如果将内部循环放在单独的方法中,JIT 将生成更好的代码:

 长解 = 0;对于(整数 p:彼得){解决方案 += countLargerThan(colin, p);}...私有静态 int countLargerThan(List<Integer> colin, int p) {整数计数 = 0;对于(整数 c:科林){如果 (p > c) {计数++;}}返回计数;}

在这种情况下,countLargerThan 方法将被正常编译,并且性能将优于 JDK 8 和 JDK 9 上的流.

I discovered this while solving Problem 205 of Project Euler. The problem is as follows:

Peter has nine four-sided (pyramidal) dice, each with faces numbered 1, 2, 3, 4. Colin has six six-sided (cubic) dice, each with faces numbered 1, 2, 3, 4, 5, 6.

Peter and Colin roll their dice and compare totals: the highest total wins. The result is a draw if the totals are equal.

What is the probability that Pyramidal Pete beats Cubic Colin? Give your answer rounded to seven decimal places in the form 0.abcdefg

I wrote a naive solution using Guava:

import com.google.common.collect.Sets;
import com.google.common.collect.ImmutableSet;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.*;
import java.util.stream.Collectors;

public class Problem205 {
    public static void main(String[] args) {
        long startTime = System.currentTimeMillis();
        List<Integer> peter = Sets.cartesianProduct(Collections.nCopies(9, ImmutableSet.of(1, 2, 3, 4)))
                .stream()
                .map(l -> l
                        .stream()
                        .mapToInt(Integer::intValue)
                        .sum())
                .collect(Collectors.toList());
        List<Integer> colin = Sets.cartesianProduct(Collections.nCopies(6, ImmutableSet.of(1, 2, 3, 4, 5, 6)))
                .stream()
                .map(l -> l
                        .stream()
                        .mapToInt(Integer::intValue)
                        .sum())
                .collect(Collectors.toList());

        long startTime2 = System.currentTimeMillis();
        // IMPORTANT BIT HERE! v
        long solutions = peter
                .stream()
                .mapToLong(p -> colin
                        .stream()
                        .filter(c -> p > c)
                        .count())
                .sum();

        // IMPORTANT BIT HERE! ^
        System.out.println("Counting solutions took " + (System.currentTimeMillis() - startTime2) + "ms");

        System.out.println("Solution: " + BigDecimal
                .valueOf(solutions)
                .divide(BigDecimal
                                .valueOf((long) Math.pow(4, 9) * (long) Math.pow(6, 6)),
                        7,
                        RoundingMode.HALF_UP));
        System.out.println("Found in: " + (System.currentTimeMillis() - startTime) + "ms");
    }
}

The code I have highlighted, which uses a simple filter(), count() and sum(), seems to run much faster in Java 9 than Java 8. Specifically, Java 8 counts the solutions in 37465ms on my machine. Java 9 does it in about 16000ms, which is the same whether I run the file compiled with Java 8 or one compiled with Java 9.

If I replace the streams code with what would seem to be the exact pre-streams equivalent:

long solutions = 0;
for (Integer p : peter) {
    long count = 0;
    for (Integer c : colin) {
        if (p > c) {
            count++;
        }
    }
    solutions += count;
}

It counts the solutions in about 35000ms, with no measurable difference between Java 8 and Java 9.

What am I missing here? Why is the streams code so much faster in Java 9, and why isn't the for loop?


I am running Ubuntu 16.04 LTS 64-bit. My Java 8 version:

java version "1.8.0_131"
Java(TM) SE Runtime Environment (build 1.8.0_131-b11)
Java HotSpot(TM) 64-Bit Server VM (build 25.131-b11, mixed mode)

My Java 9 version:

java version "9"
Java(TM) SE Runtime Environment (build 9+181)
Java HotSpot(TM) 64-Bit Server VM (build 9+181, mixed mode)

解决方案

1. Why the stream works faster on JDK 9

Stream.count() implementation is rather dumb in JDK 8: it just iterates through the whole stream adding 1L for each element.

This was fixed in JDK 9. Even though the bug report says about SIZED streams, new code improves non-sized streams, too.

If you replace .count() with Java 8-style implementation .mapToLong(e -> 1L).sum(), it will be slow again even on JDK 9.

2. Why naive loop works slow

When you put all your code in main method, it cannot be JIT-compiled efficiently. This method is executed only once, it starts running in interpreter and later, when JVM detects a hot loop, it switches from interpreted mode to compiled on-the-go. This is called on-stack replacement (OSR).

OSR compilations are often not as optimized as regular compiled methods. I've explained this in detail earlier, see this and this answer.

JIT will produce better code if you put the inner loop in a separate method:

    long solutions = 0;
    for (Integer p : peter) {
        solutions += countLargerThan(colin, p);
    }

    ...

    private static int countLargerThan(List<Integer> colin, int p) {
        int count = 0;
        for (Integer c : colin) {
            if (p > c) {
                count++;
            }
        }
        return count;
    }

In this case countLargerThan method will be compiled normally, and the performance will be better than with streams both on JDK 8 and on JDK 9.

相关文章