@@ -2159,17 +2159,18 @@ static void ggml_vk_queue_command_pools_cleanup(vk_device& device) {
21592159 }
21602160}
21612161
2162+ static std::vector<uint32_t> ggml_vk_find_memory_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
2163+ std::vector<uint32_t> indices;
21622164
2163- static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
21642165 for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) {
21652166 vk::MemoryType memory_type = mem_props->memoryTypes[i];
21662167 if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) &&
21672168 (flags & memory_type.propertyFlags) == flags &&
21682169 mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) {
2169- return static_cast<int32_t> (i);
2170+ indices.push_back (i);
21702171 }
21712172 }
2172- return UINT32_MAX ;
2173+ return indices ;
21732174}
21742175
21752176static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list) {
@@ -2212,22 +2213,24 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
22122213 for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
22132214 const auto & req_flags = *it;
22142215
2215- uint32_t memory_type_index = find_properties (&mem_props, &mem_req, req_flags);
2216+ const std::vector< uint32_t> memory_type_indices = ggml_vk_find_memory_properties (&mem_props, &mem_req, req_flags);
22162217
2217- if (memory_type_index == UINT32_MAX ) {
2218+ if (memory_type_indices.empty() ) {
22182219 continue;
22192220 }
22202221 buf->memory_property_flags = req_flags;
22212222
2222- try {
2223- buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index, &mem_flags_info });
2224- break;
2225- } catch (const vk::SystemError& e) {
2226- // loop and retry
2227- // during last attempt throw the exception
2228- if (it + 1 == req_flags_list.end()) {
2229- device->device.destroyBuffer(buf->buffer);
2230- throw e;
2223+ for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) {
2224+ try {
2225+ buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info });
2226+ break;
2227+ } catch (const vk::SystemError& e) {
2228+ // loop and retry
2229+ // during last attempt throw the exception
2230+ if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) {
2231+ device->device.destroyBuffer(buf->buffer);
2232+ throw e;
2233+ }
22312234 }
22322235 }
22332236 }
@@ -13204,25 +13207,28 @@ void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total
1320413207 vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
1320513208 vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops;
1320613209 vk::PhysicalDeviceMemoryProperties2 memprops = {};
13207- bool membudget_supported = vk_instance.device_supports_membudget[device];
13210+ const bool membudget_supported = vk_instance.device_supports_membudget[device];
13211+ const bool is_integrated_gpu = vkdev.getProperties().deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
1320813212
1320913213 if (membudget_supported) {
1321013214 memprops.pNext = &budgetprops;
1321113215 }
1321213216 vkdev.getMemoryProperties2(&memprops);
1321313217
13218+ *total = 0;
13219+ *free = 0;
13220+
1321413221 for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) {
1321513222 const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i];
1321613223
13217- if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
13218- *total = heap.size;
13224+ if (is_integrated_gpu || ( heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) ) {
13225+ *total + = heap.size;
1321913226
1322013227 if (membudget_supported && i < budgetprops.heapUsage.size()) {
13221- *free = budgetprops.heapBudget[i] - budgetprops.heapUsage[i];
13228+ *free + = budgetprops.heapBudget[i] - budgetprops.heapUsage[i];
1322213229 } else {
13223- *free = heap.size;
13230+ *free + = heap.size;
1322413231 }
13225- break;
1322613232 }
1322713233 }
1322813234}
0 commit comments