@@ -73,10 +73,11 @@ class AzureInstanceTypeInfo:
7373 - gpu_vendor(str, optional): the model of the gpus of this instance
7474 - sgx(bool, optional): True if the instance has SGX support
7575 - family(str, optional): family of the instance
76+ - arch(str, optional): architecture type of the instance x86_64 or arm64
7677 """
7778
7879 def __init__ (self , name = "" , cpu = 1 , mem = 0 , os_disk_space = 0 , res_disk_space = 0 ,
79- gpu = 0 , gpu_model = None , gpu_vendor = None , sgx = False , family = None ):
80+ gpu = 0 , gpu_model = None , gpu_vendor = None , sgx = False , family = None , arch = 'x86_64' ):
8081 self .name = name
8182 self .cpu = cpu
8283 self .mem = mem
@@ -88,6 +89,7 @@ def __init__(self, name="", cpu=1, mem=0, os_disk_space=0, res_disk_space=0,
8889 self .sgx = sgx
8990 self .family = family
9091 self .price = None
92+ self .arch = arch
9193
9294 def set_sgx (self ):
9395 """Guess SGX from instance name"""
@@ -134,6 +136,7 @@ def set_gpu_models(self):
134136 def fromSKU (sku , prices = None ):
135137 """Get an instance type object from SKU Json data"""
136138 gpu = os_disk_space = res_disk_space = mem = cpu = 0
139+ arch = "x86_64"
137140 for elem in sku .capabilities :
138141 if elem .name == "vCPUs" :
139142 cpu = int (elem .value )
@@ -145,7 +148,11 @@ def fromSKU(sku, prices=None):
145148 os_disk_space = int (elem .value )
146149 elif elem .name == "GPUs" :
147150 gpu = int (elem .value )
148- instance_type = AzureInstanceTypeInfo (sku .name , cpu , mem , os_disk_space , res_disk_space , gpu , family = sku .family )
151+ elif elem .name == "CpuArchitectureType" :
152+ if "arm" in elem .value .lower ():
153+ arch = "arm64"
154+ instance_type = AzureInstanceTypeInfo (sku .name , cpu , mem , os_disk_space ,
155+ res_disk_space , gpu , family = sku .family , arch = arch )
149156 instance_type .set_gpu_models ()
150157 instance_type .set_sgx ()
151158 if prices and sku .name in prices :
@@ -315,6 +322,7 @@ def get_instance_type(self, system, credentials, subscription_id):
315322 gpu_model = system .getValue ('gpu.model' )
316323 gpu_vendor = system .getValue ('gpu.vendor' )
317324 sgx = system .getValue ('cpu.sgx' )
325+ arch = system .getValue ('cpu.arch' , 'x86_64' )
318326
319327 instace_types = self .get_instance_type_list (credentials , subscription_id , location )
320328
@@ -323,7 +331,8 @@ def get_instance_type(self, system, credentials, subscription_id):
323331 if instace_type .name == self .INSTANCE_TYPE :
324332 default = instace_type
325333
326- comparison = cpu_op (instace_type .cpu , cpu )
334+ comparison = arch == instace_type .arch
335+ comparison = comparison and cpu_op (instace_type .cpu , cpu )
327336 comparison = comparison and memory_op (instace_type .mem , memory )
328337 comparison = comparison and disk_free_op (instace_type .res_disk_space , disk_free )
329338
0 commit comments