Рубрики
Без рубрики

Умножение матриц в Java

Узнайте, как выполнять умножение матриц в Java с использованием различных реализаций.

Автор оригинала: François Dupire.

1. Обзор

В этом уроке мы рассмотрим, как мы можем умножить две матрицы в Java.

Поскольку концепция матрицы изначально не существует в языке, мы реализуем ее сами, а также будем работать с несколькими библиотеками, чтобы увидеть, как они обрабатывают умножение матриц.

В конце концов, мы проведем небольшой сравнительный анализ различных решений, которые мы исследовали, чтобы определить самое быстрое из них.

2. Пример

Давайте начнем с примера, на который мы сможем ссылаться в этом уроке.

Во-первых, мы представим матрицу 3×2:

Давайте теперь представим себе вторую матрицу, на этот раз две строки по четыре столбца:

Затем умножение первой матрицы на вторую матрицу, что приведет к матрице 3×4:

Напомним, что этот результат получается путем вычисления каждой ячейки результирующей матрицы по этой формуле :

Где r – количество строк матрицы A , c – количество столбцов матрицы B и n – количество столбцов матрицы A , которое должно соответствовать количеству строк матрицы B .

3. Умножение матриц

3.1. Собственная Реализация

Давайте начнем с вашей собственной реализации матриц.

Мы будем держать его простым и просто использовать двумерные двойные массивы :

double[][] firstMatrix = {
  new double[]{1d, 5d},
  new double[]{2d, 3d},
  new double[]{1d, 7d}
};

double[][] secondMatrix = {
  new double[]{1d, 2d, 3d, 7d},
  new double[]{5d, 2d, 8d, 1d}
};

Это две матрицы нашего примера. Давайте создадим тот, который ожидается в результате их умножения:

double[][] expected = {
  new double[]{26d, 12d, 43d, 12d},
  new double[]{17d, 10d, 30d, 17d},
  new double[]{36d, 16d, 59d, 14d}
};

Теперь, когда все настроено, давайте реализуем алгоритм умножения. Сначала мы создадим пустой массив результатов и переберем его ячейки, чтобы сохранить ожидаемое значение в каждой из них:

double[][] multiplyMatrices(double[][] firstMatrix, double[][] secondMatrix) {
    double[][] result = new double[firstMatrix.length][secondMatrix[0].length];

    for (int row = 0; row < result.length; row++) {
        for (int col = 0; col < result[row].length; col++) {
            result[row][col] = multiplyMatricesCell(firstMatrix, secondMatrix, row, col);
        }
    }

    return result;
}

Наконец, давайте реализуем вычисление одной ячейки. Для этого мы будем использовать формулу, показанную ранее в презентации примера :

double multiplyMatricesCell(double[][] firstMatrix, double[][] secondMatrix, int row, int col) {
    double cell = 0;
    for (int i = 0; i < secondMatrix.length; i++) {
        cell += firstMatrix[row][i] * secondMatrix[i][col];
    }
    return cell;
}

Наконец, давайте проверим, что результат алгоритма соответствует нашему ожидаемому результату:

double[][] actual = multiplyMatrices(firstMatrix, secondMatrix);
assertThat(actual).isEqualTo(expected);

3.2. EJML

Первая библиотека, которую мы рассмотрим, – это EJML, который расшифровывается как Эффективная матричная библиотека Java . На момент написания этого учебника это одна из самых последних обновленных библиотек матриц Java . Его цель состоит в том, чтобы быть как можно более эффективным в отношении вычислений и использования памяти.

Нам придется добавить зависимость в библиотеку в вашем pom.xml :


    org.ejml
    ejml-all
    0.38

Мы будем использовать практически тот же шаблон, что и раньше: создадим две матрицы в соответствии с нашим примером и проверим, что результат их умножения совпадает с тем, который мы рассчитали ранее.

Итак, давайте создадим наши матрицы с помощью EJML. Для достижения этой цели мы будем использовать класс Simple Matrix , предлагаемый библиотекой .

Он может принимать двухмерный двойной массив в качестве входных данных для своего конструктора:

SimpleMatrix firstMatrix = new SimpleMatrix(
  new double[][] {
    new double[] {1d, 5d},
    new double[] {2d, 3d},
    new double[] {1d ,7d}
  }
);

SimpleMatrix secondMatrix = new SimpleMatrix(
  new double[][] {
    new double[] {1d, 2d, 3d, 7d},
    new double[] {5d, 2d, 8d, 1d}
  }
);

А теперь давайте определим нашу ожидаемую матрицу для умножения:

SimpleMatrix expected = new SimpleMatrix(
  new double[][] {
    new double[] {26d, 12d, 43d, 12d},
    new double[] {17d, 10d, 30d, 17d},
    new double[] {36d, 16d, 59d, 14d}
  }
);

Теперь, когда все готово, давайте посмотрим, как умножить две матрицы вместе. Класс Simple Matrix предлагает multi() метод принимая другую Simple Matrix в качестве параметра и возвращая умножение двух матриц:

SimpleMatrix actual = firstMatrix.mult(secondMatrix);

Давайте проверим, соответствует ли полученный результат ожидаемому.

Поскольку Простая матрица не переопределяет метод equals () , мы не можем полагаться на него для проверки. Но, он предлагает альтернативу: isIdentical() метод , который принимает не только другой параметр матрицы, но и double отказоустойчивость, чтобы игнорировать небольшие различия из-за двойной точности:

assertThat(actual).matches(m -> m.isIdentical(expected, 0d));

Это завершает умножение матриц с помощью библиотеки EJML. Давайте посмотрим, что предлагают другие.

3.3. ND4J

Теперь давайте попробуем библиотеку ND4J . ND4J является вычислительной библиотекой и является частью проекта deeplearning4j . Помимо прочего, ND4J предлагает функции матричного вычисления.

Прежде всего, мы должны получить зависимость от библиотеки :


    org.nd4j
    nd4j-native
    1.0.0-beta4

Обратите внимание, что мы используем бета-версию здесь, потому что, похоже, в выпуске GA есть некоторые ошибки.

Для краткости мы не будем переписывать два массива измерений double и просто сосредоточимся на том, как они используются с каждой библиотекой. Таким образом, с помощью ND4J мы должны создать INDArray . Для этого мы вызовем метод Nd4j.create() factory и передадим ему массив double , представляющий нашу матрицу :

INDArray matrix = Nd4j.create(/* a two dimensions double array */);

Как и в предыдущем разделе, мы создадим три матрицы: две, которые мы собираемся умножить вместе, и одну, которая будет ожидаемым результатом.

После этого мы хотим фактически выполнить умножение между первыми двумя матрицами, используя метод INDArray.mmol() :

INDArray actual = firstMatrix.mmul(secondMatrix);

Затем мы снова проверяем, соответствует ли фактический результат ожидаемому. На этот раз мы можем положиться на проверку равенства:

assertThat(actual).isEqualTo(expected);

Это демонстрирует, как библиотека ND4J может использоваться для выполнения матричных вычислений.

3.4. Apache Commons

Давайте теперь поговорим о модуле Apache Commons Math3 , который предоставляет нам математические вычисления, включая манипуляции с матрицами.

Опять же, нам придется указать зависимость в вашем pom.xml :


    org.apache.commons
    commons-math3
    3.6.1

После настройки мы можем использовать интерфейс Real Matrix и его Array2DRowRealMatrix реализацию для создания наших обычных матриц. Конструктор класса реализации принимает в качестве параметра двумерный массив double :

RealMatrix matrix = new Array2DRowRealMatrix(/* a two dimensions double array */);

Что касается умножения матриц, интерфейс Real Matrix предлагает multiply() метод взятие другого Real Matrix параметра:

RealMatrix actual = firstMatrix.multiply(secondMatrix);

Мы можем, наконец, убедиться, что результат равен тому, что мы ожидаем:

assertThat(actual).isEqualTo(expected);

Давайте посмотрим следующую библиотеку!

3.5. LA4J

Этот называется LA4J, что означает Линейная алгебра для Java .

Давайте добавим зависимость и для этого:


    org.la4j
    la4j
    0.6.0

Теперь LA4J работает почти так же, как и другие библиотеки. Он предлагает Матрицу интерфейс с Basic2DMatrix реализацией , которая принимает двумерный двойной массив в качестве входных данных:

Matrix matrix = new Basic2DMatrix(/* a two dimensions double array */);

Как и в модуле Apache Commons Math3, метод умножения multiply() и принимает другую матрицу в качестве параметра:

Matrix actual = firstMatrix.multiply(secondMatrix);

Еще раз, мы можем проверить, соответствует ли результат нашим ожиданиям:

assertThat(actual).isEqualTo(expected);

Теперь давайте взглянем на нашу последнюю библиотеку: Кольт.

3.6. Кольт

Colt – это библиотека, разработанная ЦЕРНОМ. Он обеспечивает функции, обеспечивающие высокую производительность научных и технических вычислений.

Как и в предыдущих библиотеках, мы должны получить правильную зависимость :


    colt
    colt
    1.2.0

Чтобы создать матрицы с помощью Colt, мы должны использовать класс DoubleFactory2D . Он поставляется с тремя заводскими экземплярами: плотный, разреженный и сжатый ряд . Каждый из них оптимизирован для создания соответствующей матрицы.

Для нашей цели мы будем использовать экземпляр dense . На этот раз вызывается метод make() и он снова принимает двумерный двойной массив , создавая объект DoubleMatrix2D :

DoubleMatrix2D matrix = doubleFactory2D.make(/* a two dimensions double array */);

Как только наши матрицы будут созданы, мы захотим их умножить. На этот раз в объекте матрицы нет метода для этого. Мы должны создать экземпляр класса Algebra , который имеет multi() метод , принимающий две матрицы для параметров:

Algebra algebra = new Algebra();
DoubleMatrix2D actual = algebra.mult(firstMatrix, secondMatrix);

Затем мы можем сравнить фактический результат с ожидаемым:

assertThat(actual).isEqualTo(expected);

4. Бенчмаркинг

Теперь, когда мы закончили с изучением различных возможностей умножения матриц, давайте проверим, какие из них наиболее эффективны.

4.1. Малые матрицы

Давайте начнем с небольших матриц. Здесь матрицы 3×2 и 2×4.

Для реализации теста производительности мы будем использовать библиотеку бенчмаркинга JMH . Давайте настроим класс бенчмаркинга со следующими параметрами:

public static void main(String[] args) throws Exception {
    Options opt = new OptionsBuilder()
      .include(MatrixMultiplicationBenchmarking.class.getSimpleName())
      .mode(Mode.AverageTime)
      .forks(2)
      .warmupIterations(5)
      .measurementIterations(10)
      .timeUnit(TimeUnit.MICROSECONDS)
      .build();

    new Runner(opt).run();
}

Таким образом, JMH выполнит два полных запуска для каждого метода , аннотированного @Benchmark , каждый с пятью итерациями прогрева (не учитываемыми в среднем вычислении) и десятью итерациями измерения. Что касается измерений, то он соберет среднее время выполнения различных библиотек в микросекундах.

Затем мы должны создать объект состояния, содержащий наши массивы:

@State(Scope.Benchmark)
public class MatrixProvider {
    private double[][] firstMatrix;
    private double[][] secondMatrix;

    public MatrixProvider() {
        firstMatrix =
          new double[][] {
            new double[] {1d, 5d},
            new double[] {2d, 3d},
            new double[] {1d ,7d}
          };

        secondMatrix =
          new double[][] {
            new double[] {1d, 2d, 3d, 7d},
            new double[] {5d, 2d, 8d, 1d}
          };
    }
}

Таким образом, мы удостоверяемся, что инициализация массивов не является частью бенчмаркинга. После этого нам все равно придется создавать методы, которые выполняют умножение матриц, используя в качестве источника данных объект Matrix Provider . Мы не будем повторять код здесь, как мы видели каждую библиотеку ранее.

Наконец, мы запустим процесс бенчмаркинга, используя наш метод main . Это дает нам следующий результат:

Benchmark                                                           Mode  Cnt   Score   Error  Units
MatrixMultiplicationBenchmarking.apacheCommonsMatrixMultiplication  avgt   20   1,008 ± 0,032  us/op
MatrixMultiplicationBenchmarking.coltMatrixMultiplication           avgt   20   0,219 ± 0,014  us/op
MatrixMultiplicationBenchmarking.ejmlMatrixMultiplication           avgt   20   0,226 ± 0,013  us/op
MatrixMultiplicationBenchmarking.homemadeMatrixMultiplication       avgt   20   0,389 ± 0,045  us/op
MatrixMultiplicationBenchmarking.la4jMatrixMultiplication           avgt   20   0,427 ± 0,016  us/op
MatrixMultiplicationBenchmarking.nd4jMatrixMultiplication           avgt   20  12,670 ± 2,582  us/op

Как мы видим, EJML и Colt работают очень хорошо примерно с пятой долей микросекунды на операцию, где ND4j менее эффективен с чуть более чем десятью микросекундами на операцию . В других библиотеках есть спектакли, расположенные между ними.

Кроме того, стоит отметить, что при увеличении числа итераций прогрева с 5 до 10 производительность увеличивается для всех библиотек.

4.2. Большие матрицы

Теперь, что произойдет, если мы возьмем большие матрицы, например 3000×3000? Чтобы проверить, что происходит, давайте сначала создадим другой класс состояний, предоставляющий сгенерированные матрицы такого размера:

@State(Scope.Benchmark)
public class BigMatrixProvider {
    private double[][] firstMatrix;
    private double[][] secondMatrix;

    public BigMatrixProvider() {}

    @Setup
    public void setup(BenchmarkParams parameters) {
        firstMatrix = createMatrix();
        secondMatrix = createMatrix();
    }

    private double[][] createMatrix() {
        Random random = new Random();

        double[][] result = new double[3000][3000];
        for (int row = 0; row < result.length; row++) {
            for (int col = 0; col < result[row].length; col++) {
                result[row][col] = random.nextDouble();
            }
        }
        return result;
    }
}

Как мы видим, мы создадим двумерные двойные массивы размером 3000×3000, заполненные случайными вещественными числами.

Теперь давайте создадим класс бенчмаркинга:

public class BigMatrixMultiplicationBenchmarking {
    public static void main(String[] args) throws Exception {
        Map parameters = parseParameters(args);

        ChainedOptionsBuilder builder = new OptionsBuilder()
          .include(BigMatrixMultiplicationBenchmarking.class.getSimpleName())
          .mode(Mode.AverageTime)
          .forks(2)
          .warmupIterations(10)
          .measurementIterations(10)
          .timeUnit(TimeUnit.SECONDS);

        new Runner(builder.build()).run();
    }

    @Benchmark
    public Object homemadeMatrixMultiplication(BigMatrixProvider matrixProvider) {
        return HomemadeMatrix
          .multiplyMatrices(matrixProvider.getFirstMatrix(), matrixProvider.getSecondMatrix());
    }

    @Benchmark
    public Object ejmlMatrixMultiplication(BigMatrixProvider matrixProvider) {
        SimpleMatrix firstMatrix = new SimpleMatrix(matrixProvider.getFirstMatrix());
        SimpleMatrix secondMatrix = new SimpleMatrix(matrixProvider.getSecondMatrix());

        return firstMatrix.mult(secondMatrix);
    }

    @Benchmark
    public Object apacheCommonsMatrixMultiplication(BigMatrixProvider matrixProvider) {
        RealMatrix firstMatrix = new Array2DRowRealMatrix(matrixProvider.getFirstMatrix());
        RealMatrix secondMatrix = new Array2DRowRealMatrix(matrixProvider.getSecondMatrix());

        return firstMatrix.multiply(secondMatrix);
    }

    @Benchmark
    public Object la4jMatrixMultiplication(BigMatrixProvider matrixProvider) {
        Matrix firstMatrix = new Basic2DMatrix(matrixProvider.getFirstMatrix());
        Matrix secondMatrix = new Basic2DMatrix(matrixProvider.getSecondMatrix());

        return firstMatrix.multiply(secondMatrix);
    }

    @Benchmark
    public Object nd4jMatrixMultiplication(BigMatrixProvider matrixProvider) {
        INDArray firstMatrix = Nd4j.create(matrixProvider.getFirstMatrix());
        INDArray secondMatrix = Nd4j.create(matrixProvider.getSecondMatrix());

        return firstMatrix.mmul(secondMatrix);
    }

    @Benchmark
    public Object coltMatrixMultiplication(BigMatrixProvider matrixProvider) {
        DoubleFactory2D doubleFactory2D = DoubleFactory2D.dense;

        DoubleMatrix2D firstMatrix = doubleFactory2D.make(matrixProvider.getFirstMatrix());
        DoubleMatrix2D secondMatrix = doubleFactory2D.make(matrixProvider.getSecondMatrix());

        Algebra algebra = new Algebra();
        return algebra.mult(firstMatrix, secondMatrix);
    }
}

Когда мы проводим этот бенчмаркинг, мы получаем совершенно другие результаты:

Benchmark                                                              Mode  Cnt    Score    Error  Units
BigMatrixMultiplicationBenchmarking.apacheCommonsMatrixMultiplication  avgt   20  511.140 ± 13.535   s/op
BigMatrixMultiplicationBenchmarking.coltMatrixMultiplication           avgt   20  197.914 ±  2.453   s/op
BigMatrixMultiplicationBenchmarking.ejmlMatrixMultiplication           avgt   20   25.830 ±  0.059   s/op
BigMatrixMultiplicationBenchmarking.homemadeMatrixMultiplication       avgt   20  497.493 ±  2.121   s/op
BigMatrixMultiplicationBenchmarking.la4jMatrixMultiplication           avgt   20   35.523 ±  0.102   s/op
BigMatrixMultiplicationBenchmarking.nd4jMatrixMultiplication           avgt   20    0.548 ±  0.006   s/op

Как мы видим, самодельные реализации и библиотека Apache теперь намного хуже, чем раньше, и для выполнения умножения двух матриц требуется почти 10 минут.

Кольт занимает чуть больше 3 минут, что лучше, но все равно очень долго. EJML и LA4J работают довольно хорошо, поскольку они работают почти за 30 секунд. Но это ND4J, который выигрывает этот бенчмаркинг менее чем за секунду на бэкэнде процессора .

4.3. Анализ

Это показывает нам, что результаты бенчмаркинга действительно зависят от характеристик матриц, и поэтому сложно указать одного победителя.

5. Заключение

В этой статье мы узнали, как умножать матрицы в Java, либо самостоятельно, либо с помощью внешних библиотек. Изучив все решения, мы провели сравнительный анализ всех из них и увидели, что, за исключением ND4J, все они довольно хорошо работают на небольших матрицах. С другой стороны, на больших матрицах лидирует ND4J.

Как обычно, полный код этой статьи можно найти на GitHub .