苹果深度学习框架`MLX`简介及编译示例

苹果深度学习框架`MLX`简介及编译示例

苹果深度学习框架MLX简介及编译示例

MLX简介

北京时间2023年12月6日,苹果机器学习研究中心 (Apple machine learning research) 在GitHub上开源MLX。 项目地址为 https://github.com/ml-explore/mlx。

MLX深度学习框架是苹果专门为Apple Silicon芯片优化,号称能够简化Mac、iPad、iPhone平台研究人员设计和部署模型的过程。

MLX的一些主要功能包括:

熟悉的 API: MLX具有紧密类似NumPy的Python API。 MLX还拥有功能齐全的C++ API,它与Python API非常相似。 MLX具有像mlx.nn和mlx.optimizers这样的更高级别的软件包, 其API紧密类似PyTorch,用于简化构建更复杂的模型。

可组合函数转换: MLX支持自动微分、自动矢量化和计算图优化等可组合函数的转换。

惰性计算: MLX中的计算是惰性计算。数组只是在需要时生成。

动态图构建: MLX中的计算图是动态构建的。更改函数参数的形状不会导致编译速度减慢,调试简单直观。

多设备: 操作可以在任何支持的设备上运行 (目前是CPU和GPU)。

统一内存: MLX和其他框架的显着区别是统一内存模型。MLX中的数组位于共享内存中。 MLX数组上的操作可以在任何受支持的设备类型上执行,而不需要传输数据。

MLX是由苹果机器学习研究中心的机器学习研究人员为机器学习研究人员而设计的。 该框架旨在用户友好,但仍然高效训练和部署模型。 框架本身的设计也是概念上很简单。 目的是让研究人员能够轻松扩展和改进MLX,以快速探索新想法为目标。

不能简单的把MLX视为造轮子,苹果既然发布了GPU,就自然的挖掘GPU的运算潜力。

苹果的芯片架构与以往主流的芯片架构不同,是统一内存模型。 统一内存模型不同于TensorFlow等框架需要显式管理内存,大大简化了编程模型。

编译MLX示例

由于编译MLX需要使用XCode,因此不适合在容器中编译,需要直接在MacOS中编译。 为便于管理电脑上下载的各种源代码,使用ghq下载。

下载源代码

ghq get --shallow https://github.com/ml-explore/mlx

在访达中浏览项目结构。

open -R "$(ghq list --full-path https://github.com/ml-explore/mlx)"

在命令行中浏览项目结构。

cd "$(ghq list --full-path https://github.com/ml-explore/mlx)" && ls -al

内容如下

total 168

drwxr-xr-x 26 huzhenghui staff 832 12 16 11:41 .

drwxr-xr-x 4 huzhenghui staff 128 12 16 08:09 ..

-rw-r--r--@ 1 huzhenghui staff 6148 12 16 08:25 .DS_Store

drwxr-xr-x 3 huzhenghui staff 96 12 16 08:00 .circleci

-rw-r--r-- 1 huzhenghui staff 2552 12 16 08:00 .clang-format

drwxr-xr-x 13 huzhenghui staff 416 12 16 08:00 .git

drwxr-xr-x 4 huzhenghui staff 128 12 16 08:00 .github

-rw-r--r-- 1 huzhenghui staff 733 12 16 08:00 .gitignore

-rw-r--r-- 1 huzhenghui staff 433 12 16 08:00 .pre-commit-config.yaml

-rw-r--r-- 1 huzhenghui staff 12320 12 16 08:00 ACKNOWLEDGMENTS.md

-rw-r--r-- 1 huzhenghui staff 6533 12 16 08:00 CMakeLists.txt

-rw-r--r-- 1 huzhenghui staff 5544 12 16 08:00 CODE_OF_CONDUCT.md

-rw-r--r-- 1 huzhenghui staff 1292 12 16 08:00 CONTRIBUTING.md

-rw-r--r-- 1 huzhenghui staff 1066 12 16 08:00 LICENSE

-rw-r--r-- 1 huzhenghui staff 69 12 16 08:00 MANIFEST.in

-rw-r--r-- 1 huzhenghui staff 3523 12 16 08:00 README.md

drwxr-xr-x 5 huzhenghui staff 160 12 16 08:00 benchmarks

drwxr-xr-x 3 huzhenghui staff 96 12 16 08:00 cmake

drwxr-xr-x 9 huzhenghui staff 288 12 16 08:00 docs

drwxr-xr-x 5 huzhenghui staff 160 12 16 08:00 examples

drwxr-xr-x 35 huzhenghui staff 1120 12 16 08:00 mlx

-rw-r--r-- 1 huzhenghui staff 1364 12 16 08:00 mlx.pc.in

-rw-r--r-- 1 huzhenghui staff 118 12 16 08:00 pyproject.toml

drwxr-xr-x 6 huzhenghui staff 192 12 16 08:00 python

-rw-r--r-- 1 huzhenghui staff 6887 12 16 08:00 setup.py

drwxr-xr-x 21 huzhenghui staff 672 12 16 08:00 tests

创建./build文件夹用于构建。

cd "$(ghq list --full-path https://github.com/ml-explore/mlx)" && mkdir -p build

运行CMake,为编译示例,需要设置MLX_BUILD_EXAMPLES环境变量。

export MLX_BUILD_EXAMPLES=ON

cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && cmake ..

输出如下内容

-- The CXX compiler identification is AppleClang 15.0.0.15000100

-- Detecting CXX compiler ABI info

-- Detecting CXX compiler ABI info - done

-- Check for working CXX compiler: /Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/c++ - skipped

-- Detecting CXX compile features

-- Detecting CXX compile features - done

-- Building MLX for arm64 processor on Darwin

-- Building METAL sources

-- Building with SDK for macOS version 14.2

-- Accelerate found /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX14.2.sdk/System/Library/Frameworks/Accelerate.framework

CMake Deprecation Warning at build/_deps/doctest-src/CMakeLists.txt:1 (cmake_minimum_required):

Compatibility with CMake < 3.5 will be removed from a future version of

CMake.

Update the VERSION argument value or use a ... suffix to tell

CMake that the project does not need compatibility with older versions.

-- Configuring done (24.2s)

-- Generating done (0.1s)

-- Build files have been written to: /Users/huzhenghui/ghq/github.com/ml-explore/mlx/build

运行make编译,为提升速度,使用--jobs并行编译。

cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && make --jobs

运行结果如下。

[ 1%] Building unary.air

[ 3%] Building reduce.air

[ 4%] Building indexing.air

[ 5%] Building arange.air

[ 6%] Building sort.air

[ 6%] Building softmax.air

[ 9%] Building copy.air

[ 9%] Building random.air

[ 10%] Building scan.air

[ 11%] Building gemv.air

[ 12%] Building conv.air

[ 13%] Building gemm.air

[ 15%] Building binary.air

[ 16%] Building arg_reduce.air

[ 17%] Building mlx.metallib

[ 17%] Built target mlx-metallib

[ 18%] Building CXX object CMakeFiles/mlx.dir/mlx/device.cpp.o

[ 19%] Building CXX object CMakeFiles/mlx.dir/mlx/allocator.cpp.o

[ 22%] Building CXX object CMakeFiles/mlx.dir/mlx/scheduler.cpp.o

[ 22%] Building CXX object CMakeFiles/mlx.dir/mlx/graph_utils.cpp.o

[ 23%] Building CXX object CMakeFiles/mlx.dir/mlx/transforms.cpp.o

[ 25%] Building CXX object CMakeFiles/mlx.dir/mlx/primitives.cpp.o

[ 25%] Building CXX object CMakeFiles/mlx.dir/mlx/fft.cpp.o

[ 26%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/reduce.cpp.o

[ 27%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/fft.cpp.o

[ 29%] Building CXX object CMakeFiles/mlx.dir/mlx/load.cpp.o

[ 30%] Building CXX object CMakeFiles/mlx.dir/mlx/array.cpp.o

[ 32%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/conv.cpp.o

[ 32%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/indexing.cpp.o

[ 33%] Building CXX object CMakeFiles/mlx.dir/mlx/dtype.cpp.o

[ 36%] Building CXX object CMakeFiles/mlx.dir/mlx/ops.cpp.o

[ 36%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/binary.cpp.o

[ 37%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/scan.cpp.o

[ 38%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/erf.cpp.o

[ 39%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/conv.cpp.o

[ 41%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/accelerate/primitives.cpp.o

[ 41%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/copy.cpp.o

[ 43%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/metal.cpp.o

[ 44%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/accelerate/reduce.cpp.o

[ 45%] Building CXX object CMakeFiles/mlx.dir/mlx/utils.cpp.o

[ 46%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/softmax.cpp.o

[ 47%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/copy.cpp.o

[ 48%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/arg_reduce.cpp.o

[ 50%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/primitives.cpp.o

[ 51%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/primitives.cpp.o

[ 52%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/load.cpp.o

[ 53%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/accelerate/conv.cpp.o

[ 55%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/accelerate/matmul.cpp.o

[ 55%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/device.cpp.o

[ 58%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/allocator.cpp.o

[ 58%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/fft.cpp.o

[ 59%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/threefry.cpp.o

[ 60%] Building CXX object CMakeFiles/mlx.dir/mlx/random.cpp.o

[ 61%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/sort.cpp.o

[ 63%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/accelerate/softmax.cpp.o

[ 63%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/scan.cpp.o

[ 65%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/matmul.cpp.o

[ 66%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/indexing.cpp.o

[ 67%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/reduce.cpp.o

[ 68%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/sort.cpp.o

[ 69%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/softmax.cpp.o

[ 70%] Linking CXX static library libmlx.a

[ 70%] Built target mlx

[ 74%] Building CXX object examples/cpp/CMakeFiles/linear_regression.dir/linear_regression.cpp.o

[ 74%] Building CXX object examples/cpp/CMakeFiles/tutorial.dir/tutorial.cpp.o

[ 74%] Building CXX object tests/CMakeFiles/tests.dir/tests.cpp.o

[ 76%] Building CXX object tests/CMakeFiles/tests.dir/array_tests.cpp.o

[ 76%] Building CXX object tests/CMakeFiles/tests.dir/allocator_tests.cpp.o

[ 77%] Building CXX object examples/cpp/CMakeFiles/logistic_regression.dir/logistic_regression.cpp.o

[ 79%] Building CXX object tests/CMakeFiles/tests.dir/autograd_tests.cpp.o

[ 80%] Building CXX object tests/CMakeFiles/tests.dir/arg_reduce_tests.cpp.o

[ 81%] Building CXX object tests/CMakeFiles/tests.dir/blas_tests.cpp.o

[ 82%] Building CXX object tests/CMakeFiles/tests.dir/eval_tests.cpp.o

[ 83%] Building CXX object tests/CMakeFiles/tests.dir/graph_optimize_tests.cpp.o

[ 84%] Building CXX object tests/CMakeFiles/tests.dir/creations_tests.cpp.o

[ 86%] Building CXX object tests/CMakeFiles/tests.dir/device_tests.cpp.o

[ 87%] Building CXX object tests/CMakeFiles/tests.dir/fft_tests.cpp.o

[ 88%] Building CXX object tests/CMakeFiles/tests.dir/load_tests.cpp.o

[ 89%] Building CXX object tests/CMakeFiles/tests.dir/ops_tests.cpp.o

[ 90%] Building CXX object tests/CMakeFiles/tests.dir/metal_tests.cpp.o

[ 91%] Building CXX object tests/CMakeFiles/tests.dir/vmap_tests.cpp.o

[ 93%] Building CXX object tests/CMakeFiles/tests.dir/scheduler_tests.cpp.o

[ 94%] Building CXX object tests/CMakeFiles/tests.dir/utils_tests.cpp.o

[ 95%] Building CXX object tests/CMakeFiles/tests.dir/random_tests.cpp.o

[ 96%] Linking CXX executable linear_regression

[ 97%] Linking CXX executable tutorial

[ 98%] Linking CXX executable logistic_regression

[ 98%] Built target linear_regression

[ 98%] Built target logistic_regression

[ 98%] Built target tutorial

[100%] Linking CXX executable tests

[100%] Built target tests

可以看到编译了三个示例。

[ 96%] Linking CXX executable linear_regression

[ 97%] Linking CXX executable tutorial

[ 98%] Linking CXX executable logistic_regression

[ 98%] Built target linear_regression

[ 98%] Built target logistic_regression

[ 98%] Built target tutorial

如果没有看到,说明没有正确设置环境变量MLX_BUILD_EXAMPLES。

运行测试。

cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && make test

测试结果报错。

Running tests...

Test project /Users/huzhenghui/ghq/github.com/ml-explore/mlx/build

Start 1: tests

1/1 Test #1: tests ............................***Failed 1.06 sec

0% tests passed, 1 tests failed out of 1

Total Test time (real) = 1.06 sec

The following tests FAILED:

1 - tests (Failed)

Errors while running CTest

Output from these tests are in: /Users/huzhenghui/ghq/github.com/ml-explore/mlx/build/Testing/Temporary/LastTest.log

Use "--rerun-failed --output-on-failure" to re-run the failed cases verbosely.

make: *** [test] Error 8

不必担忧,单元测试没有全部通过是正常的,全部通过才罕见呢,看日志文件。

cat /Users/huzhenghui/ghq/github.com/ml-explore/mlx/build/Testing/Temporary/LastTest.log

测试日志文件内容如下。

Start testing: Dec 16 11:49 CST

----------------------------------------------------------

1/1 Testing: tests

1/1 Test: tests

Command: "/Users/huzhenghui/ghq/github.com/ml-explore/mlx/build/tests/tests"

Directory: /Users/huzhenghui/ghq/github.com/ml-explore/mlx/build/tests

"tests" start time: Dec 16 11:49 CST

Output:

----------------------------------------------------------

[doctest] doctest version is "2.4.9"

[doctest] run with "--help" for options

===============================================================================

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/autograd_tests.cpp:221:

TEST CASE: test grad

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/autograd_tests.cpp:246: ERROR: CHECK_EQ( dfdx(x).item(), std::exp(1.0f) ) is NOT correct!

values: CHECK_EQ( 2.71828, 2.71828 )

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/autograd_tests.cpp:248: ERROR: CHECK_EQ( d2fdx2(x).item(), std::exp(1.0f) ) is NOT correct!

values: CHECK_EQ( 2.71828, 2.71828 )

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/autograd_tests.cpp:250: ERROR: CHECK_EQ( d3fdx3(x).item(), std::exp(1.0f) ) is NOT correct!

values: CHECK_EQ( 2.71828, 2.71828 )

===============================================================================

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/autograd_tests.cpp:358:

TEST CASE: test op vjps

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/autograd_tests.cpp:402: ERROR: CHECK_EQ( out.second.item(), 2.0f * std::exp(1.0f) ) is NOT correct!

values: CHECK_EQ( 5.43656, 5.43656 )

===============================================================================

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:467:

TEST CASE: test reduction ops

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:680: WARNING: WARN_EQ( logsumexp(x).item(), -inf ) is NOT correct!

values: WARN_EQ( nan, -inf )

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:686: WARNING: WARN_EQ( logsumexp(x).item(), inf ) is NOT correct!

values: WARN_EQ( nan, inf )

===============================================================================

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:741:

TEST CASE: test arithmetic unary ops

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:858: ERROR: CHECK( array_equal(exp(x), full({2, 2, 2}, std::exp(1.0f))).item() ) is NOT correct!

values: CHECK( false )

===============================================================================

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:1929:

TEST CASE: test power

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:1945: ERROR: CHECK_EQ( (x ^ 0.5).item(), std::pow(2.0f, 0.5f) ) is NOT correct!

values: CHECK_EQ( 1.41421, 1.41421 )

===============================================================================

[doctest] test cases: 113 | 109 passed | 4 failed | 0 skipped

[doctest] assertions: 1962 | 1956 passed | 6 failed |

[doctest] Status: FAILURE!

Test time = 1.06 sec

----------------------------------------------------------

Test Failed.

"tests" end time: Dec 16 11:49 CST

"tests" time elapsed: 00:00:01

----------------------------------------------------------

End testing: Dec 16 11:49 CST

可以看到单元测试用例通过的比例挺高的。

[doctest] test cases: 113 | 109 passed | 4 failed | 0 skipped

[doctest] assertions: 1962 | 1956 passed | 6 failed |

安装,因为涉及到文件夹权限,需要使用sudo。

cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && sudo make install

输出内容如下。

[ 17%] Built target mlx-metallib

[ 70%] Built target mlx

[ 93%] Built target tests

[ 95%] Built target tutorial

[ 97%] Built target linear_regression

[100%] Built target logistic_regression

Install the project...

-- Install configuration: ""

查看安装的文件。

ls -al /usr/local/lib

可以看到mlx相关文件。

drwxr-xr-x 5 root wheel 160 12 16 08:13 .

drwxr-xr-x 6 root wheel 192 12 16 08:13 ..

drwxr-xr-x 3 root wheel 96 12 16 08:13 cmake

-rw-r--r-- 1 root wheel 66139056 12 16 08:13 libmlx.a

-rw-r--r-- 1 root wheel 61726901 12 16 08:10 mlx.metallib

运行示例中的linear_regression

cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && ./examples/cpp/linear_regression

输出结果如下。

Loss array(4.64954e-05, dtype=float32), |w - w*| = 0.00363933, Throughput 2685.41 (it/s).

运行示例中的logistic_regression。

cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && ./examples/cpp/logistic_regression

输出结果如下。

Loss array(0.0289869, dtype=float32), Accuracy, array(1, dtype=float32), Throughput 2251.92 (it/s).

运行示例中的tutorial。

cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && ./examples/cpp/tutorial

输出结果如下。

array([[1, 1],

[1, 1]], dtype=float32)

相关推荐