About parallel inference on one node with 4 GPUs
Hi, how can I inference aifs-single-1.0 on 4 A100 80G on one node, I tried to inference withfrom anemoi.inference.runners.parallel import ParallelRunner
but ParallelRunner requires different parameters.
ParallelRunner
requires information about the communication group to be used. If you are using slurm, this is filled automatically. If not see here
https://anemoi.readthedocs.io/projects/inference/en/latest/inference/parallel.html
I am trying to run aifs single 1.0 using the guide provided. But that guide is giving indications for simpe runner. I checked the code in anemoi-inference 0.4.9 (file anemoi/inference/runners/parallel.py), and created a python script to run the inference in parallel using slurm. My system is an HPE/Cray using Cray Sligshot interconnect. Every node has two GPUs. I want to run this parallel mode on many nodes. the following is the script i am trying to make:
import argparse
import datetime
import os
import json
import logging
import torch
import numpy as np
import earthkit.data as ekd
import earthkit.regrid as ekr
import pickle
import time
from anemoi.inference.config import Configuration
from anemoi.inference.runners.parallel import ParallelRunner
from anemoi.inference.outputs.printer import print_state
from ecmwf.opendata import Client as OpendataClient
LOG = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
PARAM_SOIL = ["vsw", "sot"]
PARAM_PL = ["gh", "t", "u", "v", "w", "q"]
LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]
SOIL_LEVELS = [1, 2]
# --- Data Retrieval Function ---
def get_open_data(date, param, levelist=[]):
"""Fetch ECMWF Open Data and preprocess it."""
LOG.info(f"Downloading {param} from ECMWF Open Data for {date}")
fields = {}
data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist)
for f in data:
values = ekr.interpolate(np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1), {"grid": (0.25, 0.25)}, {"grid": "N320"})
name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param")
LOG.info(f"Downloaded: {name} - Shape: {values.shape} - Path: {f.path}")
fields[name] = values
return fields
def parse_args():
parser = argparse.ArgumentParser(description="Run ECMWF AIFS forecast in parallel")
parser.add_argument("--start_date", type=str, default=OpendataClient().latest().strftime("%Y-%m-%d"))
parser.add_argument("--start_time", type=str, default="00:00:00")
parser.add_argument("--lead_time", type=str, default="10d")
parser.add_argument("--cuda", action="store_true")
parser.add_argument("--cpu", action="store_true")
parser.add_argument("--input", choices=["opendata", "cds"], default="opendata")
parser.add_argument("--chunks", type=int, default=8)
parser.add_argument("--verbosity", type=int, default=1)
parser.add_argument("--output_frequency", type=str, default=None)
return parser.parse_args()
def get_rank_info():
"""Detects SLURM environment or falls back to local multi-GPU execution."""
if "SLURM_PROCID" in os.environ:
global_rank = int(os.environ.get("SLURM_PROCID", 0))
local_rank = int(os.environ.get("SLURM_LOCALID", 0))
world_size = int(os.environ.get("SLURM_NTASKS", 1))
LOG.info(f"Running under SLURM: Rank {global_rank}, Local Rank {local_rank}, World Size {world_size}")
else:
# No SLURM detected, use all available GPUs
world_size = torch.cuda.device_count()
global_rank = local_rank = 0 # Single-node execution
LOG.info(f"No SLURM detected. Using {world_size} GPUs locally.")
return global_rank, local_rank, world_size
def set_device(local_rank):
"""Ensure correct GPU assignment based on SLURM_LOCALID."""
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
LOG.info(f"Assigned GPU {local_rank} to process.")
def barrier_sync():
"""Synchronize processes safely."""
if torch.distributed.is_initialized():
try:
torch.distributed.barrier()
except RuntimeError as e:
LOG.error(f"Barrier synchronization failed: {e}")
def broadcast_input_state(input_state, device):
"""Broadcast input state across ranks."""
rank = torch.distributed.get_rank()
if rank == 0:
obj_bytes = pickle.dumps(input_state)
byte_tensor = torch.tensor(list(obj_bytes), dtype=torch.uint8, device=device)
size_tensor = torch.tensor([byte_tensor.numel()], dtype=torch.long, device=device)
else:
size_tensor = torch.tensor([0], dtype=torch.long, device=device)
torch.distributed.broadcast(size_tensor, src=0)
byte_tensor = torch.empty(size_tensor.item(), dtype=torch.uint8, device=device)
if rank == 0:
pass # already filled
torch.distributed.broadcast(byte_tensor, src=0)
if rank != 0:
obj_bytes = byte_tensor.cpu().numpy().tobytes()
input_state = pickle.loads(obj_bytes)
return input_state
def main():
args = parse_args()
global_rank, local_rank, world_size = get_rank_info()
set_device(local_rank) # Ensure correct GPU assignment
start_datetime = datetime.datetime.strptime(f"{args.start_date} {args.start_time}", "%Y-%m-%d %H:%M:%S")
# Initialize the context before anything else
context = Configuration(
device=f"cuda:{local_rank}" if args.cuda else "cpu",
checkpoint={"huggingface": "ecmwf/aifs-single-1.0"},
runner="parallel",
lead_time=args.lead_time,
verbosity=args.verbosity,
output=args.output_frequency if args.output_frequency else "none",
)
# Initialize the ParallelRunner (which handles distributed initialization)
runner = ParallelRunner(context)
if global_rank == 0:
# ----- Rank 0 does all preprocessing -----
fields_t0 = get_open_data(start_datetime, PARAM_SFC)
fields_tminus = get_open_data(start_datetime - datetime.timedelta(hours=6), PARAM_SFC)
fields = {
key: np.stack([fields_tminus[key], fields_t0[key]], axis=0)
for key in fields_t0
}
soil_t0 = get_open_data(start_datetime, PARAM_SOIL, SOIL_LEVELS)
soil_tminus = get_open_data(start_datetime - datetime.timedelta(hours=6), PARAM_SOIL, SOIL_LEVELS)
for k, v in soil_t0.items():
fields[f"{k}_t0"] = v
for k, v in soil_tminus.items():
fields[f"{k}_tminus"] = v
pl_t0 = get_open_data(start_datetime, PARAM_PL, LEVELS)
pl_tminus = get_open_data(start_datetime - datetime.timedelta(hours=6), PARAM_PL, LEVELS)
for k, v in pl_t0.items():
fields[f"{k}_t0"] = v
for k, v in pl_tminus.items():
fields[f"{k}_tminus"] = v
mapping = {'sot_1': 'stl1', 'sot_2': 'stl2', 'vsw_1': 'swvl1', 'vsw_2': 'swvl2'}
for k, v in soil_t0.items():
fields[mapping.get(k, k)] = v
for level in LEVELS:
gh_key = f"gh_{level}"
if gh_key in fields:
fields[f"z_{level}"] = fields.pop(gh_key) * 9.80665
if f"{gh_key}_t0" in fields:
fields[f"z_{level}_t0"] = fields.pop(f"{gh_key}_t0") * 9.80665
if f"{gh_key}_tminus" in fields:
fields[f"z_{level}_tminus"] = fields.pop(f"{gh_key}_tminus") * 9.80665
input_state = {
"date": start_datetime.strftime("%Y-%m-%d %H:%M:%S"),
"fields": fields
}
LOG.info("Rank 0: Finished preprocessing.")
else:
LOG.info(f"Rank {global_rank}: Waiting for preprocessing to finish...")
barrier_sync()
input_state = None
input_state = broadcast_input_state(input_state, torch.device(f"cuda:{local_rank}" if args.cuda else "cpu"))
for state in runner.run(input_state=input_state, lead_time=args.lead_time):
if global_rank == 0:
print_state(state)
LOG.info("Parallel forecast complete!")
if __name__ == "__main__":
main()
and the following is the slurm job i am submiting: #!/bin/bash
#SBATCH --job-name=aifs_forecast
#SBATCH --output=aifs_forecast.log
#SBATCH --partition=clm
#SBATCH --nodes=2 # Use both clm1 and clm2
#SBATCH --ntasks=4 # Total tasks (2 per node)
#SBATCH --ntasks-per-node=2 # 1 task per GPU
#SBATCH --gres=gpu:2 # 2 GPUs per node
#SBATCH --time=02:00:00
module load cray-python
export PYTHONFAULTHANDLER=1
export NCCL_DEBUG=INFO
# Optimize NCCL for Slingshot (disable InfiniBand, force socket-based communication)
export NCCL_IB_DISABLE=1
export NCCL_NET=Socket
export NCCL_SOCKET_IFNAME=hsn0 # Use Cray Slingshot adapter
# Reduce latency by enabling blocking waits
export NCCL_BLOCKING_WAIT=1
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_TIMEOUT=300
export NCCL_BUFFSIZE=8388608 # Increase buffer size for large messages
export NCCL_P2P_DISABLE=1 # Disable peer-to-peer communication (forces collective ops)
export NCCL_SHM_DISABLE=1 # Disable shared memory (reduces contention)
export RANK=$SLURM_PROCID
export LOCAL_RANK=$SLURM_LOCALID
export WORLD_SIZE=$SLURM_NTASKS
# Set master address and port dynamically
export MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n1)
export MASTER_PORT=$((29500 + $SLURM_JOB_ID % 1000))
# Assign GPUs correctly
export CUDA_VISIBLE_DEVICES=$LOCAL_RANK
# Dynamically assign a unique port based on SLURM_JOB_ID
#export MASTER_PORT=$((29500 + $SLURM_JOB_ID % 1000))
# Set master address for process group
#export MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n1)
export WORLD_SIZE=$SLURM_NTASKS
# Launch one process per GPU across both nodes
srun --mpi=pmi2 python aifs.py --cuda --start_date 2025-05-13 --start_time 00:00:00 --lead_time 10d --input opendata
I am struggling to make this run. but the documentaion is very rare. And i need to go inside the code to understand how parallel inference is working. I spend many days on this. so if there is someone who already have some good working script (on slurm), i would appreciate to receive a copy. Thank you
Hi ajjaji,
It should not be necessary to rewrite all this code yourself to use parallel inference. You can use anemoi-inference as documented and to use parallel inference you just have to set "runner: parallel" in your config file and to launch anemoi-inference via srun.
If that doesnt work, it would be easier to debug your issues if you post the error messages you are receiving.
Dear Cathalobrien,
Thank you for your feedback, i followed your suggestion, i completly fogot about using the anemoi API. I created a yaml file
runner: parallel
world_size: 2
checkpoint:
huggingface: "ecmwf/aifs-single-1.0"
device: cuda
input:
grib: "./ecmwf.test.grib"
output:
grib: "./forecast_output.grib"
lead_time: 240
i prepared an input grib file ecmwf.grib having the necessary parameters for t=t0 and t=t0-6h. i renamed and scaled gh so that i have z (9.8065 * gh) in the grib file. I also renamed the parameters "sot" and "vsw" to "stl1", "stl2", "swvl1" and "swvl2". grib_ls is giving now the following listing:
ai@clm1:~/aifs/yaml> grib_ls ecmwf.grib
ecmwf.grib
edition centre date dataType gridType stepRange typeOfLevel level shortName packingType
2 ecmf 20250514 fc regular_ll 0 heightAboveGround 10 10u grid_ccsds
2 ecmf 20250514 fc regular_ll 0 heightAboveGround 10 10v grid_ccsds
2 ecmf 20250514 fc regular_ll 0 heightAboveGround 2 2t grid_ccsds
2 ecmf 20250514 fc regular_ll 0 heightAboveGround 2 2d grid_ccsds
2 ecmf 20250514 fc regular_ll 0 meanSea 0 msl grid_ccsds
2 ecmf 20250514 fc regular_ll 0 surface 0 skt grid_ccsds
2 ecmf 20250514 fc regular_ll 0 surface 0 sp grid_ccsds
2 ecmf 20250514 fc regular_ll 0 entireAtmosphere 0 tcw grid_ccsds
2 ecmf 20250514 fc regular_ll 0 surface 0 lsm grid_ccsds
2 ecmf 20250514 fc regular_ll 0 surface 0 z grid_ccsds
2 ecmf 20250514 fc regular_ll 0 surface 0 sdsgso grid_ccsds
2 ecmf 20250514 fc regular_ll 0 surface 0 unknown grid_ccsds
2 ecmf 20250514 fc regular_ll 0 soilLayer 1 swvl1 grid_ccsds
2 ecmf 20250514 fc regular_ll 0 soilLayer 2 swvl2 grid_ccsds
2 ecmf 20250514 fc regular_ll 0 soilLayer 1 stl1 grid_ccsds
2 ecmf 20250514 fc regular_ll 0 soilLayer 2 stl2 grid_ccsds
2 ecmf 20250514 fc regular_ll 0 isobaricInhPa 1000 z grid_ccsds
2 ecmf 20250514 fc regular_ll 0 isobaricInhPa 925 z grid_ccsds
2 ecmf 20250514 fc regular_ll 0 isobaricInhPa 850 z grid_ccsds
2 ecmf 20250514 fc regular_ll 0 isobaricInhPa 700 z grid_ccsds
2 ecmf 20250514 fc regular_ll 0 isobaricInhPa 600 z grid_ccsds
2 ecmf 20250514 fc regular_ll 0 isobaricInhPa 500 z grid_ccsds
2 ecmf 20250514 fc regular_ll 0 isobaricInhPa 400 z grid_ccsds
.........
i submitted the following batch job:
#!/bin/bash
#SBATCH --partition clm
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=2
#SBATCH --gpus-per-node=2
#SBATCH --time=1:00:00
#SBATCH --output=parallel_inf.out
export PYTHONFAULTHANDLER=1
export NCCL_DEBUG=INFO
# Optimize NCCL for Slingshot (disable InfiniBand, force socket-based communication)
export NCCL_IB_DISABLE=1
export NCCL_NET=Socket
export NCCL_SOCKET_IFNAME=hsn0 # Use Cray Slingshot adapter
# Reduce latency by enabling blocking waits
export NCCL_BLOCKING_WAIT=1
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_TIMEOUT=300
export NCCL_BUFFSIZE=8388608 # Increase buffer size for large messages
export NCCL_P2P_DISABLE=1 # Disable peer-to-peer communication (forces collective ops)
export NCCL_SHM_DISABLE=1 # Disable shared memory (reduces contention)
export RANK=$SLURM_PROCID
export LOCAL_RANK=$SLURM_LOCALID
export WORLD_SIZE=$SLURM_NTASKS
# Set master address and port dynamically
export MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n1)
export MASTER_PORT=$((29500 + $SLURM_JOB_ID % 1000))
# Assign GPUs correctly
export CUDA_VISIBLE_DEVICES=$LOCAL_RANK
# Dynamically assign a unique port based on SLURM_JOB_ID
#export MASTER_PORT=$((29500 + $SLURM_JOB_ID % 1000))
# Set master address for process group
#export MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n1)
export WORLD_SIZE=$SLURM_NTASKS
module load cray-python
srun anemoi-inference --debug run aifs.yaml
BUT I am getting the following error:
clm2:3622073:3622267 [0] NCCL INFO Channel 01/0 : 2[0] -> 3[1] [send] via NET/Socket/0
clm2:3622073:3622267 [0] NCCL INFO Connected all rings, use ring PXN 0 GDR 0
2025-05-15 18:13:55 ERROR Data check failed for Create input state
2025-05-15 18:13:55 ERROR Expected (94 variables) x (2 dates) = 188 fields, got 180
Traceback (most recent call last):
File "/scratch/lus/ai/.local/lib/python3.11/site-packages/anemoi/utils/cli.py", line 229, in cli_main
cmd.run(args)
File "/scratch/lus/ai/.local/lib/python3.11/site-packages/anemoi/inference/commands/run.py", line 60, in run
runner.execute()
File "/scratch/lus/ai/.local/lib/python3.11/site-packages/anemoi/inference/runners/default.py", line 102, in execute
input_state = input.create_input_state(date=self.config.date)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/lus/ai/.local/lib/python3.11/site-packages/anemoi/inference/inputs/gribfile.py", line 66, in create_input_state
return self._create_input_state(ekd.from_source("file", self.path), variables=None, date=date)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/lus/ai/.local/lib/python3.11/site-packages/anemoi/inference/inputs/ekd.py", line 375, in _create_input_state
return self._create_state(
^^^^^^^^^^^^^^^^^^^
File "/scratch/lus/ai/.local/lib/python3.11/site-packages/anemoi/inference/inputs/ekd.py", line 249, in _create_state
fields = self._filter_and_sort(fields, variables=variables, dates=dates, title="Create input state")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/lus/ai/.local/lib/python3.11/site-packages/anemoi/inference/inputs/ekd.py", line 166, in _filter_and_sort
check_data(title, data, variables, dates)
File "/scratch/lus/ai/.local/lib/python3.11/site-packages/anemoi/inference/checks.py", line 76, in check_data
raise ValueError(msg)
ValueError: Expected (94 variables) x (2 dates) = 188 fields, got 180
2025-05-15 18:13:55 ERROR
💣 Expected (94 variables) x (2 dates) = 188 fields, got 180
2025-05-15 18:13:55 ERROR 💣 Exiting
name │ 2025-05-14T18:00:00 │ 2025-05-15T00:00:00
───────┼─────────────────────┼────────────────────
10u │ ✅ │ ✅
10v │ ✅ │ ✅
2d │ ✅ │ ✅
2t │ ✅ │ ✅
lsm │ ✅ │ ✅
msl │ ✅ │ ✅
q_100 │ ✅ │ ✅
q_1000 │ ✅ │ ✅
q_150 │ ✅ │ ✅
q_200 │ ✅ │ ✅
q_250 │ ✅ │ ✅
q_300 │ ✅ │ ✅
q_400 │ ✅ │ ✅
q_50 │ ✅ │ ✅
q_500 │ ✅ │ ✅
q_600 │ ✅ │ ✅
q_700 │ ✅ │ ✅
q_850 │ ✅ │ ✅
q_925 │ ✅ │ ✅
sdor │ ✅ │ ✅
skt │ ✅ │ ✅
slor │ ✅ │ ✅
sp │ ✅ │ ✅
stl1 │ ❌ │ ❌
stl2 │ ❌ │ ❌
swvl1 │ ❌ │ ❌
swvl2 │ ❌ │ ❌
t_100 │ ✅ │ ✅
t_1000 │ ✅ │ ✅
t_150 │ ✅ │ ✅
t_200 │ ✅ │ ✅
t_250 │ ✅ │ ✅
t_300 │ ✅ │ ✅
t_400 │ ✅ │ ✅
t_50 │ ✅ │ ✅
t_500 │ ✅ │ ✅
t_600 │ ✅ │ ✅
t_700 │ ✅ │ ✅
t_850 │ ✅ │ ✅
t_925 │ ✅ │ ✅
tcw │ ✅ │ ✅
u_100 │ ✅ │ ✅
u_1000 │ ✅ │ ✅
u_150 │ ✅ │ ✅
u_200 │ ✅ │ ✅
u_250 │ ✅ │ ✅
u_300 │ ✅ │ ✅
u_400 │ ✅ │ ✅
u_50 │ ✅ │ ✅
u_500 │ ✅ │ ✅
u_600 │ ✅ │ ✅
u_700 │ ✅ │ ✅
u_850 │ ✅ │ ✅
u_925 │ ✅ │ ✅
v_100 │ ✅ │ ✅
v_1000 │ ✅ │ ✅
v_150 │ ✅ │ ✅
v_200 │ ✅ │ ✅
v_250 │ ✅ │ ✅
v_300 │ ✅ │ ✅
v_400 │ ✅ │ ✅
v_50 │ ✅ │ ✅
v_500 │ ✅ │ ✅
v_600 │ ✅ │ ✅
v_700 │ ✅ │ ✅
v_850 │ ✅ │ ✅
v_925 │ ✅ │ ✅
w_100 │ ✅ │ ✅
w_1000 │ ✅ │ ✅
w_150 │ ✅ │ ✅
w_200 │ ✅ │ ✅
w_250 │ ✅ │ ✅
w_300 │ ✅ │ ✅
w_400 │ ✅ │ ✅
w_50 │ ✅ │ ✅
w_500 │ ✅ │ ✅
w_600 │ ✅ │ ✅
w_700 │ ✅ │ ✅
w_850 │ ✅ │ ✅
w_925 │ ✅ │ ✅
z │ ✅ │ ✅
z_100 │ ✅ │ ✅
z_1000 │ ✅ │ ✅
z_150 │ ✅ │ ✅
z_200 │ ✅ │ ✅
z_250 │ ✅ │ ✅
z_300 │ ✅ │ ✅
z_400 │ ✅ │ ✅
z_50 │ ✅ │ ✅
z_500 │ ✅ │ ✅
z_600 │ ✅ │ ✅
z_700 │ ✅ │ ✅
z_850 │ ✅ │ ✅
z_925 │ ✅ │ ✅
───────┴─────────────────────┴────────────────────
clm2:3622074:3622074 [1] NCCL INFO cudaDriverVersion 12080
clm2:3622074:3622074 [1] NCCL INFO NCCL_SOCKET_IFNAME set by environment to hsn0
clm2:3622074:3622074 [1] NCCL INFO Bootstrap: Using hsn0:10.253.2.12<0>
it seems that renaming and scaling "gh" to "z" was successful, but anemoi-inference still does not recognize the renamed soild parameters "stl1", "stl2", "swvl1" and "swvl2"
Am i missing something? or making soething wrong ??
Thank you for your support.
Hi,
The grib file you have prepared for input appears to be on a regular_ll
grid, AIFS is trained on reduced_gg
, so that'll likely be the next error you encounter. It also appears that you have some unknown
param included.
If you do not need to run in realtime, and are okay to run inference off of ERA5 rather than OPER, it may be easy to run from the CDS input.
runner: parallel
world_size: 2
checkpoint:
huggingface: "ecmwf/aifs-single-1.0"
date: 2020-01-01
input:
cds:
dataset: 'reanalysis-era5-complete'
...
Otherwise, please open an issue on https://github.com/ecmwf/anemoi-inference/issues providing information about the grib file used as input.
Thank you for your suggestion. I succeeded in passing the issue of input data by using the following yaml file :
runner: parallel
world_size: 2
checkpoint:
huggingface: "ecmwf/aifs-single-1.0"
device: cuda
input:
grib:
path: './opendata.grib'
namer:
rules:
- [ { shortName: sot, level: 1 }, stl1 ]
- [ { shortName: sot, level: 2 }, stl2 ]
- [ { shortName: vsw, level: 1 }, swvl1 ]
- [ { shortName: vsw, level: 2 }, swvl2 ]
output:
grib: "./forecast_output.grib"
lead_time: 240
and i am also using a script to download the ecmwf open data and convert it from regular_ll to reduced_gg (N320) using the following script :
#!/usr/bin/env python3
import argparse
import datetime
import sys
import os
import subprocess
from earthkit.data import from_source
--- Constants ---
PARAM_SFC = ["10u", "10v", "2t", "2d", "msl", "skt", "sp", "tcw", "lsm", "z", "sdor", "slor"]
PARAM_SOIL = ["vsw", "sot"]
PARAM_PL = ["gh", "t", "u", "v", "w", "q"]
LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]
SOIL_LEVELS = [1, 2]
--- Parse Arguments ---
parser = argparse.ArgumentParser(description="Download ECMWF Open Data and save processed output.")
parser.add_argument("--date", required=True, help="Date in YYYY-MM-DD format")
parser.add_argument("--time", type=int, choices=[0, 6, 12, 18], required=True, help="Hour of day (0, 6, 12, or 18)")
parser.add_argument("--output", required=True, help="Output GRIB file path")
args = parser.parse_args()
--- Format date ---
try:
DATE = datetime.datetime.strptime(args.date, "%Y-%m-%d").replace(hour=args.time)
except ValueError:
print("Error: Invalid date format. Use YYYY-MM-DD.", file=sys.stderr)
sys.exit(1)
--- Download function ---
def get_open_data(date, param, levelist=None):
try:
kwargs = {
"date": date,
"time": date.hour,
"param": param,
}
if levelist:
kwargs["levelist"] = levelist
ds = from_source("ecmwf-open-data", **kwargs)
return ds
except Exception as e:
print(f"Failed to download {param} @ {levelist if levelist else 'surface'}: {e}", file=sys.stderr)
return None
--- Download Data ---
print(f"Downloading ECMWF Open Data for {DATE} UTC and 6 hours before")
datasets = []
for dt in [DATE - datetime.timedelta(hours=6), DATE]:
ds_sfc = get_open_data(dt, PARAM_SFC)
if ds_sfc:
datasets.append(ds_sfc)
ds_soil = get_open_data(dt, PARAM_SOIL, levelist=SOIL_LEVELS)
if ds_soil:
datasets.append(ds_soil)
ds_pl = get_open_data(dt, PARAM_PL, levelist=LEVELS)
if ds_pl:
datasets.append(ds_pl)
if not datasets:
print("Error: No data was successfully downloaded.", file=sys.stderr)
sys.exit(1)
--- Merge datasets ---
combined_ds = datasets[0]
for ds in datasets[1:]:
combined_ds += ds
--- Save to GRIB ---
print(f"Writing output to {args.output}")
combined_ds.write(args.output)
--- Helper to run subprocesses ---
def run_command(cmd_args):
try:
subprocess.run(cmd_args, check=True)
except subprocess.CalledProcessError as e:
print(f"Error running command: {' '.join(cmd_args)}: {e}", file=sys.stderr)
sys.exit(1)
--- Prepare temporary filenames ---
tmp1 = "tmp1.grib"
tmp2 = "tmp2.grib"
griddes_file = "n320.griddes"
final_output = args.output
--- Scale and rename gh → z ---
print("Renaming and scaling gh to z...")
run_command([
"grib_set",
"-w", "shortName=gh",
"-s", "shortName=z,scaleValuesBy=9.80665",
args.output,
tmp1
])
--- Write N320 grid definition ---
print("Writing N320 grid definition...")
n320_grid_def = """
gridtype = gaussian_reduced
gridsize = 542080
xsize = 2
ysize = 640
xname = lon
xlongname = "longitude"
xunits = "degrees_east"
yname = lat
ylongname = "latitude"
yunits = "degrees_north"
numLPE = 320
xvals = 0 359.719
yvals = 89.7848769072186 89.5062027382071 89.225882847612 88.9451911183168 88.6643583418232 88.3834573122484 88.1025181389376
87.8215555071107 87.5405774264113 87.2595886348395 86.9785921135966 86.697589831922 86.4165831427363 86.135573006184
85.8545601224852 85.5735450142966 85.2925280796269 85.0115096268976 84.7304898988027 84.4494690889289 84.1684473535766
83.887424820323 83.60640159433 83.3253777630591 83.044353399845 82.7633285666368 82.4823033161248 82.2012776934076
81.9202517373121 81.6392254814469 81.3581989550503 81.0771721836786 80.7961451897673 80.5151179930943 80.2340906111611
79.9530630595119 79.6720353519991 79.3910075010082 79.1099795176461 78.8289514119024 78.5479231927865 78.2668948684445
77.985866446261 77.704837932945 77.4238093346063 77.1427806568205 76.8617519046863 76.5807230828753 76.2996941956759
76.0186652470318 75.737636240576 75.4566071796606 75.1755780673836 74.8945489066125 74.6135197000054 74.3324904500298
74.0514611589791 73.7704318289882 73.4894024620462 73.2083730600095 72.9273436246123 72.6463141574761 72.3652846601196
72.0842551339657 71.8032255803498 71.522196000526 71.2411663956731 70.9601367669007 70.6791071152534 70.3980774417163
70.1170477472186 69.8360180326378 69.5549882988032 69.2739585464991 68.9929287764678 68.7118989894126 68.4308691859999
68.1498393668621 67.8688095325994 67.5877796837818 67.3067498209512 67.025719944623 66.7446900552877 66.4636601534124
66.182630239442 65.901600313801 65.620570376894 65.3395404291075 65.0585104708102 64.7774805023547 64.4964505240781
64.2154205363026 63.9343905393366 63.6533605334754 63.3723305190018 63.0913004961869 62.8102704652906 62.5292404265619
62.2482103802401 61.9671803265548 61.6861502657264 61.4051201979667 61.1240901234793 60.84306004246 60.5620299550971
60.2809998615717 59.9999697620583 59.7189396567247 59.4379095457329 59.1568794292387 58.8758493073924 58.5948191803389
58.313789048218 58.0327589111646 57.751728769309 57.4706986227768 57.1896684716895 56.9086383161644 56.6276081563149
56.3465779922506 56.0655478240774 55.784517651898 55.5034874758116 55.222457295914 54.9414271122983 54.6603969250545
54.3793667342697 54.0983365400284 53.8173063424125 53.5362761415012 53.2552459373716 52.9742157300982 52.6931855197534
52.4121553064074 52.1311250901284 51.8500948709826 51.5690646490341 51.2880344243453 51.0070041969769 50.7259739669878
50.4449437344351 50.1639134993745 49.8828832618601 49.6018530219444 49.3208227796786 49.0397925351124 48.7587622882944
48.4777320392715 48.1967017880897 47.9156715347937 47.6346412794268 47.3536110220315 47.072580762649 46.7915505013196
46.5105202380823 46.2294899729755 45.9484597060362 45.6674294373009 45.3863991668049 45.1053688945827 44.824338620668
44.5433083450937 44.2622780678919 43.9812477890938 43.70021750873 43.4191872268304 43.138156943424 42.8571266585394
42.5760963722043 42.2950660844458 42.0140357952906 41.7330055047645 41.4519752128929 41.1709449197006 40.8899146252118
40.6088843294502 40.3278540324389 40.0468237342006 39.7657934347576 39.4847631341314 39.2037328323433 38.9227025294142
38.6416722253643 38.3606419202135 38.0796116139814 37.798581306687 37.5175509983491 37.236520688986 36.9554903786155
36.6744600672554 36.3934297549228 36.1123994416347 35.8313691274075 35.5503388122576 35.2693084962008 34.9882781792528
34.7072478614288 34.4262175427438 34.1451872232127 33.8641569028498 33.5831265816694 33.3020962596853 33.0210659369112
32.7400356133606 32.4590052890465 32.1779749639819 31.8969446381796 31.6159143116519 31.3348839844112 31.0538536564695
30.7728233278385 30.4917929985299 30.2107626685552 29.9297323379256 29.648702006652 29.3676716747454 29.0866413422163
28.8056110090754 28.5245806753329 28.2435503409989 27.9625200060835 27.6814896705965 27.4004593345475 27.1194289979461
26.8383986608016 26.5573683231233 26.2763379849202 25.9953076462013 25.7142773069753 25.433246967251 25.1522166270369
24.8711862863414 24.5901559451728 24.3091256035393 24.0280952614489 23.7470649189096 23.4660345759292 23.1850042325154
22.9039738886758 22.622943544418 22.3419131997493 22.0608828546771 21.7798525092085 21.4988221633507 21.2177918171107
20.9367614704954 20.6557311235117 20.3747007761663 20.0936704284658 19.812640080417 19.5316097320262 19.2505793833
18.9695490342446 18.6885186848663 18.4074883351714 18.126457985166 17.8454276348561 17.5643972842478 17.2833669333468
17.0023365821593 16.7213062306908 16.4402758789473 16.1592455269342 15.8782151746574 15.5971848221222 15.3161544693343
15.0351241162991 14.754093763022 14.4730634095083 14.1920330557633 13.9110027017923 13.6299723476005 13.348941993193
13.067911638575 12.7868812837514 12.5058509287275 12.224820573508 11.943790218098 11.6627598625023 11.3817295067259
11.1006991507735 10.8196687946499 10.53863843836 10.2576080819084 9.97657772529973 9.69554736853877 9.41451701163009
9.13348665457832 8.85245629738799 8.57142594006364 8.2903955826098 8.009365225031 7.72833486733162 7.44730450951617
7.16627415158902 6.8852437935546 6.60421343541724 6.32318307718136 6.04215271885122 5.76112236043118 5.48009200192553
5.19906164333854 4.91803128467449 4.63700092593762 4.35597056713218 4.07494020826237 3.79390984933242 3.51287949034651
3.23184913130885 2.95081877222359 2.66978841309492 2.38875805392701 2.10772769472399 1.82669733549001 1.5456669762292
1.2646366169457 0.983606257643684 0.702575898327214 0.421545539000452 0.140515179667507 -0.140515179667507 -0.421545539000452
-0.702575898327214 -0.983606257643684 -1.2646366169457 -1.5456669762292 -1.82669733549001 -2.10772769472399 -2.38875805392701
-2.66978841309492 -2.95081877222359 -3.23184913130885 -3.51287949034651 -3.79390984933242 -4.07494020826237 -4.35597056713218
-4.63700092593762 -4.91803128467449 -5.19906164333854 -5.48009200192553 -5.76112236043118 -6.04215271885122 -6.32318307718136
-6.60421343541724 -6.8852437935546 -7.16627415158902 -7.44730450951617 -7.72833486733162 -8.009365225031 -8.2903955826098
-8.57142594006364 -8.85245629738799 -9.13348665457832 -9.41451701163009 -9.69554736853877 -9.97657772529973 -10.2576080819084
-10.53863843836 -10.8196687946499 -11.1006991507735 -11.3817295067259 -11.6627598625023 -11.943790218098 -12.224820573508
-12.5058509287275 -12.7868812837514 -13.067911638575 -13.348941993193 -13.6299723476005 -13.9110027017923 -14.1920330557633
-14.4730634095083 -14.754093763022 -15.0351241162991 -15.3161544693343 -15.5971848221222 -15.8782151746574 -16.1592455269342
-16.4402758789473 -16.7213062306908 -17.0023365821593 -17.2833669333468 -17.5643972842478 -17.8454276348561 -18.126457985166
-18.4074883351714 -18.6885186848663 -18.9695490342446 -19.2505793833 -19.5316097320262 -19.812640080417 -20.0936704284658
-20.3747007761663 -20.6557311235117 -20.9367614704954 -21.2177918171107 -21.4988221633507 -21.7798525092085 -22.0608828546771
-22.3419131997493 -22.622943544418 -22.9039738886758 -23.1850042325154 -23.4660345759292 -23.7470649189096 -24.0280952614489
-24.3091256035393 -24.5901559451728 -24.8711862863414 -25.1522166270369 -25.433246967251 -25.7142773069753 -25.9953076462013
-26.2763379849202 -26.5573683231233 -26.8383986608016 -27.1194289979461 -27.4004593345475 -27.6814896705965 -27.9625200060835
-28.2435503409989 -28.5245806753329 -28.8056110090754 -29.0866413422163 -29.3676716747454 -29.648702006652 -29.9297323379256
-30.2107626685552 -30.4917929985299 -30.7728233278385 -31.0538536564695 -31.3348839844112 -31.6159143116519 -31.8969446381796
-32.1779749639819 -32.4590052890465 -32.7400356133606 -33.0210659369112 -33.3020962596853 -33.5831265816694 -33.8641569028498
-34.1451872232127 -34.4262175427438 -34.7072478614288 -34.9882781792528 -35.2693084962008 -35.5503388122576 -35.8313691274075
-36.1123994416347 -36.3934297549228 -36.6744600672554 -36.9554903786155 -37.236520688986 -37.5175509983491 -37.798581306687
-38.0796116139814 -38.3606419202135 -38.6416722253643 -38.9227025294142 -39.2037328323433 -39.4847631341314 -39.7657934347576
-40.0468237342006 -40.3278540324389 -40.6088843294502 -40.8899146252118 -41.1709449197006 -41.4519752128929 -41.7330055047645
-42.0140357952906 -42.2950660844458 -42.5760963722043 -42.8571266585394 -43.138156943424 -43.4191872268304 -43.70021750873
-43.9812477890938 -44.2622780678919 -44.5433083450937 -44.824338620668 -45.1053688945827 -45.3863991668049 -45.6674294373009
-45.9484597060362 -46.2294899729755 -46.5105202380823 -46.7915505013196 -47.072580762649 -47.3536110220315 -47.6346412794268
-47.9156715347937 -48.1967017880897 -48.4777320392715 -48.7587622882944 -49.0397925351124 -49.3208227796786 -49.6018530219444
-49.8828832618601 -50.1639134993745 -50.4449437344351 -50.7259739669878 -51.0070041969769 -51.2880344243453 -51.5690646490341
-51.8500948709826 -52.1311250901284 -52.4121553064074 -52.6931855197534 -52.9742157300982 -53.2552459373716 -53.5362761415012
-53.8173063424125 -54.0983365400284 -54.3793667342697 -54.6603969250545 -54.9414271122983 -55.222457295914 -55.5034874758116
-55.784517651898 -56.0655478240774 -56.3465779922506 -56.6276081563149 -56.9086383161644 -57.1896684716895 -57.4706986227768
-57.751728769309 -58.0327589111646 -58.313789048218 -58.5948191803389 -58.8758493073924 -59.1568794292387 -59.4379095457329
-59.7189396567247 -59.9999697620583 -60.2809998615717 -60.5620299550971 -60.84306004246 -61.1240901234793 -61.4051201979667
-61.6861502657264 -61.9671803265548 -62.2482103802401 -62.5292404265619 -62.8102704652906 -63.0913004961869 -63.3723305190018
-63.6533605334754 -63.9343905393366 -64.2154205363026 -64.4964505240781 -64.7774805023547 -65.0585104708102 -65.3395404291075
-65.620570376894 -65.901600313801 -66.182630239442 -66.4636601534124 -66.7446900552877 -67.025719944623 -67.3067498209512
-67.5877796837818 -67.8688095325994 -68.1498393668621 -68.4308691859999 -68.7118989894126 -68.9929287764678 -69.2739585464991
-69.5549882988032 -69.8360180326378 -70.1170477472186 -70.3980774417163 -70.6791071152534 -70.9601367669007 -71.2411663956731
-71.522196000526 -71.8032255803498 -72.0842551339657 -72.3652846601196 -72.6463141574761 -72.9273436246123 -73.2083730600095
-73.4894024620462 -73.7704318289882 -74.0514611589791 -74.3324904500298 -74.6135197000054 -74.8945489066125 -75.1755780673836
-75.4566071796606 -75.737636240576 -76.0186652470318 -76.2996941956759 -76.5807230828753 -76.8617519046863 -77.1427806568205
-77.4238093346063 -77.704837932945 -77.985866446261 -78.2668948684445 -78.5479231927865 -78.8289514119024 -79.1099795176461
-79.3910075010082 -79.6720353519991 -79.9530630595119 -80.2340906111611 -80.5151179930943 -80.7961451897673 -81.0771721836786
-81.3581989550503 -81.6392254814469 -81.9202517373121 -82.2012776934076 -82.4823033161248 -82.7633285666368 -83.044353399845
-83.3253777630591 -83.60640159433 -83.887424820323 -84.1684473535766 -84.4494690889289 -84.7304898988027 -85.0115096268976
-85.2925280796269 -85.5735450142966 -85.8545601224852 -86.135573006184 -86.4165831427363 -86.697589831922 -86.9785921135966
-87.2595886348395 -87.5405774264113 -87.8215555071107 -88.1025181389376 -88.3834573122484 -88.6643583418232 -88.9451911183168
-89.225882847612 -89.5062027382071 -89.7848769072186
reducedPoints = 18 25 36 40 45 50 60 64 72 72 75 81 90 96 100 108 120 120 125 135 144 144 150 160 180 180 180 192 192 200
216 216 216 225 240 240 240 250 256 270 270 288 288 288 300 300 320 320 320 324 360 360 360 360 360 360 375
375 384 384 400 400 405 432 432 432 432 450 450 450 480 480 480 480 480 486 500 500 500 512 512 540 540 540
540 540 576 576 576 576 576 576 600 600 600 600 640 640 640 640 640 640 640 648 648 675 675 675 675 720 720
720 720 720 720 720 720 720 729 750 750 750 750 768 768 768 768 800 800 800 800 800 800 810 810 864 864 864
864 864 864 864 864 864 864 864 900 900 900 900 900 900 900 900 960 960 960 960 960 960 960 960 960 960 960
960 960 960 972 972 1000 1000 1000 1000 1000 1000 1000 1000 1024 1024 1024 1024 1024 1024 1080 1080 1080
1080 1080 1080 1080 1080 1080 1080 1080 1080 1080 1080 1125 1125 1125 1125 1125 1125 1125 1125 1125 1125
1125 1125 1125 1125 1152 1152 1152 1152 1152 1152 1152 1152 1152 1200 1200 1200 1200 1200 1200 1200 1200
1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1215 1215 1215 1215 1215 1215 1215 1280 1280 1280 1280
1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280
1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280
1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280
1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280
1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280
1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280
1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1280 1215 1215 1215
1215 1215 1215 1215 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200
1200 1152 1152 1152 1152 1152 1152 1152 1152 1152 1125 1125 1125 1125 1125 1125 1125 1125 1125 1125 1125
1125 1125 1125 1080 1080 1080 1080 1080 1080 1080 1080 1080 1080 1080 1080 1080 1080 1024 1024 1024 1024
1024 1024 1000 1000 1000 1000 1000 1000 1000 1000 972 972 960 960 960 960 960 960 960 960 960 960 960 960
960 960 900 900 900 900 900 900 900 900 864 864 864 864 864 864 864 864 864 864 864 810 810 800 800 800 800
800 800 768 768 768 768 750 750 750 750 729 720 720 720 720 720 720 720 720 720 675 675 675 675 648 648 640
640 640 640 640 640 640 600 600 600 600 576 576 576 576 576 576 540 540 540 540 540 512 512 500 500 500 486
480 480 480 480 480 450 450 450 432 432 432 432 405 400 400 384 384 375 375 360 360 360 360 360 360 324 320
320 320 300 300 288 288 288 270 270 256 250 240 240 240 225 216 216 216 200 192 192 180 180 180 160 150 144
144 135 125 120 120 108 100 96 90 81 75 72 72 64 60 50 45 40 36 25 18
"""
with open(griddes_file, "w") as f:
f.write(n320_grid_def)
--- Remap using CDO ---
print("Remapping using CDO to N320 grid...")
run_command([
"cdo", f"remapbil,{griddes_file}", tmp1, tmp2
])
--- Set longitudeOfFirstGridPointInDegrees = 0 ---
print("Setting longitudeOfFirstGridPointInDegrees = 0...")
run_command([
"grib_set", "-s", "longitudeOfFirstGridPointInDegrees=0", tmp2, final_output
])
--- Cleanup ---
print("Cleaning up temporary files...")
for f in [tmp1, tmp2, griddes_file]:
if os.path.exists(f):
os.remove(f)
print(f"Final GRIB output written to: {final_output}")
My parallel inference is going further, but it is stopping here:
clm1:3209846:3210501 [0] NCCL INFO 2 coll channels, 2 collnet channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
clm1:3209846:3210501 [0] NCCL INFO CC Off, workFifoBytes 1048576
clm1:3209846:3210501 [0] NCCL INFO TUNER/Plugin: Could not find: libnccl-tuner.so. Using internal tuner plugin.
clm1:3209846:3210501 [0] NCCL INFO ncclCommInitRankConfig comm 0x78c0a90 rank 0 nranks 4 cudaDev 0 nvmlDev 0 busId ad000 commId 0x410cad32f624b7cb - Init COMPLETE
clm1:3209846:3210501 [0] NCCL INFO Init timings - ncclCommInitRankConfig: rank 0 nranks 4 total 0.33 (kernels 0.27, alloc 0.01, bootstrap 0.04, allgathers 0.00, topo 0.01, graphs 0.00, connections 0.00, rest 0.00)
clm1:3209846:3210512 [0] NCCL INFO [Proxy Progress] Device 0 CPU core 114
clm1:3209846:3210509 [0] NCCL INFO Channel 00/0 : 3[1] -> 0[0] [receive] via NET/Socket/0
clm1:3209846:3210509 [0] NCCL INFO Channel 01/0 : 3[1] -> 0[0] [receive] via NET/Socket/0
clm1:3209846:3210509 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[1] [send] via NET/Socket/0
clm1:3209846:3210509 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[1] [send] via NET/Socket/0
clm1:3209846:3210509 [0] NCCL INFO Connected all rings, use ring PXN 0 GDR 0
[rank0]: Traceback (most recent call last):
[rank0]: File "/scratch/lus/ai/.local/bin/anemoi-inference", line 8, in
[rank0]: sys.exit(main())
[rank0]: ^^^^^^
[rank0]: File "/scratch/lus/ai/.local/lib/python3.11/site-packages/anemoi/inference/main.py", line 38, in main
[rank0]: cli_main(version, doc, COMMANDS)
[rank0]: File "/scratch/lus/ai/.local/lib/python3.11/site-packages/anemoi/utils/cli.py", line 229, in cli_main
[rank0]: cmd.run(args)
[rank0]: File "/scratch/lus/ai/.local/lib/python3.11/site-packages/anemoi/inference/commands/run.py", line 60, in run
[rank0]: runner.execute()
[rank0]: File "/scratch/lus/ai/.local/lib/python3.11/site-packages/anemoi/inference/runners/default.py", line 114, in execute
[rank0]: for state in self.run(input_state=input_state, lead_time=lead_time):
[rank0]: File "/scratch/lus/ai/.local/lib/python3.11/site-packages/anemoi/inference/runner.py", line 222, in run
[rank0]: yield from self.forecast(lead_time, input_tensor, input_state)
[rank0]: File "/scratch/lus/ai/.local/lib/python3.11/site-packages/anemoi/inference/runner.py", line 590, in forecast
[rank0]: y_pred = self.predict_step(self.model, input_tensor_torch, fcstep=s, step=step, date=date)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/lus/ai/.local/lib/python3.11/site-packages/anemoi/inference/runners/parallel.py", line 117, in predict_step
[rank0]: return model.predict_step(input_tensor_torch, self.model_comm_group)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/lus/ai/.local/lib/python3.11/site-packages/anemoi/models/interface/init.py", line 129, in predict_step
[rank0]: y_hat = self(x, model_comm_group=model_comm_group, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/lus/ai/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/lus/ai/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/lus/ai/.local/lib/python3.11/site-packages/anemoi/models/models/encoder_processor_decoder.py", line 284, in forward
[rank0]: x_data_latent, x_skip = self._assemble_input(x, batch_size)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/lus/ai/.local/lib/python3.11/site-packages/anemoi/models/models/encoder_processor_decoder.py", line 171, in _assemble_input
[rank0]: self.node_attributes(self._graph_name_data, batch_size=batch_size),
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/lus/ai/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1940, in getattr
[rank0]: raise AttributeError(
[rank0]: AttributeError: 'AnemoiModelEncProcDec' object has no attribute 'node_attributes'
clm1:3209847:3210653 [1] NCCL INFO comm 0x78bf7e0 rank 1 nranks 4 cudaDev 1 busId d9000 - Destroy COMPLETE
clm1:3209846:3210693 [0] NCCL INFO comm 0x78c0a90 rank 0 nranks 4 cudaDev 0 busId ad000 - Destroy COMPLETE
srun: error: clm1: tasks 0-1: Exited with exit code 1
srun: Terminating StepId=2173385.0
srun: error: clm2: tasks 2-3: Terminated
srun: Force Terminated StepId=2173385.0
"parallel_inf.out" 250L, 23110B
the attribute "node_attributes" is missing from the definition of AnemoiModelEncProcDec ...!!!!
I am using the latest version of anemoi-inference and anemoi-models ...
Do you know if there is a version of anemoi-inference and anemoi-models which are compatible with aifs single 1.0 ?
The versions provided in the notebook do not support parallel inference ... and the latest ones seem having some mimatche with the model...
Can you please tell me what are the versions of anemoi-inference and anemoi-models suitables to run aif single 1.0 in parallel mode, because the versions you privided in the notebook can run only the simple runnning mode.
Regards
i was able to run aifs single 1.0 in the following environment: anemoi-inference==0.5.4 and anemoi-model==0.3.1. but i was obliged to do the following change :
(aifs) arw@hp3:~/ai/aifs/lib/python3.11/site-packages/anemoi/models/interface> diff init.py init.py.orig
11d10
< from typing import Optional
16d14
< from torch.distributed.distributed_c10d import ProcessGroup
89,91c87
< def predict_step(
< self, batch: torch.Tensor, model_comm_group: Optional[ProcessGroup] = None, **kwargs
< ) -> torch.Tensor:
def predict_step(self, batch: torch.Tensor) -> torch.Tensor:
115c111
< y_hat = self(x, model_comm_group)
y_hat = self(x)
(aifs) arw@hp3:
/ai/aifs/lib/python3.11/site-packages/anemoi/models/interface>/.local/lib/python3.11/site-packages/anemoi/inference> diff runner.py runner.py.orig
For people who want to run on CPUs or on GPUs older than Ampere type, the following program should be changed also:
ai@clm1:
453,472d452
< # Force anemoi to use standard PyTorch attention instead of flash_attn
< import torch.nn.functional as F
< import anemoi.models.layers.attention as attention_module
<
< attention_module.attn_func = F.scaled_dot_product_attention
< attention_module._FLASH_ATTENTION_AVAILABLE = False
<
< from anemoi.models.layers.processor import TransformerProcessor
< model.model.processor = TransformerProcessor(
< num_layers=16,
< window_size=1024,
< num_channels=1024,
< num_chunks=2,
< activation='GELU',
< num_heads=16,
< mlp_hidden_ratio=4,
< dropout_p=0.0,
< attention_implementation="scaled_dot_product_attention").to(self.device)
< torch.cuda.empty_cache()
<
ai@clm1:/.local/lib/python3.11/site-packages/anemoi/inference>/ai/aifs/yaml>
on my system ( a humble one) with 2 nodes and 4 GPUs (Nvidia RTX 2000 with 16GB Memory), i run the 240 hours inference in 4 minutes.
-rwxr-xr-x 1 arw users 481 May 29 10:21 aifs.yaml
-rw-r--r-- 1 arw users 80636026 May 29 10:21 aifs-20250519-0-0.grib
-rw-r--r-- 1 arw users 89081792 May 29 10:21 aifs-20250519-0-6.grib
-rw-r--r-- 1 arw users 89436194 May 29 10:21 aifs-20250519-0-12.grib
-rw-r--r-- 1 arw users 89540638 May 29 10:22 aifs-20250519-0-18.grib
-rw-r--r-- 1 arw users 89578639 May 29 10:22 aifs-20250519-0-24.grib
-rw-r--r-- 1 arw users 89671068 May 29 10:22 aifs-20250519-0-30.grib
-rw-r--r-- 1 arw users 89608199 May 29 10:22 aifs-20250519-0-36.grib
-rw-r--r-- 1 arw users 89567816 May 29 10:22 aifs-20250519-0-42.grib
-rw-r--r-- 1 arw users 89604512 May 29 10:22 aifs-20250519-0-48.grib
-rw-r--r-- 1 arw users 89668115 May 29 10:22 aifs-20250519-0-54.grib
-rw-r--r-- 1 arw users 89704542 May 29 10:22 aifs-20250519-0-60.grib
-rw-r--r-- 1 arw users 89749054 May 29 10:22 aifs-20250519-0-66.grib
-rw-r--r-- 1 arw users 89775864 May 29 10:22 aifs-20250519-0-72.grib
-rw-r--r-- 1 arw users 89815870 May 29 10:23 aifs-20250519-0-78.grib
-rw-r--r-- 1 arw users 89643308 May 29 10:23 aifs-20250519-0-84.grib
-rw-r--r-- 1 arw users 89606007 May 29 10:23 aifs-20250519-0-90.grib
-rw-r--r-- 1 arw users 89567323 May 29 10:23 aifs-20250519-0-96.grib
-rw-r--r-- 1 arw users 89572340 May 29 10:23 aifs-20250519-0-102.grib
-rw-r--r-- 1 arw users 89645787 May 29 10:23 aifs-20250519-0-108.grib
-rw-r--r-- 1 arw users 89663735 May 29 10:23 aifs-20250519-0-114.grib
-rw-r--r-- 1 arw users 89696430 May 29 10:23 aifs-20250519-0-120.grib
-rw-r--r-- 1 arw users 89717784 May 29 10:23 aifs-20250519-0-126.grib
-rw-r--r-- 1 arw users 89735053 May 29 10:23 aifs-20250519-0-132.grib
-rw-r--r-- 1 arw users 89743118 May 29 10:24 aifs-20250519-0-138.grib
-rw-r--r-- 1 arw users 89763779 May 29 10:24 aifs-20250519-0-144.grib
-rw-r--r-- 1 arw users 89765108 May 29 10:24 aifs-20250519-0-150.grib
-rw-r--r-- 1 arw users 89768911 May 29 10:24 aifs-20250519-0-156.grib
-rw-r--r-- 1 arw users 89774781 May 29 10:24 aifs-20250519-0-162.grib
-rw-r--r-- 1 arw users 89791168 May 29 10:24 aifs-20250519-0-168.grib
-rw-r--r-- 1 arw users 89789617 May 29 10:24 aifs-20250519-0-174.grib
-rw-r--r-- 1 arw users 89658467 May 29 10:24 aifs-20250519-0-180.grib
-rw-r--r-- 1 arw users 89682889 May 29 10:24 aifs-20250519-0-186.grib
-rw-r--r-- 1 arw users 89623713 May 29 10:24 aifs-20250519-0-192.grib
-rw-r--r-- 1 arw users 89587138 May 29 10:25 aifs-20250519-0-198.grib
-rw-r--r-- 1 arw users 89592916 May 29 10:25 aifs-20250519-0-204.grib
-rw-r--r-- 1 arw users 89603081 May 29 10:25 aifs-20250519-0-210.grib
-rw-r--r-- 1 arw users 89614808 May 29 10:25 aifs-20250519-0-216.grib
-rw-r--r-- 1 arw users 89632799 May 29 10:25 aifs-20250519-0-222.grib
-rw-r--r-- 1 arw users 89640945 May 29 10:25 aifs-20250519-0-228.grib
-rw-r--r-- 1 arw users 89648423 May 29 10:25 aifs-20250519-0-234.grib
-rwxr-xr-x 1 arw users 17462 May 29 10:25 aifs.gpu.out
-rw-r--r-- 1 arw users 89660666 May 29 10:25 aifs-20250519-0-240.grib
(aifs) arw@hp3:
this is the script (batch job script):
(aifs) arw@hp3:~/ai/aifs/yaml> cat aifs.gpu.sh
#!/bin/bash
#SBATCH --job-name=aifs_infer
#SBATCH --partition=ai
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=2
#SBATCH --gpus-per-node=2
#SBATCH --cpus-per-task=16
#SBATCH --time=1:00:00
#SBATCH --output=aifs.gpu.out
#SBATCH --error=aifs.gpu.out
##SBATCH --exclusive # Avoid sharing nodes (reduces contention)
#SBATCH --hint=multithread # Let system use hyperthreading efficiently
module unload cray-python
module load cray-python/3.11.7
source ~/ai/aifs/bin/activate
set -x
1. Networking
export MASTER_ADDR=$(ip -4 addr show hsn0 | grep -oP '(?<=inet\s)\d+(.\d+){3}' | head -n1)
#export MASTER_ADDR=$(scontrol show hostname $SLURM_NODELIST | head -n 1)
export MASTER_PORT=$((29500 + $SLURM_JOB_ID % 1000))
2. PyTorch DDP setup
export WORLD_SIZE=$SLURM_NTASKS
export RANK=$SLURM_PROCID
export LOCAL_RANK=$SLURM_LOCALID
export ANEMOI_INFERENCE_NUM_CHUNKS=2 # Adjust based on memory/workload
3. NCCL and CUDA optimizations
export CUDA_VISIBLE_DEVICES=$LOCAL_RANK
export CUDA_LAUNCH_BLOCKING=1
export PYTHONFAULTHANDLER=1
export NCCL_DEBUG=WARN # INFO is verbose, use WARN in prod
export NCCL_NET=Socket
export NCCL_SOCKET_IFNAME=hsn0
export NCCL_IB_DISABLE=1 # Use if no Infiniband or want simpler setup
export NCCL_P2P_DISABLE=0 # Enable peer-to-peer comm if hardware supports it
export NCCL_SHM_DISABLE=0 # Shared memory enabled improves intra-node comm
export NCCL_BUFFSIZE=16777216 # Increase NCCL buffer size for better throughput
export TORCH_NCCL_BLOCKING_WAIT=1
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
export TORCH_DISTRIBUTED_DEBUG=DETAIL
export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=900
export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:64,expandable_segments:True'
4. CPU/memory binding (optional advanced tuning)
export OMP_NUM_THREADS=4
export KMP_AFFINITY=granularity=fine,compact,1,0
5. YAML Configuration for Inference
cat << YAML > aifs.yaml
runner: parallel
checkpoint:
huggingface: "ecmwf/aifs-single-1.0"
post_processors:
- accumulate_from_start_of_forecast
device: cuda
input:
grib:
path: './opendata.grib'
namer:
rules:
- [ { shortName: sot, level: 1 }, stl1 ]
- [ { shortName: sot, level: 2 }, stl2 ]
- [ { shortName: vsw, level: 1 }, swvl1 ]
- [ { shortName: vsw, level: 2 }, swvl2 ]
output:
grib:
path: 'aifs-{date}-{time}-{step}.grib'
lead_time: 240
YAML
6. Run the inference with torchrun (recommended over srun for DDP)
srun anemoi-inference run aifs.yaml
(aifs) arw@hp3:~/ai/aifs/yaml>
and please refer to my get_data.py script above to prepare the initial data for this aifs single 1.0 model.
I posted this in case someone is stuck on this model.
Good luck