55from enum import Enum
66
77
8+ class SageMakerPlatform (str , Enum ):
9+ """Simple enum to define environment variables injected by the SageMaker platform."""
10+
11+ PLATFORM_PORT = "SAGEMAKER_BIND_TO_PORT"
12+
13+
814class SageMakerInference (str , Enum ):
9- """Simple enum to define the mapping between dictionary key and environement variable."""
15+ """Simple enum to define the mapping between dictionary key and environment variable."""
1016
1117 BASE_DIRECTORY = "SAGEMAKER_INFERENCE_BASE_DIRECTORY"
1218 REQUIREMENTS = "SAGEMAKER_INFERENCE_REQUIREMENTS"
@@ -28,7 +34,7 @@ def __init__(self):
2834 SageMakerInference .CODE_DIRECTORY : os .getenv (SageMakerInference .CODE_DIRECTORY , None ),
2935 SageMakerInference .CODE : os .getenv (SageMakerInference .CODE , "inference.handler" ),
3036 SageMakerInference .LOG_LEVEL : os .getenv (SageMakerInference .LOG_LEVEL , 10 ),
31- SageMakerInference .PORT : 8080 ,
37+ SageMakerInference .PORT : self . _resolve_port () ,
3238 }
3339
3440 def __str__ (self ):
@@ -57,3 +63,10 @@ def logging_level(self):
5763 @property
5864 def port (self ):
5965 return self ._environment_variables .get (SageMakerInference .PORT )
66+
67+ def _resolve_port (self ) -> int :
68+ if os .getenv (SageMakerPlatform .PLATFORM_PORT , None ):
69+ return int (os .getenv (SageMakerPlatform .PLATFORM_PORT ))
70+ if os .getenv (SageMakerInference .PORT , None ):
71+ return int (os .getenv (SageMakerInference .PORT ))
72+ return 8080
0 commit comments