Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX preloaded dlls are incompatible with CUDNN torch version #24266

Open
lorenzomammana opened this issue Apr 1, 2025 · 5 comments
Open

ONNX preloaded dlls are incompatible with CUDNN torch version #24266

lorenzomammana opened this issue Apr 1, 2025 · 5 comments
Labels
ep:CUDA issues related to the CUDA execution provider

Comments

@lorenzomammana
Copy link

lorenzomammana commented Apr 1, 2025

Describe the issue

Hi,
I'm trying to do something I'm starting to think it is not possible... which is using a CUDNN that is different between torch and onnx.

In the past I was using torch 2.1.2+cu121 and onnxruntime 1.20.0, with this configuration, given the fact that torch was shipping with cudnn8 I had to manually install cudnn 9 to use onnx.
While doing this I've realized that cudnn 9.1 makes a few model much slower compared to cudnn9.6, so I upgraded to 9.6 and everything was working fine.

Now we decided to upgrade the torch version to 2.4.1+cu121 which ships with cudnn 9.1 that is used by onnxruntime making a few of my models effectively slower..

So I'm trying to use the new preload_dlls api to load cudnn 9.6 before importing torch, but when I import torch I get

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\orobix\Desktop\Axon-0.32.0\python\axon-default-ns\torch\__init__.py", line 137, in <module>
    raise err
OSError: [WinError 127] Impossibile trovare la procedura specificata. Error loading "C:\Users\orobix\Desktop\Axon-0.32.0\python\axon-default-ns\torch\lib\cudnn_cnn64_9.dll" or one of its dependencies.

This is a list of preloaded dlls before importing torch

C:\Program Files\NVIDIA\CUDNN\v9.6\bin\12.6\cudnn_adv64_9.dll
C:\Program Files\NVIDIA\CUDNN\v9.6\bin\12.6\cudnn_engines_precompiled64_9.dll
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6\bin\cufft64_11.dll
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6\bin\cublas64_12.dll
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6\bin\cublasLt64_12.dll
C:\Program Files\NVIDIA\CUDNN\v9.6\bin\12.6\cudnn_ops64_9.dll
C:\Program Files\NVIDIA\CUDNN\v9.6\bin\12.6\cudnn_heuristic64_9.dll
C:\Program Files\NVIDIA\CUDNN\v9.6\bin\12.6\cudnn_engines_runtime_compiled64_9.dll
C:\Program Files\NVIDIA\CUDNN\v9.6\bin\12.6\cudnn_graph64_9.dll
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6\bin\cudart64_12.dll
C:\Program Files\NVIDIA\CUDNN\v9.6\bin\12.6\cudnn64_9.dll
C:\Users\orobix\AppData\Local\Programs\Python\Python310\vcruntime140_1.dll
C:\Windows\System32\msvcp140.dll
C:\Windows\System32\msvcp140_1.dll

And after torch

List of loaded DLLs:
C:\Users\orobix\Desktop\Axon-0.32.0\python\axon-default-ns\torch\lib\cudnn_adv64_9.dll
C:\Users\orobix\Desktop\Axon-0.32.0\python\axon-default-ns\torch\lib\cublasLt64_12.dll
C:\Users\orobix\Desktop\Axon-0.32.0\python\axon-default-ns\torch\lib\cublas64_12.dll
C:\Users\orobix\Desktop\Axon-0.32.0\python\axon-default-ns\torch\lib\nvrtc64_120_0.dll
C:\Program Files\NVIDIA\CUDNN\v9.6\bin\12.6\cudnn_adv64_9.dll
C:\Program Files\NVIDIA\CUDNN\v9.6\bin\12.6\cudnn_engines_precompiled64_9.dll
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6\bin\cufft64_11.dll
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6\bin\cublas64_12.dll
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6\bin\cublasLt64_12.dll
C:\Program Files\NVIDIA\CUDNN\v9.6\bin\12.6\cudnn_ops64_9.dll
C:\Program Files\NVIDIA\CUDNN\v9.6\bin\12.6\cudnn_heuristic64_9.dll
C:\Program Files\NVIDIA\CUDNN\v9.6\bin\12.6\cudnn_engines_runtime_compiled64_9.dll
C:\Program Files\NVIDIA\CUDNN\v9.6\bin\12.6\cudnn_graph64_9.dll
C:\Users\orobix\Desktop\Axon-0.32.0\python\axon-default-ns\torch\lib\cudart64_12.dll
C:\Users\orobix\Desktop\Axon-0.32.0\python\axon-default-ns\torch\lib\cudnn64_9.dll
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6\bin\cudart64_12.dll
C:\Program Files\NVIDIA\CUDNN\v9.6\bin\12.6\cudnn64_9.dll
C:\Users\orobix\Desktop\Axon-0.32.0\python\axon-default-ns\torch\lib\caffe2_nvrtc.dll
C:\Users\orobix\AppData\Local\Programs\Python\Python310\vcruntime140.dll
C:\Windows\System32\msvcp140.dll
C:\Users\orobix\AppData\Local\Programs\Python\Python310\vcruntime140_1.dll
C:\Windows\System32\msvcp140_1.dll

Am I correct to assume that cudnn_cnn64_9.dll is causing errors as it is incompatible? Any clue on why there are other libraries loaded correctly from torch?

I've tried both preloading cuda and cudnn or only cudnn and use torch cuda but the issue is the same

To reproduce

On windows with Cuda 12.6 and CUDNN 9.6.0 installed

onnxruntime.preload_dlls(cuda=True, cudnn=False, directory=CUDA_DIRECTORY)
onnxruntime.preload_dlls(cuda=False, cudnn=True, directory=CUDNN_DIRECTORY)
import torch

Urgency

It heavily slowdowns (like 2 times slower) a few of the models we use

Platform

Windows

OS Version

Windows 10

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.21.0

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

CUDA 12.6 on Windows, Cuda 12.1 from torch

@tianleiwu
Copy link
Contributor

tianleiwu commented Apr 1, 2025

See example 2 in https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#compatibility-with-pytorch
That will help you to use same cuDNN DLLs used by PyTorch. Try install PyTorch 2.6.0+cu124 or 2.6.0+cu126, it uses cudNN 9.7.

If you want to use a different version of cuDNN 9.x in PyTorch and OnnxRuntime, I have no idea since there is potential conflicts.

@lorenzomammana
Copy link
Author

Hi @tianleiwu,
I looked at example 2 but maybe I don't understand the goal of preload dlls.

In the documentation is stated
"The onnxruntime-gpu package is designed to work seamlessly with PyTorch, provided both are built against the same major version of CUDA and cuDNN"

If you don't put any preload, onnx already falls back to the cudnn available with torch, in which situation preload should be used if torch is already available?

I understand that I can upgrade torch to the latest version but I was aiming to avoid exactly that

@tianleiwu
Copy link
Contributor

tianleiwu commented Apr 1, 2025

If you do not use preload_dlls(), onnxruntime will try search cuda and cudnn DLL based on PATH setting. It is not likely that user will add pytorch lib directory to PATH. PyTorch will always load from its lib directory in Windows. If onnxruntime is imported before torch in Windows, it is recommended to use preload_dlls() to avoid potential conflicts.

The potential conflicts for different versions of cudnn 9.x:
https://docs.nvidia.com/deeplearning/cudnn/backend/v9.8.0/api/overview.html
Image
We can see that the sub DLLs has linker dependency. For some reason, there is conflict in mixing different versions of sub DLLs.

The preload_dlls() is a function in Python, the source is here:

def preload_dlls(cuda: bool = True, cudnn: bool = True, msvc: bool = True, directory=None):

@lorenzomammana, you can try modify the function and see whether you can make it to work for your purpose.

@satyajandhyala satyajandhyala added the ep:CUDA issues related to the CUDA execution provider label Apr 2, 2025
@lorenzomammana
Copy link
Author

Hi @tianleiwu I've done a few more experiments.

First as you have suggested I've tried changing the preload_dll function to preload also the library that's giving error, but I can't cheat it in this way and the error arise again.

Then I've tried preloading different versions of CUDNN without preloading cuda dlls (so fallback to cu121 from torch)

Results are the following:
Cudnn9.0 -> failure on cudnn_cnn64_9.dll
Cudnn 9.5 -> failure on cudnn_cnn64_9.dll
Cudnn 9.6 -> failure on cudnn_cnn64_9.dll

Cudnn 9.1.1 -> Success in loading torch but then I get

Image

I would have honestly expected that preloding 9.1.1 which is in the same minor range of my torch (9.1.0) should have worked 😅

@tianleiwu
Copy link
Contributor

Try add all cudnn_*.dll to preload list (Currently only cudnn_cnn64_9.dll is added). That might walkaround the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider
Projects
None yet
Development

No branches or pull requests

3 participants