diff --git a/metaflow/plugins/aws/batch/batch.py b/metaflow/plugins/aws/batch/batch.py index 4f26eeb8502..00719cbd70f 100644 --- a/metaflow/plugins/aws/batch/batch.py +++ b/metaflow/plugins/aws/batch/batch.py @@ -245,6 +245,7 @@ def create_job( ephemeral_storage=None, log_driver=None, log_options=None, + container_secrets=None, offload_command_to_s3=False, ): job_name = self._job_name( @@ -303,6 +304,7 @@ def create_job( ephemeral_storage=ephemeral_storage, log_driver=log_driver, log_options=log_options, + container_secrets=container_secrets, ) .task_id(attrs.get("metaflow.task_id")) .environment_variable("AWS_DEFAULT_REGION", self._client.region()) @@ -427,6 +429,7 @@ def launch_job( ephemeral_storage=None, log_driver=None, log_options=None, + container_secrets=None, ): if queue is None: queue = next(self._client.active_job_queues(), None) @@ -469,6 +472,7 @@ def launch_job( ephemeral_storage=ephemeral_storage, log_driver=log_driver, log_options=log_options, + container_secrets=container_secrets, ) self.num_parallel = num_parallel self.job = job.execute() diff --git a/metaflow/plugins/aws/batch/batch_cli.py b/metaflow/plugins/aws/batch/batch_cli.py index 6c8153da3a4..529b76fe40b 100644 --- a/metaflow/plugins/aws/batch/batch_cli.py +++ b/metaflow/plugins/aws/batch/batch_cli.py @@ -298,6 +298,25 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs): if env_deco: env.update(env_deco[0].attributes["vars"]) + # Collect ECS-style container secrets from @secrets decorator, if any. + # These entries have shape: {"name": , "value_from": } (snake_case input), + # and will be injected at container startup by AWS Batch/ECS via job definition. + container_secrets = [] + secrets_deco = [deco for deco in node.decorators if deco.name == "secrets"] + if secrets_deco: + try: + for s in secrets_deco[0].attributes.get("sources", []) or []: + if isinstance(s, dict): + name = s.get("name") + value_from = s.get("value_from") + if isinstance(name, str) and isinstance(value_from, str): + container_secrets.append( + {"name": name, "value_from": value_from} + ) + except Exception: + # best-effort only; ignore malformed entries silently to avoid breaking launches + pass + # Add the environment variables related to the input-paths argument if split_vars: env.update(split_vars) @@ -366,6 +385,7 @@ def _sync_metadata(): log_driver=log_driver, log_options=log_options, num_parallel=num_parallel, + container_secrets=container_secrets, ) except Exception: traceback.print_exc() diff --git a/metaflow/plugins/aws/batch/batch_client.py b/metaflow/plugins/aws/batch/batch_client.py index bf0f6a824e7..d81698b18dc 100644 --- a/metaflow/plugins/aws/batch/batch_client.py +++ b/metaflow/plugins/aws/batch/batch_client.py @@ -162,6 +162,7 @@ def _register_job_definition( ephemeral_storage, log_driver, log_options, + container_secrets=None, ): # identify platform from any compute environment associated with the # queue @@ -199,6 +200,19 @@ def _register_job_definition( "propagateTags": True, } + # Inject ECS secrets for container-start environment variables, if provided. + if container_secrets: + norm = [] + for item in container_secrets: + if not isinstance(item, dict): + continue + name = item.get("name") + value_from = item.get("value_from") + if isinstance(name, str) and isinstance(value_from, str): + norm.append({"name": name, "valueFrom": value_from}) + if norm: + job_definition["containerProperties"]["secrets"] = norm + log_options_dict = {} if log_options: if isinstance(log_options, str): @@ -480,6 +494,7 @@ def job_def( ephemeral_storage, log_driver, log_options, + container_secrets=None, ): self.payload["jobDefinition"] = self._register_job_definition( image, @@ -502,6 +517,7 @@ def job_def( ephemeral_storage, log_driver, log_options, + container_secrets, ) return self