-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Describe the issue
When compiling an ONNX model to OM on multiple NPUs concurrently, we observed a reproducible race conditon: the faster card finished OM generation first and starts writing the file, while slower cards detect that the OM file has appeared and immediately attempt to load it.
We also created a script where card 0 compiles first and all other cards continuously poll the directory; as soon as the OM file shows up, the other cards start inference, and the error reproduces reliably. Because the file is still being written by the fast card, the slower cards read an incomplete OM, leading to model-load failures or accuracy issues.
Urgency
No response
Target platform
Linux x86_64
CANN:8.3.RC1
torch:2.5.1+cpu
torch_npu:2.5.1
onnxruntime-cann:1.24.0 (This requires building ONNX Runtime from source)
Build script
inswapper onnx :https://huggingface.co/Devia/G/blob/main/inswapper_128.onnx
python script:
import os
import time
import glob
import numpy as np
import torch
import torch_npu
import onnxruntime as ort
RESWAPPER_PATH = "./insightface/inswapper_128.onnx"
OM_DIR = "./"
POLL_INTERVAL = 0.2
MAX_WAIT = 300
def list_om():
return set(glob.glob(os.path.join(OM_DIR, "*.om")))
def main():
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
torch.npu.set_device(local_rank)
sess = ort.InferenceSession(
RESWAPPER_PATH,
providers=[
(
"CANNExecutionProvider",
{
"device_id":local_rank ,
"arena_extend_strategy": "kNextPowerOfTwo",
"enable_cann_graph": True,
"precision_mode": "must_keep_orgin_dtype",
"op_select_impl_mode": "high_precision",
},
),
],
)
target = np.random.randn(1, 3, 128, 128).astype(np.float32)
source = np.random.randn(1, 512).astype(np.float32)
feed = {sess.get_inputs()[0].name: target, sess.get_inputs()[0].name: source}
init = list_om()
if local_rank == 0:
sess.run(None, feed)
else:
start = time.time()
while True:
now = list_om()
if now - init:
break
if time.time() - start > MAX_WAIT:
break
time.sleep(POLL_INTERVAL)
out = sess.run(None, feed)
if __name__ == "__main__":
main()
shell script:
torchrun --nproc_per_node=8 xxx.py
Error / output
Status Message: CANN error executing aclmdLoadFromFile()
Visual Studio Version
No response
GCC / Compiler Version
No response