diff --git a/src/binaries/query/ee_main.rs b/src/binaries/query/ee_main.rs index d85a8a217fb7f..7a6c1b36bfd88 100644 --- a/src/binaries/query/ee_main.rs +++ b/src/binaries/query/ee_main.rs @@ -17,7 +17,7 @@ mod entry; -use databend_common_base::mem_allocator::GlobalAllocator; +use databend_common_base::mem_allocator::TrackingGlobalAllocator; use databend_common_base::runtime::Runtime; use databend_common_base::runtime::ThreadTracker; use databend_common_config::InnerConfig; @@ -34,7 +34,7 @@ use crate::entry::run_cmd; use crate::entry::start_services; #[global_allocator] -pub static GLOBAL_ALLOCATOR: GlobalAllocator = GlobalAllocator; +pub static GLOBAL_ALLOCATOR: TrackingGlobalAllocator = TrackingGlobalAllocator::create(); fn main() { let binary_version = (*databend_common_config::DATABEND_COMMIT_VERSION).clone(); diff --git a/src/binaries/query/entry.rs b/src/binaries/query/entry.rs index 112f8886acda9..911119e9fb753 100644 --- a/src/binaries/query/entry.rs +++ b/src/binaries/query/entry.rs @@ -15,7 +15,7 @@ use std::env; use std::time::Duration; -use databend_common_base::mem_allocator::GlobalAllocator; +use databend_common_base::mem_allocator::TrackingGlobalAllocator; use databend_common_base::runtime::set_alloc_error_hook; use databend_common_base::runtime::GLOBAL_MEM_STAT; use databend_common_config::Commands; @@ -305,8 +305,8 @@ pub async fn start_services(conf: &InnerConfig) -> Result<(), MainError> { "unlimited".to_string() } }); - println!(" allocator: {}", GlobalAllocator::name()); - println!(" config: {}", GlobalAllocator::conf()); + println!(" allocator: {}", TrackingGlobalAllocator::name()); + println!(" config: {}", TrackingGlobalAllocator::conf()); println!(); println!("Cluster: {}", { diff --git a/src/binaries/query/oss_main.rs b/src/binaries/query/oss_main.rs index 54b19666e29fb..91d7c806673e4 100644 --- a/src/binaries/query/oss_main.rs +++ b/src/binaries/query/oss_main.rs @@ -17,7 +17,7 @@ mod entry; -use databend_common_base::mem_allocator::GlobalAllocator; +use databend_common_base::mem_allocator::TrackingGlobalAllocator; use databend_common_base::runtime::Runtime; use databend_common_base::runtime::ThreadTracker; use databend_common_config::InnerConfig; @@ -35,7 +35,7 @@ use crate::entry::run_cmd; use crate::entry::start_services; #[global_allocator] -pub static GLOBAL_ALLOCATOR: GlobalAllocator = GlobalAllocator; +pub static GLOBAL_ALLOCATOR: TrackingGlobalAllocator = TrackingGlobalAllocator::create(); fn main() { let binary_version = (*databend_common_config::DATABEND_COMMIT_VERSION).clone(); diff --git a/src/common/base/src/lib.rs b/src/common/base/src/lib.rs index 21e49119a148b..593b8efed3415 100644 --- a/src/common/base/src/lib.rs +++ b/src/common/base/src/lib.rs @@ -26,6 +26,9 @@ #![feature(variant_count)] #![feature(ptr_alignment_type)] #![feature(vec_into_raw_parts)] +#![feature(slice_ptr_get)] +#![feature(alloc_layout_extra)] +#![feature(let_chains)] pub mod base; pub mod containers; diff --git a/src/common/base/src/mem_allocator/global.rs b/src/common/base/src/mem_allocator/global.rs index 127965419c135..7c6e3f5acc827 100644 --- a/src/common/base/src/mem_allocator/global.rs +++ b/src/common/base/src/mem_allocator/global.rs @@ -19,14 +19,42 @@ use std::alloc::Layout; use std::ptr::null_mut; use std::ptr::NonNull; +use crate::mem_allocator::tracker::MetaTrackerAllocator; use crate::mem_allocator::DefaultAllocator; +pub type DefaultGlobalAllocator = GlobalAllocator; +pub type TrackingGlobalAllocator = GlobalAllocator>; + /// Global allocator, default is JeAllocator. #[derive(Debug, Clone, Copy, Default)] -pub struct GlobalAllocator; +pub struct GlobalAllocator { + inner: T, +} + +impl GlobalAllocator> { + pub const fn create() -> GlobalAllocator> { + GlobalAllocator { + inner: MetaTrackerAllocator::create(DefaultAllocator::create()), + } + } + + pub fn name() -> String { + DefaultAllocator::name() + } + + pub fn conf() -> String { + DefaultAllocator::conf() + } +} + +impl GlobalAllocator { + pub const fn create() -> GlobalAllocator { + GlobalAllocator { + inner: DefaultAllocator::create(), + } + } -impl GlobalAllocator { pub fn name() -> String { DefaultAllocator::name() } @@ -36,20 +64,20 @@ impl GlobalAllocator { } } -unsafe impl Allocator for GlobalAllocator { +unsafe impl Allocator for GlobalAllocator { #[inline(always)] fn allocate(&self, layout: Layout) -> Result, AllocError> { - DefaultAllocator::default().allocate(layout) + self.inner.allocate(layout) } #[inline(always)] fn allocate_zeroed(&self, layout: Layout) -> Result, AllocError> { - DefaultAllocator::default().allocate_zeroed(layout) + self.inner.allocate_zeroed(layout) } #[inline(always)] unsafe fn deallocate(&self, ptr: NonNull, layout: Layout) { - DefaultAllocator::default().deallocate(ptr, layout) + self.inner.deallocate(ptr, layout) } #[inline(always)] @@ -59,7 +87,7 @@ unsafe impl Allocator for GlobalAllocator { old_layout: Layout, new_layout: Layout, ) -> Result, AllocError> { - DefaultAllocator::default().grow(ptr, old_layout, new_layout) + self.inner.grow(ptr, old_layout, new_layout) } #[inline(always)] @@ -69,7 +97,7 @@ unsafe impl Allocator for GlobalAllocator { old_layout: Layout, new_layout: Layout, ) -> Result, AllocError> { - DefaultAllocator::default().grow_zeroed(ptr, old_layout, new_layout) + self.inner.grow_zeroed(ptr, old_layout, new_layout) } #[inline(always)] @@ -79,32 +107,30 @@ unsafe impl Allocator for GlobalAllocator { old_layout: Layout, new_layout: Layout, ) -> Result, AllocError> { - DefaultAllocator::default().shrink(ptr, old_layout, new_layout) + self.inner.shrink(ptr, old_layout, new_layout) } } -unsafe impl GlobalAlloc for GlobalAllocator { +unsafe impl GlobalAlloc for GlobalAllocator { #[inline] unsafe fn alloc(&self, layout: Layout) -> *mut u8 { - if let Ok(ptr) = GlobalAllocator.allocate(layout) { - ptr.as_ptr() as *mut u8 - } else { - null_mut() + match self.allocate(layout) { + Ok(ptr) => ptr.as_ptr() as *mut u8, + Err(_) => null_mut(), } } #[inline] unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { let ptr = NonNull::new(ptr).unwrap_unchecked(); - GlobalAllocator.deallocate(ptr, layout); + self.deallocate(ptr, layout); } #[inline] unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 { - if let Ok(ptr) = GlobalAllocator.allocate_zeroed(layout) { - ptr.as_ptr() as *mut u8 - } else { - null_mut() + match self.allocate_zeroed(layout) { + Ok(ptr) => ptr.as_ptr() as *mut u8, + Err(_) => null_mut(), } } @@ -115,21 +141,15 @@ unsafe impl GlobalAlloc for GlobalAllocator { let ptr = NonNull::new(ptr).unwrap_unchecked(); let new_layout = Layout::from_size_align(new_size, layout.align()).unwrap(); match layout.size().cmp(&new_size) { - Less => { - if let Ok(ptr) = GlobalAllocator.grow(ptr, layout, new_layout) { - ptr.as_ptr() as *mut u8 - } else { - null_mut() - } - } - Greater => { - if let Ok(ptr) = GlobalAllocator.shrink(ptr, layout, new_layout) { - ptr.as_ptr() as *mut u8 - } else { - null_mut() - } - } Equal => ptr.as_ptr(), + Less => match self.grow(ptr, layout, new_layout) { + Ok(ptr) => ptr.as_ptr() as *mut u8, + Err(_) => null_mut(), + }, + Greater => match self.shrink(ptr, layout, new_layout) { + Ok(ptr) => ptr.as_ptr() as *mut u8, + Err(_) => null_mut(), + }, } } } diff --git a/src/common/base/src/mem_allocator/jemalloc.rs b/src/common/base/src/mem_allocator/jemalloc.rs index f867e3b2baf15..ba21d7e1be522 100644 --- a/src/common/base/src/mem_allocator/jemalloc.rs +++ b/src/common/base/src/mem_allocator/jemalloc.rs @@ -21,6 +21,10 @@ pub struct JEAllocator; impl JEAllocator { + pub const fn create() -> JEAllocator { + JEAllocator + } + pub fn name() -> String { "jemalloc".to_string() } @@ -44,7 +48,6 @@ pub mod linux { use tikv_jemalloc_sys as ffi; use super::JEAllocator; - use crate::runtime::ThreadTracker; #[cfg(any(target_arch = "arm", target_arch = "mips", target_arch = "powerpc"))] const ALIGNOF_MAX_ALIGN_T: usize = 8; @@ -77,8 +80,6 @@ pub mod linux { unsafe impl Allocator for JEAllocator { #[inline(always)] fn allocate(&self, layout: Layout) -> Result, AllocError> { - ThreadTracker::alloc(layout.size() as i64)?; - let data_address = if layout.size() == 0 { unsafe { NonNull::new(layout.align() as *mut ()).unwrap_unchecked() } } else { @@ -92,8 +93,6 @@ pub mod linux { #[inline(always)] fn allocate_zeroed(&self, layout: Layout) -> Result, AllocError> { - ThreadTracker::alloc(layout.size() as i64)?; - let data_address = if layout.size() == 0 { unsafe { NonNull::new(layout.align() as *mut ()).unwrap_unchecked() } } else { @@ -108,8 +107,6 @@ pub mod linux { #[inline(always)] unsafe fn deallocate(&self, ptr: NonNull, layout: Layout) { - ThreadTracker::dealloc(layout.size() as i64); - if layout.size() == 0 { debug_assert_eq!(ptr.as_ptr() as usize, layout.align()); } else { @@ -127,9 +124,6 @@ pub mod linux { debug_assert_eq!(old_layout.align(), new_layout.align()); debug_assert!(old_layout.size() <= new_layout.size()); - ThreadTracker::dealloc(old_layout.size() as i64); - ThreadTracker::alloc(new_layout.size() as i64)?; - let data_address = if new_layout.size() == 0 { NonNull::new(new_layout.align() as *mut ()).unwrap_unchecked() } else if old_layout.size() == 0 { @@ -156,9 +150,6 @@ pub mod linux { debug_assert_eq!(old_layout.align(), new_layout.align()); debug_assert!(old_layout.size() <= new_layout.size()); - ThreadTracker::dealloc(old_layout.size() as i64); - ThreadTracker::alloc(new_layout.size() as i64)?; - let data_address = if new_layout.size() == 0 { NonNull::new(new_layout.align() as *mut ()).unwrap_unchecked() } else if old_layout.size() == 0 { @@ -195,9 +186,6 @@ pub mod linux { debug_assert_eq!(old_layout.align(), new_layout.align()); debug_assert!(old_layout.size() >= new_layout.size()); - ThreadTracker::dealloc(old_layout.size() as i64); - ThreadTracker::alloc(new_layout.size() as i64)?; - if old_layout.size() == 0 { debug_assert_eq!(ptr.as_ptr() as usize, old_layout.align()); let slice = std::slice::from_raw_parts_mut(ptr.as_ptr(), 0); diff --git a/src/common/base/src/mem_allocator/mmap.rs b/src/common/base/src/mem_allocator/mmap.rs deleted file mode 100644 index 8224c832ea841..0000000000000 --- a/src/common/base/src/mem_allocator/mmap.rs +++ /dev/null @@ -1,391 +0,0 @@ -// Copyright 2021 Datafuse Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. - -/// mmap allocator. -/// For better performance, we use jemalloc as the inner allocator. -#[derive(Debug, Clone, Copy, Default)] -pub struct MmapAllocator { - #[cfg(feature = "jemalloc")] - allocator: crate::mem_allocator::JEAllocator, - #[cfg(not(feature = "jemalloc"))] - allocator: crate::mem_allocator::StdAllocator, -} - -impl MmapAllocator { - pub fn new() -> Self { - Self { - #[cfg(feature = "jemalloc")] - allocator: crate::mem_allocator::JEAllocator, - #[cfg(not(feature = "jemalloc"))] - allocator: crate::mem_allocator::StdAllocator, - } - } -} - -#[cfg(target_os = "linux")] -pub mod linux { - use std::alloc::AllocError; - use std::alloc::Allocator; - use std::alloc::Layout; - use std::ptr::null_mut; - use std::ptr::NonNull; - - use super::MmapAllocator; - use crate::runtime::ThreadTracker; - - // MADV_POPULATE_WRITE is supported since Linux 5.14. - const MADV_POPULATE_WRITE: i32 = 23; - - const THRESHOLD: usize = 64 << 20; - - impl MmapAllocator { - #[inline(always)] - fn mmap_alloc(&self, layout: Layout) -> Result, AllocError> { - debug_assert!(layout.align() <= page_size()); - ThreadTracker::alloc(layout.size() as i64)?; - const PROT: i32 = libc::PROT_READ | libc::PROT_WRITE; - const FLAGS: i32 = libc::MAP_PRIVATE | libc::MAP_ANONYMOUS | libc::MAP_POPULATE; - let addr = unsafe { libc::mmap(null_mut(), layout.size(), PROT, FLAGS, -1, 0) }; - if addr == libc::MAP_FAILED { - return Err(AllocError); - } - let addr = NonNull::new(addr as *mut ()).ok_or(AllocError)?; - Ok(NonNull::<[u8]>::from_raw_parts(addr, layout.size())) - } - - #[inline(always)] - unsafe fn mmap_dealloc(&self, ptr: NonNull, layout: Layout) { - debug_assert!(layout.align() <= page_size()); - ThreadTracker::dealloc(layout.size() as i64); - let result = libc::munmap(ptr.cast().as_ptr(), layout.size()); - assert_eq!(result, 0, "Failed to deallocate."); - } - - #[inline(always)] - unsafe fn mmap_grow( - &self, - ptr: NonNull, - old_layout: Layout, - new_layout: Layout, - ) -> Result, AllocError> { - debug_assert!(old_layout.align() <= page_size()); - debug_assert!(old_layout.align() == new_layout.align()); - - ThreadTracker::dealloc(old_layout.size() as i64); - ThreadTracker::alloc(new_layout.size() as i64)?; - - const REMAP_FLAGS: i32 = libc::MREMAP_MAYMOVE; - let addr = libc::mremap( - ptr.cast().as_ptr(), - old_layout.size(), - new_layout.size(), - REMAP_FLAGS, - ); - if addr == libc::MAP_FAILED { - return Err(AllocError); - } - let addr = NonNull::new(addr as *mut ()).ok_or(AllocError)?; - if linux_kernel_version() >= (5, 14, 0) { - libc::madvise(addr.cast().as_ptr(), new_layout.size(), MADV_POPULATE_WRITE); - } - Ok(NonNull::<[u8]>::from_raw_parts(addr, new_layout.size())) - } - - #[inline(always)] - unsafe fn mmap_shrink( - &self, - ptr: NonNull, - old_layout: Layout, - new_layout: Layout, - ) -> Result, AllocError> { - debug_assert!(old_layout.align() <= page_size()); - debug_assert!(old_layout.align() == new_layout.align()); - - ThreadTracker::dealloc(old_layout.size() as i64); - ThreadTracker::alloc(new_layout.size() as i64)?; - - const REMAP_FLAGS: i32 = libc::MREMAP_MAYMOVE; - let addr = libc::mremap( - ptr.cast().as_ptr(), - old_layout.size(), - new_layout.size(), - REMAP_FLAGS, - ); - if addr == libc::MAP_FAILED { - return Err(AllocError); - } - let addr = NonNull::new(addr as *mut ()).ok_or(AllocError)?; - - Ok(NonNull::<[u8]>::from_raw_parts(addr, new_layout.size())) - } - } - - unsafe impl Allocator for MmapAllocator { - #[inline(always)] - fn allocate(&self, layout: Layout) -> Result, AllocError> { - if layout.align() > page_size() { - return self.allocator.allocate(layout); - } - if layout.size() >= THRESHOLD { - self.mmap_alloc(layout) - } else { - self.allocator.allocate(layout) - } - } - - #[inline(always)] - unsafe fn deallocate(&self, ptr: NonNull, layout: Layout) { - if layout.align() > page_size() { - return self.allocator.deallocate(ptr, layout); - } - if layout.size() >= THRESHOLD { - self.mmap_dealloc(ptr, layout); - } else { - self.allocator.deallocate(ptr, layout); - } - } - - #[inline(always)] - fn allocate_zeroed(&self, layout: Layout) -> Result, AllocError> { - if layout.align() > page_size() { - return self.allocator.allocate_zeroed(layout); - } - if layout.size() >= THRESHOLD { - self.mmap_alloc(layout) - } else { - self.allocator.allocate_zeroed(layout) - } - } - - unsafe fn grow( - &self, - ptr: NonNull, - old_layout: Layout, - new_layout: Layout, - ) -> Result, AllocError> { - if old_layout.align() > page_size() { - return self.allocator.grow(ptr, old_layout, new_layout); - } - if old_layout.size() >= THRESHOLD { - self.mmap_grow(ptr, old_layout, new_layout) - } else if new_layout.size() >= THRESHOLD { - let addr = self.mmap_alloc(new_layout)?; - std::ptr::copy_nonoverlapping( - ptr.as_ptr(), - addr.cast().as_ptr(), - old_layout.size(), - ); - self.allocator.deallocate(ptr, old_layout); - Ok(addr) - } else { - self.allocator.grow(ptr, old_layout, new_layout) - } - } - - unsafe fn grow_zeroed( - &self, - ptr: NonNull, - old_layout: Layout, - new_layout: Layout, - ) -> Result, AllocError> { - if old_layout.align() > page_size() { - return self.allocator.grow_zeroed(ptr, old_layout, new_layout); - } - if old_layout.size() >= THRESHOLD { - self.mmap_grow(ptr, old_layout, new_layout) - } else if new_layout.size() >= THRESHOLD { - let addr = self.mmap_alloc(new_layout)?; - std::ptr::copy_nonoverlapping( - ptr.as_ptr(), - addr.cast().as_ptr(), - old_layout.size(), - ); - self.allocator.deallocate(ptr, old_layout); - Ok(addr) - } else { - self.allocator.grow_zeroed(ptr, old_layout, new_layout) - } - } - - unsafe fn shrink( - &self, - ptr: NonNull, - old_layout: Layout, - new_layout: Layout, - ) -> Result, AllocError> { - if old_layout.align() > page_size() { - return self.allocator.shrink(ptr, old_layout, new_layout); - } - if new_layout.size() >= THRESHOLD { - self.mmap_shrink(ptr, old_layout, new_layout) - } else if old_layout.size() >= THRESHOLD { - let addr = self.allocator.allocate(new_layout)?; - std::ptr::copy_nonoverlapping( - ptr.as_ptr(), - addr.cast().as_ptr(), - old_layout.size(), - ); - self.mmap_dealloc(ptr, old_layout); - Ok(addr) - } else { - self.allocator.shrink(ptr, old_layout, new_layout) - } - } - } - - #[inline(always)] - fn page_size() -> usize { - use std::sync::atomic::AtomicUsize; - use std::sync::atomic::Ordering; - const INVALID: usize = 0; - static CACHE: AtomicUsize = AtomicUsize::new(INVALID); - let fetch = CACHE.load(Ordering::Relaxed); - if fetch == INVALID { - let result = unsafe { libc::sysconf(libc::_SC_PAGE_SIZE) as usize }; - debug_assert_eq!(result.count_ones(), 1); - CACHE.store(result, Ordering::Relaxed); - result - } else { - fetch - } - } - - #[inline(always)] - fn linux_kernel_version() -> (u16, u8, u8) { - use std::sync::atomic::AtomicU32; - use std::sync::atomic::Ordering; - const INVALID: u32 = 0; - static CACHE: AtomicU32 = AtomicU32::new(INVALID); - let fetch = CACHE.load(Ordering::Relaxed); - let code = if fetch == INVALID { - let mut uname = unsafe { std::mem::zeroed::() }; - assert_ne!(-1, unsafe { libc::uname(&mut uname) }); - let mut length = 0usize; - - // refer: https://semver.org/, here we stop at \0 and _ - while length < uname.release.len() - && uname.release[length] != 0 - && uname.release[length] != 95 - { - length += 1; - } - // fallback to (5.13.0) - let fallback_version = 5u32 << 16 | 13u32 << 8; - #[allow(clippy::unnecessary_cast)] - let slice = unsafe { &*(&uname.release[..length] as *const _ as *const [u8]) }; - let result = match std::str::from_utf8(slice) { - Ok(ver) => match semver::Version::parse(ver) { - Ok(semver) => { - (semver.major.min(65535) as u32) << 16 - | (semver.minor.min(255) as u32) << 8 - | (semver.patch.min(255) as u32) - } - Err(_) => fallback_version, - }, - Err(_) => fallback_version, - }; - - CACHE.store(result, Ordering::Relaxed); - result - } else { - fetch - }; - ((code >> 16) as u16, (code >> 8) as u8, code as u8) - } -} - -#[cfg(not(target_os = "linux"))] -pub mod not_linux { - use std::alloc::AllocError; - use std::alloc::Allocator; - use std::alloc::Layout; - use std::ptr::NonNull; - - use super::MmapAllocator; - - unsafe impl Allocator for MmapAllocator { - #[inline(always)] - fn allocate(&self, layout: Layout) -> Result, AllocError> { - self.allocator.allocate(layout) - } - - #[inline(always)] - fn allocate_zeroed(&self, layout: Layout) -> Result, AllocError> { - self.allocator.allocate_zeroed(layout) - } - - #[inline(always)] - unsafe fn deallocate(&self, ptr: NonNull, layout: Layout) { - self.allocator.deallocate(ptr, layout) - } - - unsafe fn grow( - &self, - ptr: NonNull, - old_layout: Layout, - new_layout: Layout, - ) -> Result, AllocError> { - self.allocator.grow(ptr, old_layout, new_layout) - } - - unsafe fn grow_zeroed( - &self, - ptr: NonNull, - old_layout: Layout, - new_layout: Layout, - ) -> Result, AllocError> { - self.allocator.grow_zeroed(ptr, old_layout, new_layout) - } - - unsafe fn shrink( - &self, - ptr: NonNull, - old_layout: Layout, - new_layout: Layout, - ) -> Result, AllocError> { - self.allocator.shrink(ptr, old_layout, new_layout) - } - } -} - -#[cfg(test)] -mod test { - - #[test] - fn test_semver() { - let uname_release: Vec = - "4.18.0-2.4.3.xyz.x86_64.fdsf.fdsfsdfsdf.fdsafdsf\0\0\0cxzcxzcxzc" - .as_bytes() - .to_vec(); - let mut length = 0; - while length < uname_release.len() - && uname_release[length] != 0 - && uname_release[length] != 95 - { - length += 1; - } - let slice = unsafe { &*(&uname_release[..length] as *const _) }; - let ver = std::str::from_utf8(slice).unwrap(); - let version = semver::Version::parse(ver); - assert!(version.is_ok()); - let version = version.unwrap(); - assert_eq!(version.major, 4); - assert_eq!(version.minor, 18); - assert_eq!(version.patch, 0); - } -} diff --git a/src/common/base/src/mem_allocator/mod.rs b/src/common/base/src/mem_allocator/mod.rs index ca2aa9fe87912..e42798f1ddce8 100644 --- a/src/common/base/src/mem_allocator/mod.rs +++ b/src/common/base/src/mem_allocator/mod.rs @@ -15,19 +15,20 @@ mod global; #[cfg(feature = "jemalloc")] mod jemalloc; -mod mmap; mod std_; pub use default::DefaultAllocator; +pub use global::DefaultGlobalAllocator; pub use global::GlobalAllocator; +pub use global::TrackingGlobalAllocator; #[cfg(feature = "jemalloc")] pub use jemalloc::JEAllocator; -pub use mmap::MmapAllocator; pub use std_::StdAllocator; mod default; #[cfg(feature = "memory-profiling")] mod profiling; +mod tracker; #[cfg(feature = "memory-profiling")] pub use profiling::dump_profile; diff --git a/src/common/base/src/mem_allocator/std_.rs b/src/common/base/src/mem_allocator/std_.rs index 6421cd6b6d274..ab7d286696815 100644 --- a/src/common/base/src/mem_allocator/std_.rs +++ b/src/common/base/src/mem_allocator/std_.rs @@ -18,13 +18,15 @@ use std::alloc::Layout; use std::alloc::System; use std::ptr::NonNull; -use crate::runtime::ThreadTracker; - /// std system allocator. #[derive(Debug, Clone, Copy, Default)] pub struct StdAllocator; impl StdAllocator { + pub const fn create() -> StdAllocator { + StdAllocator + } + pub fn name() -> String { "std".to_string() } @@ -37,19 +39,16 @@ impl StdAllocator { unsafe impl Allocator for StdAllocator { #[inline(always)] fn allocate(&self, layout: Layout) -> Result, AllocError> { - ThreadTracker::alloc(layout.size() as i64)?; System.allocate(layout) } #[inline(always)] fn allocate_zeroed(&self, layout: Layout) -> Result, AllocError> { - ThreadTracker::alloc(layout.size() as i64)?; System.allocate_zeroed(layout) } #[inline(always)] unsafe fn deallocate(&self, ptr: NonNull, layout: Layout) { - ThreadTracker::dealloc(layout.size() as i64); System.deallocate(ptr, layout) } @@ -60,9 +59,6 @@ unsafe impl Allocator for StdAllocator { old_layout: Layout, new_layout: Layout, ) -> Result, AllocError> { - ThreadTracker::dealloc(old_layout.size() as i64); - ThreadTracker::alloc(new_layout.size() as i64)?; - System.grow(ptr, old_layout, new_layout) } @@ -73,9 +69,6 @@ unsafe impl Allocator for StdAllocator { old_layout: Layout, new_layout: Layout, ) -> Result, AllocError> { - ThreadTracker::dealloc(old_layout.size() as i64); - ThreadTracker::alloc(new_layout.size() as i64)?; - System.grow_zeroed(ptr, old_layout, new_layout) } @@ -86,9 +79,6 @@ unsafe impl Allocator for StdAllocator { old_layout: Layout, new_layout: Layout, ) -> Result, AllocError> { - ThreadTracker::dealloc(old_layout.size() as i64); - ThreadTracker::alloc(new_layout.size() as i64)?; - System.shrink(ptr, old_layout, new_layout) } } diff --git a/src/common/base/src/mem_allocator/tracker.rs b/src/common/base/src/mem_allocator/tracker.rs new file mode 100644 index 0000000000000..f77b27a797291 --- /dev/null +++ b/src/common/base/src/mem_allocator/tracker.rs @@ -0,0 +1,1383 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::alloc::AllocError; +use std::alloc::Allocator; +use std::alloc::Layout; +use std::mem::ManuallyDrop; +use std::ptr::slice_from_raw_parts_mut; +use std::ptr::NonNull; +use std::sync::Arc; + +use crate::runtime::GlobalStatBuffer; +use crate::runtime::MemStat; +use crate::runtime::MemStatBuffer; +use crate::runtime::ThreadTracker; + +/// Memory allocation tracker threshold (512 bytes) +static META_TRACKER_THRESHOLD: usize = 512; + +/// An allocator wrapper that tracks memory usage statistics through metadata. +#[derive(Debug, Clone, Copy)] +pub struct MetaTrackerAllocator { + inner: T, +} + +impl MetaTrackerAllocator { + pub const fn create(inner: T) -> MetaTrackerAllocator { + MetaTrackerAllocator { inner } + } +} + +impl Default for MetaTrackerAllocator { + fn default() -> Self { + MetaTrackerAllocator { + inner: T::default(), + } + } +} + +impl MetaTrackerAllocator { + fn metadata_layout() -> Layout { + Layout::new::() + } + + fn adjusted_layout(base_layout: Layout) -> Layout { + base_layout.extend_packed(Self::metadata_layout()).unwrap() + } + + fn with_meta(base: NonNull<[u8]>, layout: Layout, address: usize) -> NonNull<[u8]> { + let mut base_ptr = base.as_non_null_ptr(); + + unsafe { + base_ptr + .add(layout.size()) + .cast::() + .write_unaligned(address); + + NonNull::new_unchecked(slice_from_raw_parts_mut(base_ptr.as_mut(), layout.size())) + } + } + + fn alloc(&self, stat: &Arc, layout: Layout) -> Result, AllocError> { + let adjusted_layout = Self::adjusted_layout(layout); + MemStatBuffer::current().alloc(stat, adjusted_layout.size() as i64)?; + + let Ok(allocated_ptr) = self.inner.allocate(adjusted_layout) else { + MemStatBuffer::current().dealloc(stat, adjusted_layout.size() as i64); + return Err(AllocError); + }; + + let address = Arc::into_raw(stat.clone()) as usize; + Ok(Self::with_meta(allocated_ptr, layout, address)) + } + + fn alloc_zeroed( + &self, + stat: &Arc, + layout: Layout, + ) -> Result, AllocError> { + let adjusted_layout = Self::adjusted_layout(layout); + MemStatBuffer::current().alloc(stat, adjusted_layout.size() as i64)?; + + let Ok(allocated_ptr) = self.inner.allocate_zeroed(adjusted_layout) else { + MemStatBuffer::current().dealloc(stat, adjusted_layout.size() as i64); + return Err(AllocError); + }; + + let address = Arc::into_raw(stat.clone()) as usize; + Ok(Self::with_meta(allocated_ptr, layout, address)) + } + + unsafe fn dealloc(&self, ptr: NonNull, layout: Layout) -> Option { + let adjusted_layout = Self::adjusted_layout(layout); + let mem_stat_address = ptr.add(layout.size()).cast::().read_unaligned(); + + if mem_stat_address == 0 { + return Some(adjusted_layout); + } + + let mem_stat = Arc::from_raw(mem_stat_address as *const MemStat); + MemStatBuffer::current().dealloc(&mem_stat, adjusted_layout.size() as i64); + self.inner.deallocate(ptr, adjusted_layout); + None + } + + unsafe fn move_grow( + &self, + ptr: NonNull, + stat: &Arc, + old_layout: Layout, + new_layout: Layout, + ) -> Result, AllocError> { + let new_adjusted_layout = Self::adjusted_layout(new_layout); + MemStatBuffer::current().alloc(stat, new_adjusted_layout.size() as i64)?; + GlobalStatBuffer::current().dealloc(old_layout.size() as i64); + + let Ok(grow_ptr) = self.inner.grow(ptr, old_layout, new_adjusted_layout) else { + GlobalStatBuffer::current().force_alloc(old_layout.size() as i64); + MemStatBuffer::current().dealloc(stat, new_adjusted_layout.size() as i64); + return Err(AllocError); + }; + + let address = Arc::into_raw(stat.clone()) as usize; + Ok(Self::with_meta(grow_ptr, new_layout, address)) + } + + unsafe fn grow_impl( + &self, + ptr: NonNull, + old_layout: Layout, + new_layout: Layout, + ) -> Result, AllocError> { + let old_adjusted_layout = Self::adjusted_layout(old_layout); + let new_adjusted_layout = Self::adjusted_layout(new_layout); + let address = ptr.add(old_layout.size()).cast::().read_unaligned(); + + if address == 0 { + let diff = new_adjusted_layout.size() - old_adjusted_layout.size(); + GlobalStatBuffer::current().alloc(diff as i64)?; + + let Ok(grow_ptr) = self + .inner + .grow(ptr, old_adjusted_layout, new_adjusted_layout) + else { + GlobalStatBuffer::current().dealloc(diff as i64); + return Err(AllocError); + }; + + return Ok(Self::with_meta(grow_ptr, new_layout, address)); + } + + let diff = new_adjusted_layout.size() - old_adjusted_layout.size(); + let stat = ManuallyDrop::new(Arc::from_raw(address as *const MemStat)); + + MemStatBuffer::current().alloc(&stat, diff as i64)?; + + let Ok(grow_ptr) = self + .inner + .grow(ptr, old_adjusted_layout, new_adjusted_layout) + else { + MemStatBuffer::current().dealloc(&stat, diff as i64); + return Err(AllocError); + }; + + Ok(Self::with_meta(grow_ptr, new_layout, address)) + } + + unsafe fn move_grow_zeroed( + &self, + ptr: NonNull, + stat: &Arc, + old_layout: Layout, + new_layout: Layout, + ) -> Result, AllocError> { + let new_adjusted_layout = Self::adjusted_layout(new_layout); + + MemStatBuffer::current().alloc(stat, new_adjusted_layout.size() as i64)?; + GlobalStatBuffer::current().dealloc(old_layout.size() as i64); + + let Ok(grow_ptr) = self.inner.grow_zeroed(ptr, old_layout, new_adjusted_layout) else { + GlobalStatBuffer::current().force_alloc(old_layout.size() as i64); + MemStatBuffer::current().dealloc(stat, new_adjusted_layout.size() as i64); + return Err(AllocError); + }; + + let address = Arc::into_raw(stat.clone()) as usize; + Ok(Self::with_meta(grow_ptr, new_layout, address)) + } + + unsafe fn grow_zeroed_impl( + &self, + ptr: NonNull, + old_layout: Layout, + new_layout: Layout, + ) -> Result, AllocError> { + let old_adjusted_layout = Self::adjusted_layout(old_layout); + let new_adjusted_layout = Self::adjusted_layout(new_layout); + + let address = ptr.add(old_layout.size()).cast::().read_unaligned(); + + if address == 0 { + let diff = new_adjusted_layout.size() - old_adjusted_layout.size(); + GlobalStatBuffer::current().alloc(diff as i64)?; + + let Ok(grow_ptr) = + self.inner + .grow_zeroed(ptr, old_adjusted_layout, new_adjusted_layout) + else { + GlobalStatBuffer::current().dealloc(diff as i64); + return Err(AllocError); + }; + + return Ok(Self::with_meta(grow_ptr, new_layout, 0)); + } + + let alloc_size = new_adjusted_layout.size() - old_adjusted_layout.size(); + let stat = ManuallyDrop::new(Arc::from_raw(address as *const MemStat)); + + MemStatBuffer::current().alloc(&stat, alloc_size as i64)?; + + let Ok(grow_ptr) = self + .inner + .grow_zeroed(ptr, old_adjusted_layout, new_adjusted_layout) + else { + MemStatBuffer::current().dealloc(&stat, alloc_size as i64); + return Err(AllocError); + }; + + grow_ptr + .as_non_null_ptr() + .add(old_layout.size()) + .cast::() + .write_unaligned(0); + + Ok(Self::with_meta(grow_ptr, new_layout, address)) + } + + unsafe fn move_shrink( + &self, + ptr: NonNull, + old_layout: Layout, + new_layout: Layout, + ) -> Result, AllocError> { + let old_adjusted_layout = Self::adjusted_layout(old_layout); + + let mem_stat_address = ptr.add(old_layout.size()).cast::().read_unaligned(); + + if mem_stat_address == 0 { + let Ok(reduced_ptr) = self.inner.shrink(ptr, old_adjusted_layout, new_layout) else { + return Err(AllocError); + }; + + let diff = old_adjusted_layout.size() - new_layout.size(); + GlobalStatBuffer::current().dealloc(diff as i64); + return Ok(reduced_ptr); + } + + let Ok(reduced_ptr) = self.inner.shrink(ptr, old_adjusted_layout, new_layout) else { + return Err(AllocError); + }; + + let mem_stat = Arc::from_raw(mem_stat_address as *const MemStat); + MemStatBuffer::current().dealloc(&mem_stat, old_adjusted_layout.size() as i64); + GlobalStatBuffer::current().force_alloc(new_layout.size() as i64); + Ok(reduced_ptr) + } + + unsafe fn shrink_impl( + &self, + ptr: NonNull, + old_layout: Layout, + new_layout: Layout, + ) -> Result, AllocError> { + let old_adjusted_layout = Self::adjusted_layout(old_layout); + let new_adjusted_layout = Self::adjusted_layout(new_layout); + + let address = ptr.add(old_layout.size()).cast::().read_unaligned(); + + if address == 0 { + let diff = old_adjusted_layout.size() - new_adjusted_layout.size(); + GlobalStatBuffer::current().dealloc(diff as i64); + let Ok(ptr) = self + .inner + .shrink(ptr, old_adjusted_layout, new_adjusted_layout) + else { + GlobalStatBuffer::current().force_alloc(diff as i64); + return Err(AllocError); + }; + + return Ok(Self::with_meta(ptr, new_layout, address)); + } + + let alloc_size = old_adjusted_layout.size() - new_adjusted_layout.size(); + let stat = ManuallyDrop::new(Arc::from_raw(address as *const MemStat)); + MemStatBuffer::current().dealloc(&stat, alloc_size as i64); + + let Ok(ptr) = self + .inner + .shrink(ptr, old_adjusted_layout, new_adjusted_layout) + else { + MemStatBuffer::current().force_alloc(&stat, alloc_size as i64); + return Err(AllocError); + }; + + Ok(Self::with_meta(ptr, new_layout, address)) + } +} + +unsafe impl Allocator for MetaTrackerAllocator { + #[inline(always)] + fn allocate(&self, layout: Layout) -> Result, AllocError> { + if layout.size() >= META_TRACKER_THRESHOLD { + if let Some(mem_stat) = ThreadTracker::mem_stat() { + return self.alloc(mem_stat, layout); + } + + let adjusted_layout = Self::adjusted_layout(layout); + + GlobalStatBuffer::current().alloc(adjusted_layout.size() as i64)?; + let Ok(allocated_ptr) = self.inner.allocate(adjusted_layout) else { + GlobalStatBuffer::current().dealloc(adjusted_layout.size() as i64); + return Err(AllocError); + }; + + return Ok(Self::with_meta(allocated_ptr, layout, 0)); + } + + GlobalStatBuffer::current().alloc(layout.size() as i64)?; + let Ok(allocated_ptr) = self.inner.allocate(layout) else { + GlobalStatBuffer::current().dealloc(layout.size() as i64); + return Err(AllocError); + }; + + Ok(allocated_ptr) + } + + #[inline(always)] + fn allocate_zeroed(&self, layout: Layout) -> Result, AllocError> { + if layout.size() >= META_TRACKER_THRESHOLD { + if let Some(mem_stat) = ThreadTracker::mem_stat() { + return self.alloc_zeroed(mem_stat, layout); + } + + let adjusted_layout = Self::adjusted_layout(layout); + GlobalStatBuffer::current().alloc(adjusted_layout.size() as i64)?; + + let Ok(allocated_ptr) = self.inner.allocate_zeroed(adjusted_layout) else { + GlobalStatBuffer::current().dealloc(adjusted_layout.size() as i64); + return Err(AllocError); + }; + + return Ok(Self::with_meta(allocated_ptr, layout, 0)); + } + + GlobalStatBuffer::current().alloc(layout.size() as i64)?; + let Ok(allocated_ptr) = self.inner.allocate_zeroed(layout) else { + GlobalStatBuffer::current().dealloc(layout.size() as i64); + return Err(AllocError); + }; + + Ok(allocated_ptr) + } + + #[inline(always)] + unsafe fn deallocate(&self, ptr: NonNull, layout: Layout) { + if layout.size() >= META_TRACKER_THRESHOLD { + if let Some(layout) = self.dealloc(ptr, layout) { + GlobalStatBuffer::current().dealloc(layout.size() as i64); + self.inner.deallocate(ptr, layout); + } + + return; + } + + GlobalStatBuffer::current().dealloc(layout.size() as i64); + self.inner.deallocate(ptr, layout); + } + + #[inline(always)] + unsafe fn grow( + &self, + mut ptr: NonNull, + old_layout: Layout, + new_layout: Layout, + ) -> Result, AllocError> { + if old_layout.size() == new_layout.size() { + return Ok(NonNull::new_unchecked(slice_from_raw_parts_mut( + ptr.as_mut(), + new_layout.size(), + ))); + } + + if old_layout.size() >= META_TRACKER_THRESHOLD + && new_layout.size() >= META_TRACKER_THRESHOLD + { + self.grow_impl(ptr, old_layout, new_layout) + } else if old_layout.size() < META_TRACKER_THRESHOLD + && new_layout.size() < META_TRACKER_THRESHOLD + { + let diff = new_layout.size() - old_layout.size(); + + GlobalStatBuffer::current().alloc(diff as i64)?; + let Ok(grow_ptr) = self.inner.grow(ptr, old_layout, new_layout) else { + GlobalStatBuffer::current().dealloc(diff as i64); + return Err(AllocError); + }; + + Ok(grow_ptr) + } else { + if let Some(mem_stat) = ThreadTracker::mem_stat() { + return self.move_grow(ptr, mem_stat, old_layout, new_layout); + } + + let new_adjusted_layout = Self::adjusted_layout(new_layout); + + let diff = new_adjusted_layout.size() - old_layout.size(); + GlobalStatBuffer::current().alloc(diff as i64)?; + + let Ok(grow_ptr) = self.inner.grow(ptr, old_layout, new_adjusted_layout) else { + GlobalStatBuffer::current().dealloc(diff as i64); + return Err(AllocError); + }; + + Ok(Self::with_meta(grow_ptr, new_layout, 0)) + } + } + + #[inline(always)] + unsafe fn grow_zeroed( + &self, + mut ptr: NonNull, + old_layout: Layout, + new_layout: Layout, + ) -> Result, AllocError> { + if old_layout.size() == new_layout.size() { + return Ok(NonNull::new_unchecked(slice_from_raw_parts_mut( + ptr.as_mut(), + new_layout.size(), + ))); + } + + if old_layout.size() >= META_TRACKER_THRESHOLD + && new_layout.size() >= META_TRACKER_THRESHOLD + { + self.grow_zeroed_impl(ptr, old_layout, new_layout) + } else if old_layout.size() < META_TRACKER_THRESHOLD + && new_layout.size() < META_TRACKER_THRESHOLD + { + let diff = new_layout.size() - old_layout.size(); + GlobalStatBuffer::current().alloc(diff as i64)?; + let Ok(grow_ptr) = self.inner.grow_zeroed(ptr, old_layout, new_layout) else { + GlobalStatBuffer::current().dealloc(diff as i64); + return Err(AllocError); + }; + + Ok(grow_ptr) + } else { + if let Some(mem_stat) = ThreadTracker::mem_stat() { + return self.move_grow_zeroed(ptr, mem_stat, old_layout, new_layout); + } + + let new_adjusted_layout = Self::adjusted_layout(new_layout); + + let diff = new_adjusted_layout.size() - old_layout.size(); + GlobalStatBuffer::current().alloc(diff as i64)?; + let Ok(grow_ptr) = self.inner.grow_zeroed(ptr, old_layout, new_adjusted_layout) else { + GlobalStatBuffer::current().dealloc(diff as i64); + return Err(AllocError); + }; + + Ok(NonNull::new_unchecked(slice_from_raw_parts_mut( + grow_ptr.as_non_null_ptr().as_mut(), + new_layout.size(), + ))) + } + } + + #[inline(always)] + unsafe fn shrink( + &self, + mut ptr: NonNull, + old_layout: Layout, + new_layout: Layout, + ) -> Result, AllocError> { + if old_layout.size() == new_layout.size() { + return Ok(NonNull::new_unchecked(slice_from_raw_parts_mut( + ptr.as_mut(), + new_layout.size(), + ))); + } + + if old_layout.size() >= META_TRACKER_THRESHOLD + && new_layout.size() >= META_TRACKER_THRESHOLD + { + self.shrink_impl(ptr, old_layout, new_layout) + } else if old_layout.size() < META_TRACKER_THRESHOLD + && new_layout.size() < META_TRACKER_THRESHOLD + { + let diff = old_layout.size() - new_layout.size(); + GlobalStatBuffer::current().dealloc(diff as i64); + let Ok(shrink_ptr) = self.inner.shrink(ptr, old_layout, new_layout) else { + GlobalStatBuffer::current().force_alloc(diff as i64); + return Err(AllocError); + }; + + Ok(shrink_ptr) + } else { + self.move_shrink(ptr, old_layout, new_layout) + } + } +} + +#[cfg(test)] +mod tests { + use std::alloc::AllocError; + use std::alloc::Allocator; + use std::alloc::Layout; + use std::ptr::NonNull; + use std::sync::atomic::Ordering; + use std::sync::Arc; + + use crate::base::GlobalUniqName; + use crate::mem_allocator::tracker::MetaTrackerAllocator; + use crate::mem_allocator::tracker::META_TRACKER_THRESHOLD; + use crate::mem_allocator::DefaultAllocator; + use crate::runtime::GlobalStatBuffer; + use crate::runtime::MemStat; + use crate::runtime::MemStatBuffer; + use crate::runtime::Thread; + use crate::runtime::ThreadTracker; + + fn with_mock_env< + T: Allocator + Send + Sync + 'static, + R, + F: Fn(Arc, Arc) -> R, + >( + test: F, + nest: T, + global: Arc, + ) -> R { + { + let mem_stat = MemStat::create(GlobalUniqName::unique()); + let mut tracking_payload = ThreadTracker::new_tracking_payload(); + tracking_payload.mem_stat = Some(mem_stat.clone()); + + let allocator = Arc::new(MetaTrackerAllocator::create(nest)); + let _guard = ThreadTracker::tracking(tracking_payload); + let _mem_stat_guard = MemStatBuffer::mock(global.clone()); + let _global_stat_guard = GlobalStatBuffer::mock(global.clone()); + test(mem_stat, allocator) + } + } + + #[test] + fn test_small_allocation() { + let test_function = + |mem_stat: Arc, allocator: Arc| { + let layout = Layout::from_size_align(256, 8).unwrap(); + + let ptr = allocator.allocate(layout).unwrap(); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(MemStatBuffer::current().memory_usage, 0); + assert_eq!(GlobalStatBuffer::current().memory_usage, 256); + + unsafe { allocator.deallocate(ptr.as_non_null_ptr(), layout) }; + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(MemStatBuffer::current().memory_usage, 0); + assert_eq!(GlobalStatBuffer::current().memory_usage, 0); + }; + + with_mock_env( + test_function, + DefaultAllocator::default(), + Arc::new(MemStat::global()), + ); + } + + #[test] + fn test_large_allocation_with_mem_stat() { + let test_function = + |mem_stat: Arc, allocator: Arc| { + let layout = Layout::from_size_align(512, 8).unwrap(); + let meta_size = std::mem::size_of::(); + + let ptr = allocator.allocate(layout).unwrap(); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!( + MemStatBuffer::current().memory_usage, + (512 + meta_size) as i64 + ); + assert_eq!(GlobalStatBuffer::current().memory_usage, 0); + + unsafe { allocator.deallocate(ptr.as_non_null_ptr(), layout) }; + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(MemStatBuffer::current().memory_usage, 0); + assert_eq!(GlobalStatBuffer::current().memory_usage, 0); + }; + + with_mock_env( + test_function, + DefaultAllocator::default(), + Arc::new(MemStat::global()), + ); + } + + #[test] + fn test_cross_threshold_grow() { + let test_function = + |_mem_stat: Arc, allocator: Arc| { + let old_layout = Layout::from_size_align(256, 8).unwrap(); + let ptr = allocator.allocate(old_layout).unwrap(); + assert_eq!(GlobalStatBuffer::current().memory_usage, 256); + + let new_layout = Layout::from_size_align(768, 8).unwrap(); + let new_ptr = + unsafe { allocator.grow(ptr.as_non_null_ptr(), old_layout, new_layout) } + .unwrap(); + + assert_eq!(GlobalStatBuffer::current().memory_usage, 0); + assert_eq!( + MemStatBuffer::current().memory_usage, + 768 + std::mem::size_of::() as i64 + ); + + unsafe { allocator.deallocate(new_ptr.as_non_null_ptr(), new_layout) }; + assert_eq!(MemStatBuffer::current().memory_usage, 0); + assert_eq!(GlobalStatBuffer::current().memory_usage, 0); + }; + + with_mock_env( + test_function, + DefaultAllocator::default(), + Arc::new(MemStat::global()), + ); + } + + #[test] + fn test_allocation_rollback() { + struct FailingAllocator; + unsafe impl Allocator for FailingAllocator { + fn allocate(&self, _: Layout) -> Result, AllocError> { + Err(AllocError) + } + unsafe fn deallocate(&self, _: NonNull, _: Layout) {} + } + + let test_function = + |mem_stat: Arc, allocator: Arc| { + let layout = Layout::from_size_align(1024, 8).unwrap(); + + let result = allocator.allocate(layout); + assert!(result.is_err()); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(MemStatBuffer::current().memory_usage, 0); + assert_eq!(GlobalStatBuffer::current().memory_usage, 0); + }; + + with_mock_env(test_function, FailingAllocator, Arc::new(MemStat::global())); + } + + #[test] + fn test_shrink_memory() { + let test_function = + |_mem_stat: Arc, allocator: Arc| { + let old_layout = Layout::from_size_align(1024, 8).unwrap(); + let new_layout = Layout::from_size_align(256, 8).unwrap(); + + let ptr = allocator.allocate(old_layout).unwrap(); + let initial_usage = MemStatBuffer::current().memory_usage; + + let new_ptr = + unsafe { allocator.shrink(ptr.as_non_null_ptr(), old_layout, new_layout) } + .unwrap(); + + let expected_dealloc = (old_layout.size() + std::mem::size_of::()) as i64; + assert_eq!( + MemStatBuffer::current().memory_usage, + initial_usage - expected_dealloc + ); + assert_eq!(GlobalStatBuffer::current().memory_usage, 256); + + unsafe { allocator.deallocate(new_ptr.as_non_null_ptr(), new_layout) }; + }; + with_mock_env( + test_function, + DefaultAllocator::default(), + Arc::new(MemStat::global()), + ); + } + + #[test] + fn test_extreme_alignment() { + let test_function = + |_mem_stat: Arc, allocator: Arc| { + let layout = Layout::from_size_align(512, 64).unwrap(); + let ptr = allocator.allocate(layout).unwrap(); + + let addr = ptr.as_non_null_ptr().as_ptr() as usize; + assert_eq!(addr % 64, 0); + + unsafe { + let meta_ptr = ptr.as_non_null_ptr().as_ptr().add(layout.size()); + let stat_addr = meta_ptr.cast::().read_unaligned(); + assert_ne!(stat_addr, 0); + } + + unsafe { allocator.deallocate(ptr.as_non_null_ptr(), layout) }; + }; + + with_mock_env( + test_function, + DefaultAllocator::default(), + Arc::new(MemStat::global()), + ); + } + + #[test] + fn test_allocate_zeroed_failure() { + struct FailingAllocator; + unsafe impl Allocator for FailingAllocator { + fn allocate(&self, _layout: Layout) -> Result, AllocError> { + unreachable!() + } + + fn allocate_zeroed(&self, _: Layout) -> Result, AllocError> { + Err(AllocError) + } + + unsafe fn deallocate(&self, _ptr: NonNull, _layout: Layout) { + unreachable!() + } + } + + let test_function = + |mem_stat: Arc, allocator: Arc| { + let layout = Layout::from_size_align(1024, 8).unwrap(); + + let result = allocator.allocate_zeroed(layout); + assert!(result.is_err()); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(MemStatBuffer::current().memory_usage, 0); + assert_eq!(GlobalStatBuffer::current().memory_usage, 0); + + let small_layout = Layout::from_size_align(256, 8).unwrap(); + let result = allocator.allocate_zeroed(small_layout); + assert!(result.is_err()); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(MemStatBuffer::current().memory_usage, 0); + assert_eq!(GlobalStatBuffer::current().memory_usage, 0); + }; + + with_mock_env(test_function, FailingAllocator, Arc::new(MemStat::global())); + } + + #[test] + fn test_grow_failure_rollback() { + struct PartialFailingAllocator(T); + unsafe impl Allocator for PartialFailingAllocator { + fn allocate(&self, layout: Layout) -> Result, AllocError> { + self.0.allocate(layout) + } + + unsafe fn deallocate(&self, ptr: NonNull, layout: Layout) { + self.0.deallocate(ptr, layout) + } + + unsafe fn grow( + &self, + _ptr: NonNull, + _old_layout: Layout, + _new_layout: Layout, + ) -> Result, AllocError> { + Err(AllocError) + } + } + + let test_function = + |mem_stat: Arc, allocator: Arc| { + let old_layout = Layout::from_size_align(256, 8).unwrap(); + let ptr = allocator.allocate(old_layout).unwrap(); + let initial_global = GlobalStatBuffer::current().memory_usage; + + let new_layout = Layout::from_size_align(498, 8).unwrap(); + let result = + unsafe { allocator.grow(ptr.as_non_null_ptr(), old_layout, new_layout) }; + + assert!(result.is_err()); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(MemStatBuffer::current().memory_usage, 0); + assert_eq!(GlobalStatBuffer::current().memory_usage, initial_global); + + unsafe { allocator.deallocate(ptr.as_non_null_ptr(), old_layout) }; + + let old_layout = Layout::from_size_align(512, 8).unwrap(); + let ptr = allocator.allocate(old_layout).unwrap(); + let initial_usage = MemStatBuffer::current().memory_usage; + + let new_layout = Layout::from_size_align(1024, 8).unwrap(); + let result = + unsafe { allocator.grow(ptr.as_non_null_ptr(), old_layout, new_layout) }; + + assert!(result.is_err()); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(GlobalStatBuffer::current().memory_usage, 0); + assert_eq!(MemStatBuffer::current().memory_usage, initial_usage); + + unsafe { allocator.deallocate(ptr.as_non_null_ptr(), old_layout) }; + + let old_layout = Layout::from_size_align(256, 8).unwrap(); + let ptr = allocator.allocate(old_layout).unwrap(); + let initial_global = GlobalStatBuffer::current().memory_usage; + + let new_layout = Layout::from_size_align(768, 8).unwrap(); + let result = + unsafe { allocator.grow(ptr.as_non_null_ptr(), old_layout, new_layout) }; + + assert!(result.is_err()); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(MemStatBuffer::current().memory_usage, 0); + assert_eq!(GlobalStatBuffer::current().memory_usage, initial_global); + + unsafe { allocator.deallocate(ptr.as_non_null_ptr(), old_layout) }; + }; + + with_mock_env( + test_function, + PartialFailingAllocator(DefaultAllocator::default()), + Arc::new(MemStat::global()), + ); + } + + #[test] + fn test_grow_zeroed_failure() { + struct GrowZeroedFailingAllocator(T); + + unsafe impl Allocator for GrowZeroedFailingAllocator { + fn allocate(&self, layout: Layout) -> Result, AllocError> { + self.0.allocate(layout) + } + + unsafe fn deallocate(&self, ptr: NonNull, layout: Layout) { + self.0.deallocate(ptr, layout) + } + + unsafe fn grow_zeroed( + &self, + _ptr: NonNull, + _old_layout: Layout, + _new_layout: Layout, + ) -> Result, AllocError> { + Err(AllocError) + } + } + + let test_function = + |mem_stat: Arc, allocator: Arc| { + let old_layout = Layout::from_size_align(256, 8).unwrap(); + let ptr = allocator.allocate(old_layout).unwrap(); + let initial_global = GlobalStatBuffer::current().memory_usage; + + let new_layout = Layout::from_size_align(498, 8).unwrap(); + let result = + unsafe { allocator.grow_zeroed(ptr.as_non_null_ptr(), old_layout, new_layout) }; + + assert!(result.is_err()); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(MemStatBuffer::current().memory_usage, 0); + assert_eq!(GlobalStatBuffer::current().memory_usage, initial_global); + + unsafe { allocator.deallocate(ptr.as_non_null_ptr(), old_layout) }; + + let old_layout = Layout::from_size_align(512, 8).unwrap(); + let ptr = allocator.allocate(old_layout).unwrap(); + let initial_usage = MemStatBuffer::current().memory_usage; + + let new_layout = Layout::from_size_align(1024, 8).unwrap(); + let result = + unsafe { allocator.grow_zeroed(ptr.as_non_null_ptr(), old_layout, new_layout) }; + + assert!(result.is_err()); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(GlobalStatBuffer::current().memory_usage, 0); + assert_eq!(MemStatBuffer::current().memory_usage, initial_usage); + + unsafe { allocator.deallocate(ptr.as_non_null_ptr(), old_layout) }; + + let old_layout = Layout::from_size_align(256, 8).unwrap(); + let ptr = allocator.allocate(old_layout).unwrap(); + let initial_global = GlobalStatBuffer::current().memory_usage; + + let new_layout = Layout::from_size_align(768, 8).unwrap(); + let result = + unsafe { allocator.grow_zeroed(ptr.as_non_null_ptr(), old_layout, new_layout) }; + + assert!(result.is_err()); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(MemStatBuffer::current().memory_usage, 0); + assert_eq!(GlobalStatBuffer::current().memory_usage, initial_global); + + unsafe { allocator.deallocate(ptr.as_non_null_ptr(), old_layout) }; + }; + + with_mock_env( + test_function, + GrowZeroedFailingAllocator(DefaultAllocator::default()), + Arc::new(MemStat::global()), + ); + } + + #[test] + fn test_shrink_failure() { + struct ShrinkFailingAllocator(T); + unsafe impl Allocator for ShrinkFailingAllocator { + fn allocate(&self, layout: Layout) -> Result, AllocError> { + self.0.allocate(layout) + } + + unsafe fn deallocate(&self, ptr: NonNull, layout: Layout) { + self.0.deallocate(ptr, layout) + } + + unsafe fn shrink( + &self, + _ptr: NonNull, + _old_layout: Layout, + _new_layout: Layout, + ) -> Result, AllocError> { + Err(AllocError) + } + } + + let test_function = + |mem_stat: Arc, allocator: Arc| { + let old_layout = Layout::from_size_align(498, 8).unwrap(); + let ptr = allocator.allocate(old_layout).unwrap(); + let initial_usage = GlobalStatBuffer::current().memory_usage; + + let new_layout = Layout::from_size_align(256, 8).unwrap(); + let result = + unsafe { allocator.shrink(ptr.as_non_null_ptr(), old_layout, new_layout) }; + + assert!(result.is_err()); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(MemStatBuffer::current().memory_usage, 0); + assert_eq!(GlobalStatBuffer::current().memory_usage, initial_usage); + + unsafe { allocator.deallocate(ptr.as_non_null_ptr(), old_layout) }; + + let old_layout = Layout::from_size_align(1024, 8).unwrap(); + let ptr = allocator.allocate(old_layout).unwrap(); + let initial_usage = MemStatBuffer::current().memory_usage; + + let new_layout = Layout::from_size_align(512, 8).unwrap(); + let result = + unsafe { allocator.shrink(ptr.as_non_null_ptr(), old_layout, new_layout) }; + + assert!(result.is_err()); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(GlobalStatBuffer::current().memory_usage, 0); + assert_eq!(MemStatBuffer::current().memory_usage, initial_usage); + + unsafe { allocator.deallocate(ptr.as_non_null_ptr(), old_layout) }; + + let old_layout = Layout::from_size_align(512, 8).unwrap(); + let ptr = allocator.allocate(old_layout).unwrap(); + let initial_usage = MemStatBuffer::current().memory_usage; + + let new_layout = Layout::from_size_align(256, 8).unwrap(); + let result = + unsafe { allocator.shrink(ptr.as_non_null_ptr(), old_layout, new_layout) }; + + assert!(result.is_err()); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(GlobalStatBuffer::current().memory_usage, 0); + assert_eq!(MemStatBuffer::current().memory_usage, initial_usage); + + unsafe { allocator.deallocate(ptr.as_non_null_ptr(), old_layout) }; + }; + + with_mock_env( + test_function, + ShrinkFailingAllocator(DefaultAllocator::default()), + Arc::new(MemStat::global()), + ); + } + + #[test] + fn test_mixed_failure_scenarios() { + struct ChaosAllocator { + inner: T, + failure_rate: f64, + } + + unsafe impl Allocator for ChaosAllocator { + fn allocate(&self, layout: Layout) -> Result, AllocError> { + if rand::random::() < self.failure_rate { + Err(AllocError) + } else { + self.inner.allocate(layout) + } + } + + fn allocate_zeroed(&self, layout: Layout) -> Result, AllocError> { + if rand::random::() < self.failure_rate { + Err(AllocError) + } else { + self.inner.allocate_zeroed(layout) + } + } + + unsafe fn deallocate(&self, ptr: NonNull, layout: Layout) { + self.inner.deallocate(ptr, layout) + } + + unsafe fn grow( + &self, + ptr: NonNull, + old_layout: Layout, + new_layout: Layout, + ) -> Result, AllocError> { + if rand::random::() < self.failure_rate { + Err(AllocError) + } else { + self.inner.grow(ptr, old_layout, new_layout) + } + } + + unsafe fn grow_zeroed( + &self, + ptr: NonNull, + old_layout: Layout, + new_layout: Layout, + ) -> Result, AllocError> { + if rand::random::() < self.failure_rate { + Err(AllocError) + } else { + self.inner.grow_zeroed(ptr, old_layout, new_layout) + } + } + + unsafe fn shrink( + &self, + ptr: NonNull, + old_layout: Layout, + new_layout: Layout, + ) -> Result, AllocError> { + if rand::random::() < self.failure_rate { + Err(AllocError) + } else { + self.inner.shrink(ptr, old_layout, new_layout) + } + } + } + + let test_function = + |mem_stat: Arc, allocator: Arc| { + let mut allocations = vec![]; + + for _ in 0..1000000 { + match rand::random::() % 6 { + // allocate + 0 => { + let size = rand::random::() % 2048 + 1; + let layout = Layout::from_size_align(size, 8).unwrap(); + + if let Ok(ptr) = allocator.allocate(layout) { + allocations.push((ptr, layout)); + } + } + // allocate_zero + 1 => { + let size = rand::random::() % 2048 + 1; + let layout = Layout::from_size_align(size, 8).unwrap(); + + if let Ok(ptr) = allocator.allocate_zeroed(layout) { + allocations.push((ptr, layout)); + } + } + // deallocate + 2 => { + if !allocations.is_empty() { + let index = rand::random::() % allocations.len(); + let (ptr, layout) = allocations.remove(index); + unsafe { allocator.deallocate(ptr.as_non_null_ptr(), layout) }; + } + } + // grow + 3 => { + if !allocations.is_empty() { + let index = rand::random::() % allocations.len(); + let (ptr, old_layout) = allocations[index]; + let new_size = + old_layout.size() + rand::random::() % 256 + 1; + let new_layout = Layout::from_size_align(new_size, 8).unwrap(); + + if let Ok(new_ptr) = unsafe { + allocator.grow(ptr.as_non_null_ptr(), old_layout, new_layout) + } { + allocations[index] = (new_ptr, new_layout); + } + } + } + // grow_zero + 4 => { + if !allocations.is_empty() { + let index = rand::random::() % allocations.len(); + let (ptr, old_layout) = allocations[index]; + let new_size = + old_layout.size() + rand::random::() % 256 + 1; + let new_layout = Layout::from_size_align(new_size, 8).unwrap(); + + if let Ok(new_ptr) = unsafe { + allocator.grow_zeroed( + ptr.as_non_null_ptr(), + old_layout, + new_layout, + ) + } { + allocations[index] = (new_ptr, new_layout); + } + } + } + // shrink + _ => { + if !allocations.is_empty() { + let index = rand::random::() % allocations.len(); + let (ptr, old_layout) = allocations[index]; + let new_size = old_layout + .size() + .saturating_sub(rand::random::() % 256); + let new_layout = + Layout::from_size_align(std::cmp::max(1, new_size), 8).unwrap(); + + if let Ok(new_ptr) = unsafe { + allocator.shrink(ptr.as_non_null_ptr(), old_layout, new_layout) + } { + allocations[index] = (new_ptr, new_layout); + } + } + } + } + } + + let actual_usage = + mem_stat.used.load(Ordering::Relaxed) + MemStatBuffer::current().memory_usage; + + let mem_stat_expected_usage = allocations + .iter() + .filter(|(_, r)| r.size() >= META_TRACKER_THRESHOLD) + .map(|(_, r)| r.size() as i64 + std::mem::size_of::() as i64) + .sum::(); + + assert_eq!(actual_usage, mem_stat_expected_usage); + + let small_usage = allocations + .iter() + .filter(|(_, r)| r.size() < META_TRACKER_THRESHOLD) + .map(|(_, r)| r.size() as i64) + .sum::(); + + assert_eq!( + small_usage + actual_usage, + MemStatBuffer::current().memory_usage + + GlobalStatBuffer::current().memory_usage + + GlobalStatBuffer::current() + .global_mem_stat + .used + .load(Ordering::Relaxed) + ); + + for (ptr, layout) in allocations { + unsafe { allocator.deallocate(ptr.as_non_null_ptr(), layout) }; + } + + MemStatBuffer::current().flush::(0).unwrap(); + GlobalStatBuffer::current().flush::(0).unwrap(); + + assert_eq!(0, mem_stat.used.load(Ordering::Relaxed)); + assert_eq!(0, MemStatBuffer::current().memory_usage); + assert_eq!(0, GlobalStatBuffer::current().memory_usage); + assert_eq!( + 0, + GlobalStatBuffer::current() + .global_mem_stat + .used + .load(Ordering::Relaxed) + ); + }; + + with_mock_env( + test_function, + ChaosAllocator { + inner: DefaultAllocator::default(), + failure_rate: 0.3, + }, + Arc::new(MemStat::global()), + ); + } + + #[test] + fn test_out_of_order_deallocation() { + let test_function = + |mem_stat: Arc, allocator: Arc| { + let layouts = [ + Layout::from_size_align(256, 8).unwrap(), + Layout::from_size_align(1024, 8).unwrap(), + Layout::from_size_align(384, 8).unwrap(), + ]; + + let pointers: Vec<_> = layouts + .iter() + .map(|&l| allocator.allocate(l).unwrap()) + .collect(); + + unsafe { + allocator.deallocate(pointers[1].as_non_null_ptr(), layouts[1]); + allocator.deallocate(pointers[0].as_non_null_ptr(), layouts[0]); + } + + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(MemStatBuffer::current().memory_usage, 0); + assert_eq!(GlobalStatBuffer::current().memory_usage, 384); + + unsafe { allocator.deallocate(pointers[2].as_non_null_ptr(), layouts[2]) }; + assert_eq!(MemStatBuffer::current().memory_usage, 0); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(GlobalStatBuffer::current().memory_usage, 0); + }; + + with_mock_env( + test_function, + DefaultAllocator::default(), + Arc::new(MemStat::global()), + ); + } + + #[test] + fn test_dynamic_memstat_switch() { + let global = Arc::new(MemStat::global()); + let test_function = { + let global = global.clone(); + move |mem_stat: Arc, _allocator: Arc| { + let layout = Layout::from_size_align(512, 8).unwrap(); + let test_function = { + let mem_stat = mem_stat.clone(); + + move |new_mem_stat: Arc, allocator: Arc| { + let layout = Layout::from_size_align(512, 8).unwrap(); + let ptr = allocator.allocate(layout).unwrap(); + + assert_eq!(GlobalStatBuffer::current().memory_usage, 0); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(new_mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!( + MemStatBuffer::current().memory_usage, + (512 + std::mem::size_of::()) as i64 + ); + + (ptr, new_mem_stat, allocator) + } + }; + + let (ptr, new_mem_stat, allocator) = + with_mock_env(test_function, DefaultAllocator::default(), global.clone()); + + assert_eq!(GlobalStatBuffer::current().memory_usage, 0); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!( + new_mem_stat.used.load(Ordering::Relaxed), + (512 + std::mem::size_of::()) as i64 + ); + assert_eq!(MemStatBuffer::current().memory_usage, 0); + + unsafe { + allocator.deallocate(ptr.as_non_null_ptr(), layout); + } + + assert_eq!(GlobalStatBuffer::current().memory_usage, 0); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!( + new_mem_stat.used.load(Ordering::Relaxed), + (512 + std::mem::size_of::()) as i64 + ); + assert_eq!( + MemStatBuffer::current().memory_usage, + -((512 + std::mem::size_of::()) as i64) + ); + + let _ = MemStatBuffer::current().flush::(0); + + assert_eq!(MemStatBuffer::current().memory_usage, 0); + assert_eq!(GlobalStatBuffer::current().memory_usage, 0); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(new_mem_stat.used.load(Ordering::Relaxed), 0); + } + }; + + with_mock_env(test_function, DefaultAllocator::default(), global); + } + + #[test] + fn test_thread_local_stat_isolation() { + let global_mem_stat = Arc::new(MemStat::global()); + + let test_function = { + let global_mem_stat = global_mem_stat.clone(); + + move |_mem_stat: Arc, + _allocator: Arc| { + const THREADS: usize = 8; + let mut handles = vec![]; + + for i in 0..THREADS { + let global_stat = global_mem_stat.clone(); + handles.push(Thread::spawn(move || { + let test_function = |mem_stat: Arc, + allocator: Arc< + dyn Allocator + Send + Sync + 'static, + >| { + let layout = Layout::from_size_align(512 * (i + 1), 8).unwrap(); + let ptr = allocator.allocate(layout).unwrap(); + + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!( + MemStatBuffer::current().memory_usage, + (512 * (i + 1) + std::mem::size_of::()) as i64 + ); + ( + ptr.as_non_null_ptr().as_ptr() as usize, + mem_stat, + layout, + allocator, + ) + }; + + with_mock_env( + test_function, + DefaultAllocator::default(), + global_stat.clone(), + ) + })); + } + + let mut memory_usage = 0; + let mut handle_res = vec![]; + for (idx, handle) in handles.into_iter().enumerate() { + let (ptr, mem_stat, layout, allocator) = handle.join().expect("Thread panic"); + + memory_usage += layout.size() as i64 + std::mem::size_of::() as i64; + assert_eq!( + mem_stat.used.load(Ordering::Relaxed), + (512 * (idx + 1) + std::mem::size_of::()) as i64 + ); + handle_res.push((ptr, mem_stat, layout, allocator)); + } + + assert_eq!(global_mem_stat.used.load(Ordering::Relaxed), memory_usage); + for (ptr, mem_stat, layout, allocator) in handle_res.into_iter() { + unsafe { allocator.deallocate(NonNull::new_unchecked(ptr as *mut u8), layout) }; + let _ = MemStatBuffer::current().flush::(0); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(MemStatBuffer::current().memory_usage, 0); + } + + let _ = GlobalStatBuffer::current().flush::(0); + assert_eq!(global_mem_stat.used.load(Ordering::Relaxed), 0); + } + }; + + with_mock_env(test_function, DefaultAllocator::default(), global_mem_stat); + } +} diff --git a/src/common/base/src/runtime/memory/mem_stat.rs b/src/common/base/src/runtime/memory/mem_stat.rs index 62a3e6689151b..d35223730e384 100644 --- a/src/common/base/src/runtime/memory/mem_stat.rs +++ b/src/common/base/src/runtime/memory/mem_stat.rs @@ -19,7 +19,8 @@ use std::sync::atomic::Ordering; use std::sync::Arc; use bytesize::ByteSize; -use log::info; + +use crate::base::GlobalSequence; /// The program mem stat /// @@ -34,11 +35,12 @@ const MINIMUM_MEMORY_LIMIT: i64 = 256 * 1024 * 1024; /// - Every stat that is fed to a child is also fed to its parent. /// - A MemStat has at most one parent. pub struct MemStat { + pub id: usize, + #[allow(unused)] name: Option, pub(crate) used: AtomicI64, - - pub(crate) peak_used: AtomicI64, + pub(crate) peek_used: AtomicI64, /// The limit of max used memory for this tracker. /// @@ -51,10 +53,11 @@ pub struct MemStat { impl MemStat { pub const fn global() -> Self { Self { + id: 0, name: None, used: AtomicI64::new(0), + peek_used: AtomicI64::new(0), limit: AtomicI64::new(0), - peak_used: AtomicI64::new(0), parent_memory_stat: vec![], } } @@ -64,11 +67,17 @@ impl MemStat { } pub fn create_child(name: String, parent_memory_stat: Vec>) -> Arc { + let id = match GlobalSequence::next() { + 0 => GlobalSequence::next(), + id => id, + }; + Arc::new(MemStat { + id, name: Some(name), used: AtomicI64::new(0), + peek_used: AtomicI64::new(0), limit: AtomicI64::new(0), - peak_used: AtomicI64::new(0), parent_memory_stat, }) } @@ -98,7 +107,7 @@ impl MemStat { let mut used = self.used.fetch_add(batch_memory_used, Ordering::Relaxed); used += batch_memory_used; - let old_peak_used = self.peak_used.fetch_max(used, Ordering::Relaxed); + self.peek_used.fetch_max(used, Ordering::Relaxed); for (idx, parent_memory_stat) in self.parent_memory_stat.iter().enumerate() { if let Err(cause) = parent_memory_stat @@ -108,11 +117,6 @@ impl MemStat { // We only roll back the memory that alloc failed self.used.fetch_sub(current_memory_alloc, Ordering::Relaxed); - if used > old_peak_used { - self.peak_used - .fetch_sub(current_memory_alloc, Ordering::Relaxed); - } - for index in 0..idx { self.parent_memory_stat[index].rollback(current_memory_alloc); } @@ -124,11 +128,6 @@ impl MemStat { if let Err(cause) = self.check_limit(used) { if NEED_ROLLBACK { - if used > old_peak_used { - self.peak_used - .fetch_sub(current_memory_alloc, Ordering::Relaxed); - } - // NOTE: we cannot rollback peak_used of parent mem stat in this case // self.peak_used.store(peak_used, Ordering::Relaxed); self.rollback(current_memory_alloc); @@ -148,15 +147,6 @@ impl MemStat { } } - pub fn movein_memory(&self, size: i64) { - let used = self.used.fetch_add(size, Ordering::Relaxed); - self.peak_used.fetch_max(used + size, Ordering::Relaxed); - } - - pub fn moveout_memory(&self, size: i64) { - self.used.fetch_sub(size, Ordering::Relaxed); - } - /// Check if used memory is out of the limit. #[inline] fn check_limit(&self, used: i64) -> Result<(), OutOfLimit> { @@ -180,33 +170,8 @@ impl MemStat { } #[inline] - #[allow(unused)] - pub fn get_peak_memory_usage(&self) -> i64 { - self.peak_used.load(Ordering::Relaxed) - } - - #[allow(unused)] - pub fn log_memory_usage(&self) { - let name = self.name.clone().unwrap_or_else(|| String::from("global")); - let memory_usage = self.used.load(Ordering::Relaxed); - let memory_usage = std::cmp::max(0, memory_usage) as u64; - info!( - "Current memory usage({}): {}.", - name, - ByteSize::b(memory_usage) - ); - } - - #[allow(unused)] - pub fn log_peek_memory_usage(&self) { - let name = self.name.clone().unwrap_or_else(|| String::from("global")); - let peak_memory_usage = self.peak_used.load(Ordering::Relaxed); - let peak_memory_usage = std::cmp::max(0, peak_memory_usage) as u64; - info!( - "Peak memory usage({}): {}.", - name, - ByteSize::b(peak_memory_usage) - ); + pub fn get_peek_memory_usage(&self) -> i64 { + self.peek_used.load(Ordering::Relaxed) } } @@ -254,7 +219,6 @@ mod tests { mem_stat.record_memory::(-1, -1).unwrap(); assert_eq!(mem_stat.used.load(Ordering::Relaxed), 2); - assert_eq!(mem_stat.peak_used.load(Ordering::Relaxed), 3); Ok(()) } @@ -272,50 +236,30 @@ mod tests { mem_stat.used.load(Ordering::Relaxed), 1 + MINIMUM_MEMORY_LIMIT ); - assert_eq!( - mem_stat.peak_used.load(Ordering::Relaxed), - 1 + MINIMUM_MEMORY_LIMIT - ); assert!(mem_stat.record_memory::(1, 1).is_err()); assert_eq!( mem_stat.used.load(Ordering::Relaxed), 1 + MINIMUM_MEMORY_LIMIT + 1 ); - assert_eq!( - mem_stat.peak_used.load(Ordering::Relaxed), - 1 + MINIMUM_MEMORY_LIMIT + 1 - ); assert!(mem_stat.record_memory::(1, 1).is_err()); assert_eq!( mem_stat.used.load(Ordering::Relaxed), 1 + MINIMUM_MEMORY_LIMIT + 1 ); - assert_eq!( - mem_stat.peak_used.load(Ordering::Relaxed), - 1 + MINIMUM_MEMORY_LIMIT + 1 - ); assert!(mem_stat.record_memory::(-1, -1).is_err()); assert_eq!( mem_stat.used.load(Ordering::Relaxed), 1 + MINIMUM_MEMORY_LIMIT + 1 ); - assert_eq!( - mem_stat.peak_used.load(Ordering::Relaxed), - 1 + MINIMUM_MEMORY_LIMIT + 1 - ); assert!(mem_stat.record_memory::(-1, -1).is_err()); assert_eq!( mem_stat.used.load(Ordering::Relaxed), 1 + MINIMUM_MEMORY_LIMIT ); - assert_eq!( - mem_stat.peak_used.load(Ordering::Relaxed), - 1 + MINIMUM_MEMORY_LIMIT + 1 - ); Ok(()) } @@ -331,18 +275,14 @@ mod tests { mem_stat.record_memory::(-1, -1).unwrap(); assert_eq!(mem_stat.used.load(Ordering::Relaxed), 2); - assert_eq!(mem_stat.peak_used.load(Ordering::Relaxed), 3); assert_eq!(child_mem_stat.used.load(Ordering::Relaxed), 0); - assert_eq!(child_mem_stat.peak_used.load(Ordering::Relaxed), 0); child_mem_stat.record_memory::(1, 1).unwrap(); child_mem_stat.record_memory::(2, 2).unwrap(); child_mem_stat.record_memory::(-1, -1).unwrap(); assert_eq!(mem_stat.used.load(Ordering::Relaxed), 4); - assert_eq!(mem_stat.peak_used.load(Ordering::Relaxed), 5); assert_eq!(child_mem_stat.used.load(Ordering::Relaxed), 2); - assert_eq!(child_mem_stat.peak_used.load(Ordering::Relaxed), 3); Ok(()) } @@ -363,12 +303,7 @@ mod tests { mem_stat.used.load(Ordering::Relaxed), 1 + MINIMUM_MEMORY_LIMIT ); - assert_eq!( - mem_stat.peak_used.load(Ordering::Relaxed), - 1 + MINIMUM_MEMORY_LIMIT - ); assert_eq!(child_mem_stat.used.load(Ordering::Relaxed), 0); - assert_eq!(child_mem_stat.peak_used.load(Ordering::Relaxed), 0); child_mem_stat.record_memory::(1, 1).unwrap(); assert!(child_mem_stat @@ -378,18 +313,10 @@ mod tests { mem_stat.used.load(Ordering::Relaxed), 1 + MINIMUM_MEMORY_LIMIT + 1 + MINIMUM_MEMORY_LIMIT ); - assert_eq!( - mem_stat.peak_used.load(Ordering::Relaxed), - 1 + MINIMUM_MEMORY_LIMIT + 1 + MINIMUM_MEMORY_LIMIT - ); assert_eq!( child_mem_stat.used.load(Ordering::Relaxed), 1 + MINIMUM_MEMORY_LIMIT ); - assert_eq!( - child_mem_stat.peak_used.load(Ordering::Relaxed), - 1 + MINIMUM_MEMORY_LIMIT - ); // parent failure let mem_stat = MemStat::create("TEST".to_string()); @@ -402,9 +329,7 @@ mod tests { .record_memory::(1 + MINIMUM_MEMORY_LIMIT, 1 + MINIMUM_MEMORY_LIMIT) .is_err()); assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); - assert_eq!(mem_stat.peak_used.load(Ordering::Relaxed), 0); assert_eq!(child_mem_stat.used.load(Ordering::Relaxed), 0); - assert_eq!(child_mem_stat.peak_used.load(Ordering::Relaxed), 0); // child failure let mem_stat = MemStat::create("TEST".to_string()); @@ -419,7 +344,6 @@ mod tests { assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); // assert_eq!(mem_stat.peak_used.load(Ordering::Relaxed), 0); assert_eq!(child_mem_stat.used.load(Ordering::Relaxed), 0); - assert_eq!(child_mem_stat.peak_used.load(Ordering::Relaxed), 0); Ok(()) } diff --git a/src/common/base/src/runtime/memory/mod.rs b/src/common/base/src/runtime/memory/mod.rs index 80be2e68348e3..2c895f8ed807c 100644 --- a/src/common/base/src/runtime/memory/mod.rs +++ b/src/common/base/src/runtime/memory/mod.rs @@ -14,10 +14,12 @@ mod alloc_error_hook; mod mem_stat; -mod stat_buffer; +mod stat_buffer_global; +mod stat_buffer_mem_stat; pub use alloc_error_hook::set_alloc_error_hook; pub use mem_stat::MemStat; pub use mem_stat::OutOfLimit; pub use mem_stat::GLOBAL_MEM_STAT; -pub use stat_buffer::StatBuffer; +pub use stat_buffer_global::GlobalStatBuffer; +pub use stat_buffer_mem_stat::MemStatBuffer; diff --git a/src/common/base/src/runtime/memory/stat_buffer.rs b/src/common/base/src/runtime/memory/stat_buffer_global.rs similarity index 65% rename from src/common/base/src/runtime/memory/stat_buffer.rs rename to src/common/base/src/runtime/memory/stat_buffer_global.rs index 45b1336a58cf4..4eb20f411f296 100644 --- a/src/common/base/src/runtime/memory/stat_buffer.rs +++ b/src/common/base/src/runtime/memory/stat_buffer_global.rs @@ -12,8 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::alloc::AllocError; use std::ptr::addr_of_mut; use std::sync::atomic::Ordering; +#[cfg(test)] +use std::sync::Arc; use crate::runtime::memory::mem_stat::OutOfLimit; use crate::runtime::memory::MemStat; @@ -22,23 +25,23 @@ use crate::runtime::ThreadTracker; use crate::runtime::GLOBAL_MEM_STAT; #[thread_local] -static mut STAT_BUFFER: StatBuffer = StatBuffer::empty(&GLOBAL_MEM_STAT); +static mut GLOBAL_STAT_BUFFER: GlobalStatBuffer = GlobalStatBuffer::empty(&GLOBAL_MEM_STAT); -static MEM_STAT_BUFFER_SIZE: i64 = 4 * 1024 * 1024; +pub static MEM_STAT_BUFFER_SIZE: i64 = 4 * 1024 * 1024; /// Buffering memory allocation stats. /// /// A StatBuffer buffers stats changes in local variables, and periodically flush them to other storage such as an `Arc` shared by several threads. #[derive(Clone)] -pub struct StatBuffer { - memory_usage: i64, +pub struct GlobalStatBuffer { + pub(crate) memory_usage: i64, // Whether to allow unlimited memory. Alloc memory will not panic if it is true. unlimited_flag: bool, - global_mem_stat: &'static MemStat, + pub(crate) global_mem_stat: &'static MemStat, destroyed_thread_local_macro: bool, } -impl StatBuffer { +impl GlobalStatBuffer { pub const fn empty(global_mem_stat: &'static MemStat) -> Self { Self { memory_usage: 0, @@ -48,8 +51,8 @@ impl StatBuffer { } } - pub fn current() -> &'static mut StatBuffer { - unsafe { &mut *addr_of_mut!(STAT_BUFFER) } + pub fn current() -> &'static mut GlobalStatBuffer { + unsafe { &mut *addr_of_mut!(GLOBAL_STAT_BUFFER) } } pub fn is_unlimited(&self) -> bool { @@ -74,44 +77,62 @@ impl StatBuffer { ) -> std::result::Result<(), OutOfLimit> { match std::mem::take(&mut self.memory_usage) { 0 => Ok(()), - usage => { - if let Err(e) = self.global_mem_stat.record_memory::(usage, alloc) { - if !ROLLBACK { - let _ = ThreadTracker::record_memory::(usage, alloc); - } + usage => self.global_mem_stat.record_memory::(usage, alloc), + } + } - return Err(e); - } + pub fn alloc(&mut self, memory_usage: i64) -> std::result::Result<(), AllocError> { + // Rust will alloc or dealloc memory after the thread local is destroyed when we using thread_local macro. + // This is the boundary of thread exit. It may be dangerous to throw mistakes here. + if self.destroyed_thread_local_macro { + let used = self + .global_mem_stat + .used + .fetch_add(memory_usage, Ordering::Relaxed); + self.global_mem_stat + .peek_used + .fetch_max(used + memory_usage, Ordering::Relaxed); + return Ok(()); + } - if let Err(e) = ThreadTracker::record_memory::(usage, alloc) { - if ROLLBACK { - self.global_mem_stat.rollback(alloc); - return Err(e); + match self.incr(memory_usage) <= MEM_STAT_BUFFER_SIZE { + true => Ok(()), + false => { + match !std::thread::panicking() && !self.unlimited_flag { + true => { + if let Err(out_of_limit) = self.flush::(memory_usage) { + let _guard = LimitMemGuard::enter_unlimited(); + ThreadTracker::replace_error_message(Some(format!( + "{:?}", + out_of_limit + ))); + return Err(AllocError); + } + } + false => { + let _ = self.flush::(0); } - } + }; Ok(()) } } } - pub fn alloc(&mut self, memory_usage: i64) -> std::result::Result<(), OutOfLimit> { - // Rust will alloc or dealloc memory after the thread local is destroyed when we using thread_local macro. - // This is the boundary of thread exit. It may be dangerous to throw mistakes here. + pub fn force_alloc(&mut self, memory_usage: i64) { if self.destroyed_thread_local_macro { let used = self .global_mem_stat .used .fetch_add(memory_usage, Ordering::Relaxed); self.global_mem_stat - .peak_used + .peek_used .fetch_max(used + memory_usage, Ordering::Relaxed); - return Ok(()); + return; } - match self.incr(memory_usage) <= MEM_STAT_BUFFER_SIZE { - true => Ok(()), - false => self.flush::(memory_usage), + if self.incr(memory_usage) > MEM_STAT_BUFFER_SIZE { + let _ = self.flush::(memory_usage); } } @@ -143,20 +164,62 @@ impl StatBuffer { } } +#[cfg(test)] +pub struct MockGuard { + _mem_stat: Arc, + old_global_stat_buffer: GlobalStatBuffer, +} + +#[cfg(test)] +impl MockGuard { + pub fn flush(&mut self) -> Result<(), OutOfLimit> { + GlobalStatBuffer::current().flush::(0) + } +} + +#[cfg(test)] +impl Drop for MockGuard { + fn drop(&mut self) { + let _ = self.flush(); + std::mem::swap( + GlobalStatBuffer::current(), + &mut self.old_global_stat_buffer, + ); + } +} + +#[cfg(test)] +impl GlobalStatBuffer { + pub fn mock(mem_stat: Arc) -> MockGuard { + let mut mock_global_stat_buffer = Self { + memory_usage: 0, + global_mem_stat: unsafe { std::mem::transmute::<&_, &'static _>(mem_stat.as_ref()) }, + unlimited_flag: false, + destroyed_thread_local_macro: false, + }; + + std::mem::swap(GlobalStatBuffer::current(), &mut mock_global_stat_buffer); + MockGuard { + _mem_stat: mem_stat, + old_global_stat_buffer: mock_global_stat_buffer, + } + } +} + #[cfg(test)] mod tests { use std::sync::atomic::Ordering; use databend_common_exception::Result; - use crate::runtime::memory::stat_buffer::MEM_STAT_BUFFER_SIZE; + use crate::runtime::memory::stat_buffer_global::MEM_STAT_BUFFER_SIZE; + use crate::runtime::memory::GlobalStatBuffer; use crate::runtime::memory::MemStat; - use crate::runtime::memory::StatBuffer; #[test] fn test_alloc() -> Result<()> { static TEST_MEM_STATE: MemStat = MemStat::global(); - let mut buffer = StatBuffer::empty(&TEST_MEM_STATE); + let mut buffer = GlobalStatBuffer::empty(&TEST_MEM_STATE); buffer.alloc(1).unwrap(); assert_eq!(buffer.memory_usage, 1); @@ -181,7 +244,7 @@ mod tests { #[test] fn test_dealloc() -> Result<()> { static TEST_MEM_STATE: MemStat = MemStat::global(); - let mut buffer = StatBuffer::empty(&TEST_MEM_STATE); + let mut buffer = GlobalStatBuffer::empty(&TEST_MEM_STATE); buffer.dealloc(1); assert_eq!(buffer.memory_usage, -1); @@ -207,7 +270,7 @@ mod tests { fn test_mark_destroyed() -> Result<()> { static TEST_MEM_STATE: MemStat = MemStat::global(); - let mut buffer = StatBuffer::empty(&TEST_MEM_STATE); + let mut buffer = GlobalStatBuffer::empty(&TEST_MEM_STATE); assert!(!buffer.destroyed_thread_local_macro); buffer.alloc(1).unwrap(); diff --git a/src/common/base/src/runtime/memory/stat_buffer_mem_stat.rs b/src/common/base/src/runtime/memory/stat_buffer_mem_stat.rs new file mode 100644 index 0000000000000..8c890598dd61b --- /dev/null +++ b/src/common/base/src/runtime/memory/stat_buffer_mem_stat.rs @@ -0,0 +1,384 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::alloc::AllocError; +use std::ptr::addr_of_mut; +use std::sync::atomic::Ordering; +use std::sync::Arc; + +use crate::runtime::memory::stat_buffer_global::MEM_STAT_BUFFER_SIZE; +use crate::runtime::memory::OutOfLimit; +use crate::runtime::LimitMemGuard; +use crate::runtime::MemStat; +use crate::runtime::ThreadTracker; +use crate::runtime::GLOBAL_MEM_STAT; + +#[thread_local] +static mut MEM_STAT_BUFFER: MemStatBuffer = MemStatBuffer::empty(&GLOBAL_MEM_STAT); + +pub struct MemStatBuffer { + pub(crate) cur_mem_stat_id: usize, + pub(crate) cur_mem_stat: Option>, + pub(crate) memory_usage: i64, + // Whether to allow unlimited memory. Alloc memory will not panic if it is true. + unlimited_flag: bool, + pub(crate) global_mem_stat: &'static MemStat, + destroyed_thread_local_macro: bool, +} + +impl MemStatBuffer { + pub const fn empty(global_mem_stat: &'static MemStat) -> MemStatBuffer { + MemStatBuffer { + global_mem_stat, + cur_mem_stat_id: 0, + cur_mem_stat: None, + memory_usage: 0, + unlimited_flag: false, + destroyed_thread_local_macro: false, + } + } + + pub fn current() -> &'static mut MemStatBuffer { + unsafe { &mut *addr_of_mut!(MEM_STAT_BUFFER) } + } + + pub fn set_unlimited_flag(&mut self, flag: bool) -> bool { + let old = self.unlimited_flag; + self.unlimited_flag = flag; + old + } + + pub fn incr(&mut self, bs: i64) -> i64 { + self.memory_usage += bs; + self.memory_usage + } + + pub fn flush(&mut self, alloc: i64) -> Result<(), OutOfLimit> { + let memory_usage = std::mem::take(&mut self.memory_usage); + + if memory_usage == 0 { + return Ok(()); + } + + self.cur_mem_stat_id = 0; + if let Some(mem_stat) = self.cur_mem_stat.take() { + if let Err(cause) = mem_stat.record_memory::(memory_usage, alloc) { + let memory_usage = match FALLBACK { + true => memory_usage - alloc, + false => memory_usage, + }; + + self.global_mem_stat + .record_memory::(memory_usage, 0)?; + return Err(cause); + } + } + + self.global_mem_stat + .record_memory::(memory_usage, alloc) + } + + pub fn alloc(&mut self, mem_stat: &Arc, usage: i64) -> Result<(), AllocError> { + if self.destroyed_thread_local_macro { + let used = mem_stat.used.fetch_add(usage, Ordering::Relaxed); + mem_stat + .peek_used + .fetch_max(used + usage, Ordering::Relaxed); + return Ok(()); + } + + if mem_stat.id != self.cur_mem_stat_id { + if let Err(out_of_limit) = self.flush::(0) { + if !std::thread::panicking() && !self.unlimited_flag { + let _guard = LimitMemGuard::enter_unlimited(); + ThreadTracker::replace_error_message(Some(format!("{:?}", out_of_limit))); + return Err(AllocError); + } + } + + self.cur_mem_stat = Some(mem_stat.clone()); + self.cur_mem_stat_id = mem_stat.id; + } + + if self.incr(usage) >= MEM_STAT_BUFFER_SIZE { + let alloc = usage; + match !std::thread::panicking() && !self.unlimited_flag { + true => { + if let Err(out_of_limit) = self.flush::(alloc) { + let _guard = LimitMemGuard::enter_unlimited(); + ThreadTracker::replace_error_message(Some(format!("{:?}", out_of_limit))); + return Err(AllocError); + } + } + false => { + let _ = self.flush::(0); + } + }; + } + + Ok(()) + } + + pub fn force_alloc(&mut self, mem_stat: &Arc, memory_usage: i64) { + if self.destroyed_thread_local_macro { + let used = mem_stat.used.fetch_add(memory_usage, Ordering::Relaxed); + mem_stat + .peek_used + .fetch_max(used + memory_usage, Ordering::Relaxed); + return; + } + + if mem_stat.id != self.cur_mem_stat_id { + let _ = self.flush::(0); + + self.cur_mem_stat = Some(mem_stat.clone()); + self.cur_mem_stat_id = mem_stat.id; + } + + if self.incr(memory_usage) >= MEM_STAT_BUFFER_SIZE { + let alloc = memory_usage; + let _ = self.flush::(alloc); + } + } + + pub fn dealloc(&mut self, mem_stat: &Arc, memory_usage: i64) { + let memory_usage = -memory_usage; + + if self.destroyed_thread_local_macro { + mem_stat.used.fetch_add(memory_usage, Ordering::Relaxed); + return; + } + + debug_assert_eq!( + Arc::weak_count(mem_stat), + 0, + "mem stat address {}", + Arc::as_ptr(mem_stat) as usize + ); + + if mem_stat.id != self.cur_mem_stat_id { + if Arc::strong_count(mem_stat) == 1 { + mem_stat.used.fetch_add(memory_usage, Ordering::Relaxed); + self.global_mem_stat + .used + .fetch_add(memory_usage, Ordering::Relaxed); + return; + } + + let _ = self.flush::(0); + + self.cur_mem_stat = Some(mem_stat.clone()); + self.cur_mem_stat_id = mem_stat.id; + } + + if self.incr(memory_usage) <= -MEM_STAT_BUFFER_SIZE || Arc::strong_count(mem_stat) == 1 { + let alloc = memory_usage; + let _ = self.flush::(alloc); + } + + // NOTE: De-allocation does not panic + // even when it's possible exceeding the limit + // due to other threads sharing the same MemStat may have allocated a lot of memory. + } + + pub fn mark_destroyed(&mut self) { + let _guard = LimitMemGuard::enter_unlimited(); + + self.destroyed_thread_local_macro = true; + let _ = self.flush::(0); + } +} + +#[cfg(test)] +pub struct MockGuard { + _mem_stat: Arc, + old_mem_stat_buffer: MemStatBuffer, +} + +#[cfg(test)] +impl MockGuard { + pub fn flush(&mut self) -> Result<(), OutOfLimit> { + MemStatBuffer::current().flush::(0) + } +} + +#[cfg(test)] +impl Drop for MockGuard { + fn drop(&mut self) { + let _ = self.flush(); + std::mem::swap(MemStatBuffer::current(), &mut self.old_mem_stat_buffer); + } +} + +#[cfg(test)] +impl MemStatBuffer { + pub fn mock(mem_stat: Arc) -> MockGuard { + let mut mem_stat_buffer = + Self::empty(unsafe { std::mem::transmute::<&_, &'static _>(mem_stat.as_ref()) }); + std::mem::swap(MemStatBuffer::current(), &mut mem_stat_buffer); + + MockGuard { + _mem_stat: mem_stat, + old_mem_stat_buffer: mem_stat_buffer, + } + } +} + +#[cfg(test)] +mod tests { + use std::alloc::AllocError; + use std::sync::atomic::Ordering; + + use crate::runtime::memory::stat_buffer_global::MEM_STAT_BUFFER_SIZE; + use crate::runtime::memory::stat_buffer_mem_stat::MemStatBuffer; + use crate::runtime::GlobalStatBuffer; + use crate::runtime::MemStat; + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_alloc_with_same_allocator() -> Result<(), AllocError> { + static TEST_GLOBAL: MemStat = MemStat::global(); + + let mut buffer = MemStatBuffer::empty(&TEST_GLOBAL); + + let mem_stat = MemStat::create(String::from("test")); + buffer.alloc(&mem_stat, 1)?; + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(TEST_GLOBAL.used.load(Ordering::Relaxed), 0); + + buffer.alloc(&mem_stat, MEM_STAT_BUFFER_SIZE - 2)?; + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(TEST_GLOBAL.used.load(Ordering::Relaxed), 0); + + buffer.alloc(&mem_stat, 1)?; + assert_eq!(mem_stat.used.load(Ordering::Relaxed), MEM_STAT_BUFFER_SIZE); + assert_eq!( + TEST_GLOBAL.used.load(Ordering::Relaxed), + MEM_STAT_BUFFER_SIZE + ); + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_alloc_with_diff_allocator() -> Result<(), AllocError> { + static TEST_GLOBAL: MemStat = MemStat::global(); + + let mut buffer = MemStatBuffer::empty(&TEST_GLOBAL); + + let mem_stat_1 = MemStat::create(String::from("test")); + let mem_stat_2 = MemStat::create(String::from("test")); + buffer.alloc(&mem_stat_1, 1)?; + assert_eq!(mem_stat_1.used.load(Ordering::Relaxed), 0); + assert_eq!(mem_stat_2.used.load(Ordering::Relaxed), 0); + assert_eq!(TEST_GLOBAL.used.load(Ordering::Relaxed), 0); + buffer.alloc(&mem_stat_2, 1)?; + assert_eq!(mem_stat_1.used.load(Ordering::Relaxed), 1); + assert_eq!(mem_stat_2.used.load(Ordering::Relaxed), 0); + assert_eq!(TEST_GLOBAL.used.load(Ordering::Relaxed), 1); + + buffer.alloc(&mem_stat_1, 1)?; + assert_eq!(mem_stat_1.used.load(Ordering::Relaxed), 1); + assert_eq!(mem_stat_2.used.load(Ordering::Relaxed), 1); + assert_eq!(TEST_GLOBAL.used.load(Ordering::Relaxed), 2); + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_dealloc_with_same_allocator() -> Result<(), AllocError> { + static TEST_GLOBAL: MemStat = MemStat::global(); + + let mut buffer = MemStatBuffer::empty(&TEST_GLOBAL); + + let mem_stat = MemStat::create(String::from("test")); + let _shared = mem_stat.clone(); + + buffer.dealloc(&mem_stat, 1); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(TEST_GLOBAL.used.load(Ordering::Relaxed), 0); + + buffer.dealloc(&mem_stat, MEM_STAT_BUFFER_SIZE - 2); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), 0); + assert_eq!(TEST_GLOBAL.used.load(Ordering::Relaxed), 0); + + buffer.dealloc(&mem_stat, 1); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), -MEM_STAT_BUFFER_SIZE); + assert_eq!( + TEST_GLOBAL.used.load(Ordering::Relaxed), + -MEM_STAT_BUFFER_SIZE + ); + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_dealloc_with_diff_allocator() -> Result<(), AllocError> { + static TEST_GLOBAL: MemStat = MemStat::global(); + + let mut buffer = MemStatBuffer::empty(&TEST_GLOBAL); + + let mem_stat_1 = MemStat::create(String::from("test")); + let mem_stat_2 = MemStat::create(String::from("test")); + let _shared = (mem_stat_1.clone(), mem_stat_2.clone()); + + buffer.dealloc(&mem_stat_1, 1); + assert_eq!(mem_stat_1.used.load(Ordering::Relaxed), 0); + assert_eq!(mem_stat_2.used.load(Ordering::Relaxed), 0); + assert_eq!(TEST_GLOBAL.used.load(Ordering::Relaxed), 0); + buffer.dealloc(&mem_stat_2, 1); + assert_eq!(mem_stat_1.used.load(Ordering::Relaxed), -1); + assert_eq!(mem_stat_2.used.load(Ordering::Relaxed), 0); + assert_eq!(TEST_GLOBAL.used.load(Ordering::Relaxed), -1); + + buffer.dealloc(&mem_stat_1, 1); + assert_eq!(mem_stat_1.used.load(Ordering::Relaxed), -1); + assert_eq!(mem_stat_2.used.load(Ordering::Relaxed), -1); + assert_eq!(TEST_GLOBAL.used.load(Ordering::Relaxed), -2); + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_dealloc_with_unique_allocator() -> Result<(), AllocError> { + static TEST_GLOBAL: MemStat = MemStat::global(); + + let mut buffer = MemStatBuffer::empty(&TEST_GLOBAL); + + let mem_stat = MemStat::create(String::from("test")); + + buffer.dealloc(&mem_stat, 1); + let _ = GlobalStatBuffer::current().flush::(0); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), -1); + assert_eq!(TEST_GLOBAL.used.load(Ordering::Relaxed), -1); + + buffer.dealloc(&mem_stat, MEM_STAT_BUFFER_SIZE - 2); + assert_eq!( + mem_stat.used.load(Ordering::Relaxed), + -(MEM_STAT_BUFFER_SIZE - 1) + ); + assert_eq!( + TEST_GLOBAL.used.load(Ordering::Relaxed), + -(MEM_STAT_BUFFER_SIZE - 1) + ); + + buffer.dealloc(&mem_stat, 1); + assert_eq!(mem_stat.used.load(Ordering::Relaxed), -MEM_STAT_BUFFER_SIZE); + assert_eq!( + TEST_GLOBAL.used.load(Ordering::Relaxed), + -MEM_STAT_BUFFER_SIZE + ); + + Ok(()) + } +} diff --git a/src/common/base/src/runtime/mod.rs b/src/common/base/src/runtime/mod.rs index bd34e051ef9a9..d809ac2792525 100644 --- a/src/common/base/src/runtime/mod.rs +++ b/src/common/base/src/runtime/mod.rs @@ -35,7 +35,10 @@ pub use defer::defer; pub use global_runtime::GlobalIORuntime; pub use global_runtime::GlobalQueryRuntime; pub use memory::set_alloc_error_hook; +pub use memory::GlobalStatBuffer; pub use memory::MemStat; +pub use memory::MemStatBuffer; +pub use memory::OutOfLimit; pub use memory::GLOBAL_MEM_STAT; pub use runtime::block_on; pub use runtime::execute_futures_in_parallel; diff --git a/src/common/base/src/runtime/runtime.rs b/src/common/base/src/runtime/runtime.rs index 8b13af7673c34..65634dc7c2a1f 100644 --- a/src/common/base/src/runtime/runtime.rs +++ b/src/common/base/src/runtime/runtime.rs @@ -37,7 +37,6 @@ use tokio::sync::Semaphore; use crate::runtime::catch_unwind::CatchUnwindFuture; use crate::runtime::drop_guard; -use crate::runtime::memory::MemStat; use crate::runtime::Thread; use crate::runtime::ThreadJoinHandle; use crate::runtime::ThreadTracker; @@ -132,15 +131,12 @@ pub struct Runtime { /// Runtime handle. handle: Handle, - /// Memory tracker for this runtime - tracker: Arc, - /// Use to receive a drop signal when dropper is dropped. _dropper: Dropper, } impl Runtime { - fn create(name: Option, tracker: Arc, builder: &mut Builder) -> Result { + fn create(name: Option, builder: &mut Builder) -> Result { let runtime = builder .build() .map_err(|tokio_error| ErrorCode::TokioError(tokio_error.to_string()))?; @@ -173,7 +169,6 @@ impl Runtime { Ok(Runtime { handle, - tracker, _dropper: Dropper { name, close: Some(send_stop), @@ -182,15 +177,10 @@ impl Runtime { }) } - pub fn get_tracker(&self) -> Arc { - self.tracker.clone() - } - /// Spawns a new tokio runtime with a default thread count on a background /// thread and returns a `Handle` which can be used to spawn tasks via /// its executor. pub fn with_default_worker_threads() -> Result { - let mem_stat = MemStat::create(String::from("UnnamedRuntime")); let mut runtime_builder = tokio::runtime::Builder::new_multi_thread(); #[cfg(debug_assertions)] @@ -207,7 +197,6 @@ impl Runtime { Self::create( None, - mem_stat, runtime_builder .enable_all() .on_thread_start(ThreadTracker::init), @@ -216,13 +205,6 @@ impl Runtime { #[allow(unused_mut)] pub fn with_worker_threads(workers: usize, mut thread_name: Option) -> Result { - let mut mem_stat_name = String::from("UnnamedRuntime"); - - if let Some(thread_name) = thread_name.as_ref() { - mem_stat_name = format!("{}Runtime", thread_name); - } - - let mem_stat = MemStat::create(mem_stat_name); let mut runtime_builder = tokio::runtime::Builder::new_multi_thread(); #[cfg(debug_assertions)] @@ -243,7 +225,6 @@ impl Runtime { Self::create( thread_name, - mem_stat, runtime_builder .enable_all() .on_thread_start(ThreadTracker::init) diff --git a/src/common/base/src/runtime/runtime_tracker.rs b/src/common/base/src/runtime/runtime_tracker.rs index c9b4acba445eb..40ba0227edb53 100644 --- a/src/common/base/src/runtime/runtime_tracker.rs +++ b/src/common/base/src/runtime/runtime_tracker.rs @@ -42,7 +42,6 @@ //! When `TrackedFuture` is `poll()`ed, its `ThreadTracker` is installed to the running thread //! and will be restored when `poll()` returns. -use std::alloc::AllocError; use std::cell::RefCell; use std::future::Future; use std::pin::Pin; @@ -52,11 +51,11 @@ use std::task::Poll; use pin_project_lite::pin_project; +use crate::runtime::memory::GlobalStatBuffer; use crate::runtime::memory::MemStat; -use crate::runtime::memory::OutOfLimit; -use crate::runtime::memory::StatBuffer; use crate::runtime::metrics::ScopedRegistry; use crate::runtime::profile::Profile; +use crate::runtime::MemStatBuffer; // For implemented and needs to call drop, we cannot use the attribute tag thread local. // https://play.rust-lang.org/?version=nightly&mode=debug&edition=2021&gist=ea33533387d401e86423df1a764b5609 @@ -65,30 +64,30 @@ thread_local! { } pub struct LimitMemGuard { - saved: bool, + global_saved: bool, + mem_stat_saved: bool, } impl LimitMemGuard { pub fn enter_unlimited() -> Self { Self { - saved: StatBuffer::current().set_unlimited_flag(true), + global_saved: GlobalStatBuffer::current().set_unlimited_flag(true), + mem_stat_saved: MemStatBuffer::current().set_unlimited_flag(true), } } pub fn enter_limited() -> Self { Self { - saved: StatBuffer::current().set_unlimited_flag(false), + global_saved: GlobalStatBuffer::current().set_unlimited_flag(false), + mem_stat_saved: MemStatBuffer::current().set_unlimited_flag(false), } } - - pub(crate) fn is_unlimited() -> bool { - StatBuffer::current().is_unlimited() - } } impl Drop for LimitMemGuard { fn drop(&mut self) { - StatBuffer::current().set_unlimited_flag(self.saved); + MemStatBuffer::current().set_unlimited_flag(self.mem_stat_saved); + GlobalStatBuffer::current().set_unlimited_flag(self.global_saved); } } @@ -112,7 +111,7 @@ pub struct TrackingGuard { impl Drop for TrackingGuard { fn drop(&mut self) { - let _ = StatBuffer::current().flush::(0); + let _ = GlobalStatBuffer::current().flush::(0); TRACKER.with(|x| { let mut thread_tracker = x.borrow_mut(); @@ -146,7 +145,8 @@ impl Future for TrackingFuture { impl Drop for ThreadTracker { fn drop(&mut self) { - StatBuffer::current().mark_destroyed(); + MemStatBuffer::current().mark_destroyed(); + GlobalStatBuffer::current().mark_destroyed(); } } @@ -157,7 +157,6 @@ impl Drop for ThreadTracker { impl ThreadTracker { pub(crate) const fn empty() -> Self { Self { - // mem_stat: None, out_of_limit_desc: None, payload: TrackingPayload { profile: None, @@ -185,7 +184,8 @@ impl ThreadTracker { let mut guard = TrackingGuard { saved: tracking_payload, }; - let _ = StatBuffer::current().flush::(0); + let _ = MemStatBuffer::current().flush::(0); + let _ = GlobalStatBuffer::current().flush::(0); TRACKER.with(move |x| { let mut thread_tracker = x.borrow_mut(); @@ -224,63 +224,13 @@ impl ThreadTracker { }) } - /// Accumulate stat about allocated memory. - /// - /// `size` is the positive number of allocated bytes. - #[inline] - pub fn alloc(size: i64) -> Result<(), AllocError> { - if let Err(out_of_limit) = StatBuffer::current().alloc(size) { - // https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=03d21a15e52c7c0356fca04ece283cf9 - if !std::thread::panicking() && !LimitMemGuard::is_unlimited() { - let _guard = LimitMemGuard::enter_unlimited(); - ThreadTracker::replace_error_message(Some(format!("{:?}", out_of_limit))); - return Err(AllocError); - } - } - - Ok(()) - } - - /// Accumulate deallocated memory. - /// - /// `size` is positive number of bytes of the memory to deallocate. - #[inline] - pub fn dealloc(size: i64) { - StatBuffer::current().dealloc(size) - } - - pub fn movein_memory(size: i64) { - TRACKER.with(|tracker| { - let thread_tracker = tracker.borrow(); - if let Some(mem_stat) = &thread_tracker.payload.mem_stat { - mem_stat.movein_memory(size); - } - }) - } - - pub fn moveout_memory(size: i64) { - TRACKER.with(|tracker| { - let thread_tracker = tracker.borrow(); - if let Some(mem_stat) = &thread_tracker.payload.mem_stat { - mem_stat.moveout_memory(size); - } - }) - } - - pub fn record_memory(batch: i64, cur: i64) -> Result<(), OutOfLimit> { - let has_thread_local = TRACKER.try_with(|tracker: &RefCell| { - // We need to ensure no heap memory alloc or dealloc. it will cause panic of borrow recursive call. - let tracker = tracker.borrow(); - match tracker.payload.mem_stat.as_deref() { - None => Ok(()), - Some(mem_stat) => mem_stat.record_memory::(batch, cur), - } - }); - - match has_thread_local { - Ok(Ok(_)) | Err(_) => Ok(()), - Ok(Err(oom)) => Err(oom), - } + pub fn mem_stat() -> Option<&'static Arc> { + TRACKER + .try_with(|tracker| { + let tracker = tracker.borrow(); + unsafe { std::mem::transmute(tracker.payload.mem_stat.as_ref()) } + }) + .unwrap_or(None) } pub fn query_id() -> Option<&'static String> { diff --git a/src/common/base/tests/it/main.rs b/src/common/base/tests/it/main.rs index abc26daac688c..d40345946826f 100644 --- a/src/common/base/tests/it/main.rs +++ b/src/common/base/tests/it/main.rs @@ -12,11 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use databend_common_base::mem_allocator::GlobalAllocator; +use databend_common_base::mem_allocator::TrackingGlobalAllocator; mod ext; mod fixed_heap; -mod memory; mod metrics; mod pool; mod pool_retry; @@ -28,4 +27,4 @@ mod string; // runtime tests depends on the memory stat collector. #[global_allocator] -pub static GLOBAL_ALLOCATOR: GlobalAllocator = GlobalAllocator; +pub static GLOBAL_ALLOCATOR: TrackingGlobalAllocator = TrackingGlobalAllocator::create(); diff --git a/src/common/base/tests/it/memory/mem_stat.rs b/src/common/base/tests/it/memory/mem_stat.rs deleted file mode 100644 index ede0f4acf5172..0000000000000 --- a/src/common/base/tests/it/memory/mem_stat.rs +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2021 Datafuse Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::time::SystemTime; - -use databend_common_base::runtime::MemStat; -use databend_common_base::runtime::ThreadTracker; - -#[test] -fn test_mem_tracker_with_primitive_type() { - fn test_primitive_type(index: T) { - let mem_stat = MemStat::create("TEST".to_string()); - let mut payload = ThreadTracker::new_tracking_payload(); - payload.mem_stat = Some(mem_stat.clone()); - - let _guard = ThreadTracker::tracking(payload); - let _test = Box::new(index); - - drop(_guard); - drop(_test); - assert_eq!( - mem_stat.get_memory_usage(), - std::mem::size_of_val(&index) as i64 - ); - } - - test_primitive_type(0_i8); - test_primitive_type(0_i16); - test_primitive_type(0_i32); - test_primitive_type(0_i64); - test_primitive_type(0_u8); - test_primitive_type(0_u16); - test_primitive_type(0_u32); - test_primitive_type(0_u64); -} - -#[test] -fn test_mem_tracker_with_string_type() { - for _index in 0..10 { - let length = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .unwrap() - .as_nanos() as usize - % 1000; - - let mem_stat = MemStat::create("TEST".to_string()); - let mut payload = ThreadTracker::new_tracking_payload(); - payload.mem_stat = Some(mem_stat.clone()); - - let _guard = ThreadTracker::tracking(payload); - - let str = "".repeat(length); - drop(_guard); - assert_eq!(mem_stat.get_memory_usage(), str.len() as i64); - } -} - -#[test] -fn test_mem_tracker_with_vec_type() { - for _index in 0..10 { - let length = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .unwrap() - .as_nanos() as usize - % 1000; - - let mem_stat = MemStat::create("TEST".to_string()); - let mut payload = ThreadTracker::new_tracking_payload(); - payload.mem_stat = Some(mem_stat.clone()); - - let _guard = ThreadTracker::tracking(payload); - - let vec = (0..length).collect::>(); - drop(_guard); - assert_eq!( - mem_stat.get_memory_usage(), - (vec.capacity() * std::mem::size_of::()) as i64 - ); - } -} diff --git a/src/common/base/tests/it/memory/mod.rs b/src/common/base/tests/it/memory/mod.rs deleted file mode 100644 index 9db58b6f78402..0000000000000 --- a/src/common/base/tests/it/memory/mod.rs +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2021 Datafuse Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -mod mem_stat; diff --git a/src/meta/binaries/meta/ee_main.rs b/src/meta/binaries/meta/ee_main.rs index 0863fc7557039..873254b72797c 100644 --- a/src/meta/binaries/meta/ee_main.rs +++ b/src/meta/binaries/meta/ee_main.rs @@ -17,11 +17,11 @@ mod entry; mod kvapi; -use databend_common_base::mem_allocator::GlobalAllocator; +use databend_common_base::mem_allocator::DefaultGlobalAllocator; use databend_meta::configs::Config; #[global_allocator] -pub static GLOBAL_ALLOCATOR: GlobalAllocator = GlobalAllocator; +pub static GLOBAL_ALLOCATOR: DefaultGlobalAllocator = DefaultGlobalAllocator::create(); #[tokio::main(flavor = "multi_thread")] async fn main() -> anyhow::Result<()> { diff --git a/src/meta/binaries/meta/oss_main.rs b/src/meta/binaries/meta/oss_main.rs index 6776304670afb..01f88eb369497 100644 --- a/src/meta/binaries/meta/oss_main.rs +++ b/src/meta/binaries/meta/oss_main.rs @@ -17,11 +17,11 @@ mod entry; mod kvapi; -use databend_common_base::mem_allocator::GlobalAllocator; +use databend_common_base::mem_allocator::DefaultGlobalAllocator; use databend_meta::configs::Config; #[global_allocator] -pub static GLOBAL_ALLOCATOR: GlobalAllocator = GlobalAllocator; +pub static GLOBAL_ALLOCATOR: DefaultGlobalAllocator = DefaultGlobalAllocator::create(); #[tokio::main(flavor = "multi_thread")] async fn main() -> anyhow::Result<()> { diff --git a/src/query/pipeline/core/src/processors/port.rs b/src/query/pipeline/core/src/processors/port.rs index d01886f0392d3..372180ff52b43 100644 --- a/src/query/pipeline/core/src/processors/port.rs +++ b/src/query/pipeline/core/src/processors/port.rs @@ -19,7 +19,6 @@ use std::sync::Arc; use databend_common_base::runtime::drop_guard; use databend_common_base::runtime::profile::Profile; use databend_common_base::runtime::profile::ProfileStatisticsName; -use databend_common_base::runtime::ThreadTracker; use databend_common_exception::Result; use databend_common_expression::DataBlock; @@ -189,15 +188,7 @@ impl InputPort { let unset_flags = HAS_DATA | NEED_DATA; match self.shared.swap(std::ptr::null_mut(), 0, unset_flags) { address if address.is_null() => None, - address => { - let data_block = (*Box::from_raw(address)).0; - - if let Ok(data_block) = &data_block { - ThreadTracker::movein_memory(data_block.memory_size() as i64); - } - - Some(data_block) - } + address => Some((*Box::from_raw(address)).0), } } } @@ -238,8 +229,6 @@ impl OutputPort { UpdateTrigger::update_output(&self.update_trigger); if let Ok(data_block) = &data { - ThreadTracker::moveout_memory(data_block.memory_size() as i64); - if *self.record_profile { Profile::record_usize_profile( ProfileStatisticsName::OutputRows, diff --git a/src/query/service/src/interpreters/common/query_log.rs b/src/query/service/src/interpreters/common/query_log.rs index 1cae6d42aeaee..2896858421360 100644 --- a/src/query/service/src/interpreters/common/query_log.rs +++ b/src/query/service/src/interpreters/common/query_log.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::fmt::Write; use std::sync::Arc; use std::time::SystemTime; @@ -224,6 +225,7 @@ impl InterpreterQueryLog { has_profiles: false, txn_state, txn_id, + peek_memory_usage: HashMap::new(), }) } @@ -335,6 +337,8 @@ impl InterpreterQueryLog { let txn_id = guard.txn_id().to_string(); drop(guard); + let peek_memory_usage = ctx.get_node_peek_memory_usage(); + Self::write_log(QueryLogElement { log_type, log_type_name, @@ -398,6 +402,7 @@ impl InterpreterQueryLog { has_profiles, txn_state, txn_id, + peek_memory_usage, }) } } diff --git a/src/query/service/src/pipelines/executor/executor_graph.rs b/src/query/service/src/pipelines/executor/executor_graph.rs index 7a56e250ee5a1..c7b2c27272aeb 100644 --- a/src/query/service/src/pipelines/executor/executor_graph.rs +++ b/src/query/service/src/pipelines/executor/executor_graph.rs @@ -28,7 +28,6 @@ use databend_common_base::base::WatchNotify; use databend_common_base::runtime::error_info::NodeErrorType; use databend_common_base::runtime::profile::Profile; use databend_common_base::runtime::profile::ProfileStatisticsName; -use databend_common_base::runtime::MemStat; use databend_common_base::runtime::ThreadTracker; use databend_common_base::runtime::TrackingPayload; use databend_common_base::runtime::TrySpawn; @@ -122,16 +121,6 @@ impl Node { scope.as_ref().map(|x| x.metrics_registry.clone()), ))); - // Node mem stat - tracking_payload.mem_stat = Some(MemStat::create_child( - unsafe { processor.name() }, - tracking_payload - .mem_stat - .as_ref() - .map(|x| vec![x.clone()]) - .unwrap_or_default(), - )); - // Node tracking metrics tracking_payload.metrics = scope.as_ref().map(|x| x.metrics_registry.clone()); @@ -739,18 +728,6 @@ impl RunningGraph { .node_weights() .map(|x| { let new_profile = x.tracking_payload.profile.as_deref().cloned(); - - // inject memory usage - if let Some((profile, mem_stat)) = new_profile - .as_ref() - .zip(x.tracking_payload.mem_stat.as_ref()) - { - profile.statistics[ProfileStatisticsName::MemoryUsage as usize].fetch_add( - std::cmp::max(0, mem_stat.get_memory_usage()) as usize, - Ordering::Relaxed, - ); - } - Arc::new(new_profile.unwrap()) }) .collect::>() diff --git a/src/query/service/src/pipelines/executor/pipeline_complete_executor.rs b/src/query/service/src/pipelines/executor/pipeline_complete_executor.rs index ed689fc856589..1c82b1c4e4682 100644 --- a/src/query/service/src/pipelines/executor/pipeline_complete_executor.rs +++ b/src/query/service/src/pipelines/executor/pipeline_complete_executor.rs @@ -15,7 +15,6 @@ use std::sync::Arc; use databend_common_base::runtime::drop_guard; -use databend_common_base::runtime::MemStat; use databend_common_base::runtime::Thread; use databend_common_base::runtime::ThreadTracker; use databend_common_base::runtime::TrackingPayload; @@ -35,13 +34,8 @@ pub struct PipelineCompleteExecutor { // Use this executor when the pipeline is complete pipeline (has source and sink) impl PipelineCompleteExecutor { - fn execution_tracking_payload(query_id: &str) -> TrackingPayload { - let mut tracking_payload = ThreadTracker::new_tracking_payload(); - tracking_payload.mem_stat = Some(MemStat::create(format!( - "QueryExecutionMemStat-{}", - query_id - ))); - tracking_payload + fn execution_tracking_payload(_query_id: &str) -> TrackingPayload { + ThreadTracker::new_tracking_payload() } pub fn try_create( diff --git a/src/query/service/src/pipelines/executor/pipeline_pulling_executor.rs b/src/query/service/src/pipelines/executor/pipeline_pulling_executor.rs index a0d74844e975d..ac27f6183dd15 100644 --- a/src/query/service/src/pipelines/executor/pipeline_pulling_executor.rs +++ b/src/query/service/src/pipelines/executor/pipeline_pulling_executor.rs @@ -21,7 +21,6 @@ use std::sync::Arc; use std::time::Duration; use databend_common_base::runtime::drop_guard; -use databend_common_base::runtime::MemStat; use databend_common_base::runtime::Thread; use databend_common_base::runtime::ThreadTracker; use databend_common_base::runtime::TrackingPayload; @@ -102,33 +101,15 @@ pub struct PipelinePullingExecutor { } impl PipelinePullingExecutor { - fn execution_tracking_payload(query_id: &str) -> TrackingPayload { - let mut tracking_payload = ThreadTracker::new_tracking_payload(); - tracking_payload.mem_stat = Some(MemStat::create(format!( - "QueryExecutionMemStat-{}", - query_id - ))); - tracking_payload - } - - fn wrap_pipeline( - pipeline: &mut Pipeline, - tx: SyncSender, - mem_stat: Arc, - ) -> Result<()> { + fn wrap_pipeline(pipeline: &mut Pipeline, tx: SyncSender) -> Result<()> { if pipeline.is_pushing_pipeline()? || !pipeline.is_pulling_pipeline()? { return Err(ErrorCode::Internal( "Logical error, PipelinePullingExecutor can only work on pulling pipeline.", )); } - pipeline.add_sink(|input| { - Ok(ProcessorPtr::create(PullingSink::create( - tx.clone(), - mem_stat.clone(), - input, - ))) - })?; + pipeline + .add_sink(|input| Ok(ProcessorPtr::create(PullingSink::create(tx.clone(), input))))?; pipeline.set_on_finished(move |_info: &ExecutionInfo| { drop(tx); @@ -142,16 +123,12 @@ impl PipelinePullingExecutor { mut pipeline: Pipeline, settings: ExecutorSettings, ) -> Result { - let tracking_payload = Self::execution_tracking_payload(settings.query_id.as_ref()); + let tracking_payload = ThreadTracker::new_tracking_payload(); let _guard = ThreadTracker::tracking(tracking_payload.clone()); let (sender, receiver) = std::sync::mpsc::sync_channel(pipeline.output_len()); - Self::wrap_pipeline( - &mut pipeline, - sender, - tracking_payload.mem_stat.clone().unwrap(), - )?; + Self::wrap_pipeline(&mut pipeline, sender)?; let executor = PipelineExecutor::create(pipeline, settings)?; Ok(PipelinePullingExecutor { @@ -166,17 +143,13 @@ impl PipelinePullingExecutor { build_res: PipelineBuildResult, settings: ExecutorSettings, ) -> Result { - let tracking_payload = Self::execution_tracking_payload(settings.query_id.as_ref()); + let tracking_payload = ThreadTracker::new_tracking_payload(); let _guard = ThreadTracker::tracking(tracking_payload.clone()); let mut main_pipeline = build_res.main_pipeline; let (sender, receiver) = std::sync::mpsc::sync_channel(main_pipeline.output_len()); - Self::wrap_pipeline( - &mut main_pipeline, - sender, - tracking_payload.mem_stat.clone().unwrap(), - )?; + Self::wrap_pipeline(&mut main_pipeline, sender)?; let mut pipelines = build_res.sources_pipelines; pipelines.push(main_pipeline); @@ -286,19 +259,11 @@ impl Drop for PipelinePullingExecutor { struct PullingSink { sender: Option>, - query_execution_mem_stat: Arc, } impl PullingSink { - pub fn create( - tx: SyncSender, - mem_stat: Arc, - input: Arc, - ) -> Box { - Sinker::create(input, PullingSink { - sender: Some(tx), - query_execution_mem_stat: mem_stat, - }) + pub fn create(tx: SyncSender, input: Arc) -> Box { + Sinker::create(input, PullingSink { sender: Some(tx) }) } } @@ -311,12 +276,6 @@ impl Sink for PullingSink { } fn consume(&mut self, data_block: DataBlock) -> Result<()> { - let memory_size = data_block.memory_size() as i64; - // TODO: need moveout memory for plan tracker - ThreadTracker::moveout_memory(memory_size); - - self.query_execution_mem_stat.moveout_memory(memory_size); - if let Some(sender) = &self.sender { if let Err(cause) = sender.send(data_block) { return Err(ErrorCode::Internal(format!( diff --git a/src/query/service/src/pipelines/executor/query_pipeline_executor.rs b/src/query/service/src/pipelines/executor/query_pipeline_executor.rs index 939dabfceda92..844a1d8316fd7 100644 --- a/src/query/service/src/pipelines/executor/query_pipeline_executor.rs +++ b/src/query/service/src/pipelines/executor/query_pipeline_executor.rs @@ -21,7 +21,6 @@ use databend_common_base::runtime::catch_unwind; use databend_common_base::runtime::drop_guard; use databend_common_base::runtime::error_info::NodeErrorType; use databend_common_base::runtime::GlobalIORuntime; -use databend_common_base::runtime::MemStat; use databend_common_base::runtime::Runtime; use databend_common_base::runtime::Thread; use databend_common_base::runtime::ThreadJoinHandle; @@ -195,14 +194,7 @@ impl QueryPipelineExecutor { let mut on_finished_chain = self.on_finished_chain.lock(); // untracking for on finished - let mut tracking_payload = ThreadTracker::new_tracking_payload(); - if let Some(mem_stat) = &tracking_payload.mem_stat { - tracking_payload.mem_stat = Some(MemStat::create_child( - String::from("Pipeline-on-finished"), - mem_stat.get_parent_memory_stat(), - )); - } - + let tracking_payload = ThreadTracker::new_tracking_payload(); let _guard = ThreadTracker::tracking(tracking_payload); on_finished_chain.apply(info) } @@ -281,14 +273,7 @@ impl QueryPipelineExecutor { drop(guard); // untracking for on finished - let mut tracking_payload = ThreadTracker::new_tracking_payload(); - if let Some(mem_stat) = &tracking_payload.mem_stat { - tracking_payload.mem_stat = Some(MemStat::create_child( - String::from("Pipeline-on-finished"), - mem_stat.get_parent_memory_stat(), - )); - } - + let tracking_payload = ThreadTracker::new_tracking_payload(); if let Err(cause) = Result::flatten(catch_unwind(move || { let _guard = ThreadTracker::tracking(tracking_payload); @@ -472,14 +457,7 @@ impl Drop for QueryPipelineExecutor { let mut on_finished_chain = self.on_finished_chain.lock(); // untracking for on finished - let mut tracking_payload = ThreadTracker::new_tracking_payload(); - if let Some(mem_stat) = &tracking_payload.mem_stat { - tracking_payload.mem_stat = Some(MemStat::create_child( - String::from("Pipeline-on-finished"), - mem_stat.get_parent_memory_stat(), - )); - } - + let tracking_payload = ThreadTracker::new_tracking_payload(); let _guard = ThreadTracker::tracking(tracking_payload); let profiling = self.fetch_plans_profile(true); let info = ExecutionInfo::create(Err(cause), profiling); diff --git a/src/query/service/src/servers/flight/v1/actions/init_query_env.rs b/src/query/service/src/servers/flight/v1/actions/init_query_env.rs index 236f192af2179..8b530a4933805 100644 --- a/src/query/service/src/servers/flight/v1/actions/init_query_env.rs +++ b/src/query/service/src/servers/flight/v1/actions/init_query_env.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use databend_common_base::runtime::MemStat; use databend_common_base::runtime::ThreadTracker; use databend_common_config::GlobalConfig; use databend_common_exception::Result; @@ -23,15 +24,23 @@ use crate::servers::flight::v1::packets::QueryEnv; pub static INIT_QUERY_ENV: &str = "/actions/init_query_env"; pub async fn init_query_env(env: QueryEnv) -> Result<()> { + let query_mem_stat = MemStat::create(format!("Query-{}", env.query_id)); + let query_max_memory_usage = env.settings.get_max_query_memory_usage()?; + + if query_max_memory_usage != 0 { + query_mem_stat.set_limit(query_max_memory_usage as i64); + } + let mut tracking_payload = ThreadTracker::new_tracking_payload(); tracking_payload.query_id = Some(env.query_id.clone()); + tracking_payload.mem_stat = Some(query_mem_stat.clone()); let _guard = ThreadTracker::tracking(tracking_payload); ThreadTracker::tracking_future(async move { debug!("init query env with {:?}", env); let ctx = match env.request_server_id == GlobalConfig::instance().query.node_id { true => None, - false => Some(env.create_query_ctx().await?), + false => Some(env.create_query_ctx(query_mem_stat).await?), }; if let Err(e) = DataExchangeManager::instance() diff --git a/src/query/service/src/servers/flight/v1/actions/init_query_fragments.rs b/src/query/service/src/servers/flight/v1/actions/init_query_fragments.rs index a90927496b97f..171f56b5f8deb 100644 --- a/src/query/service/src/servers/flight/v1/actions/init_query_fragments.rs +++ b/src/query/service/src/servers/flight/v1/actions/init_query_fragments.rs @@ -23,7 +23,10 @@ use crate::servers::flight::v1::packets::QueryFragments; pub static INIT_QUERY_FRAGMENTS: &str = "/actions/init_query_fragments"; pub async fn init_query_fragments(fragments: QueryFragments) -> Result<()> { + let ctx = DataExchangeManager::instance().get_query_ctx(&fragments.query_id)?; + let mut tracking_payload = ThreadTracker::new_tracking_payload(); + tracking_payload.mem_stat = ctx.get_query_memory_tracking(); tracking_payload.query_id = Some(fragments.query_id.clone()); let _guard = ThreadTracker::tracking(tracking_payload); @@ -31,7 +34,6 @@ pub async fn init_query_fragments(fragments: QueryFragments) -> Result<()> { // Avoid blocking runtime. let query_id = fragments.query_id.clone(); - let ctx = DataExchangeManager::instance().get_query_ctx(&fragments.query_id)?; let join_handler = ctx.spawn(ThreadTracker::tracking_future(async move { DataExchangeManager::instance().init_query_fragments_plan(&fragments) })); diff --git a/src/query/service/src/servers/flight/v1/actions/start_prepared_query.rs b/src/query/service/src/servers/flight/v1/actions/start_prepared_query.rs index 10ae27f56f442..fbb7e8fce55e7 100644 --- a/src/query/service/src/servers/flight/v1/actions/start_prepared_query.rs +++ b/src/query/service/src/servers/flight/v1/actions/start_prepared_query.rs @@ -21,8 +21,11 @@ use crate::servers::flight::v1::exchange::DataExchangeManager; pub static START_PREPARED_QUERY: &str = "/actions/start_prepared_query"; pub async fn start_prepared_query(id: String) -> Result<()> { + let ctx = DataExchangeManager::instance().get_query_ctx(&id)?; + let mut tracking_payload = ThreadTracker::new_tracking_payload(); tracking_payload.query_id = Some(id.clone()); + tracking_payload.mem_stat = ctx.get_query_memory_tracking(); let _guard = ThreadTracker::tracking(tracking_payload); debug!("start prepared query {}", id); diff --git a/src/query/service/src/servers/flight/v1/exchange/statistics_receiver.rs b/src/query/service/src/servers/flight/v1/exchange/statistics_receiver.rs index cb3107f9e3442..18f501a6a98c8 100644 --- a/src/query/service/src/servers/flight/v1/exchange/statistics_receiver.rs +++ b/src/query/service/src/servers/flight/v1/exchange/statistics_receiver.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::collections::HashMap; +use std::sync::atomic::Ordering; use std::sync::Arc; use databend_common_base::base::tokio::sync::broadcast::channel; @@ -26,7 +27,9 @@ use futures_util::future::select; use futures_util::future::Either; use crate::servers::flight::v1::packets::DataPacket; +use crate::servers::flight::v1::packets::ProgressInfo; use crate::servers::flight::FlightExchange; +use crate::sessions::MemoryUpdater; use crate::sessions::QueryContext; pub struct StatisticsReceiver { @@ -49,6 +52,7 @@ impl StatisticsReceiver { exchange_handler.push(runtime.spawn({ let ctx = ctx.clone(); let shutdown_rx = shutdown_tx.subscribe(); + let node_memory_updater = ctx.get_node_memory_updater(&source_target); async move { let mut shutdown_rx = shutdown_rx; @@ -65,6 +69,7 @@ impl StatisticsReceiver { match StatisticsReceiver::recv_data( &ctx, &source_target, + &node_memory_updater, recv.await, ) { Ok(true) => { @@ -78,6 +83,7 @@ impl StatisticsReceiver { match StatisticsReceiver::recv_data( &ctx, &source_target, + &node_memory_updater, rx.recv().await, ) { Ok(true) => { @@ -94,7 +100,12 @@ impl StatisticsReceiver { } } Either::Right((res, left)) => { - match StatisticsReceiver::recv_data(&ctx, &source_target, res) { + match StatisticsReceiver::recv_data( + &ctx, + &source_target, + &node_memory_updater, + res, + ) { Ok(true) => { return Ok(()); } @@ -124,6 +135,7 @@ impl StatisticsReceiver { fn recv_data( ctx: &Arc, source_target: &str, + node_memory_usage: &Arc, recv_data: Result>, ) -> Result { match recv_data { @@ -134,6 +146,18 @@ impl StatisticsReceiver { Ok(Some(DataPacket::FragmentData(_))) => unreachable!(), Ok(Some(DataPacket::SerializeProgress(progress))) => { for progress_info in progress { + if let ProgressInfo::MemoryUsage(memory_usage, peek_memory_usage) = + &progress_info + { + node_memory_usage + .memory_usage + .store(*memory_usage, Ordering::Relaxed); + node_memory_usage + .peek_memory_usage + .store(*peek_memory_usage, Ordering::Relaxed); + continue; + } + progress_info.inc(source_target, ctx); } diff --git a/src/query/service/src/servers/flight/v1/exchange/statistics_sender.rs b/src/query/service/src/servers/flight/v1/exchange/statistics_sender.rs index 33ea8c1f915b7..08c34c8618388 100644 --- a/src/query/service/src/servers/flight/v1/exchange/statistics_sender.rs +++ b/src/query/service/src/servers/flight/v1/exchange/statistics_sender.rs @@ -17,6 +17,7 @@ use std::time::Duration; use async_channel::Sender; use databend_common_base::base::tokio::time::sleep; +use databend_common_base::runtime::MemStat; use databend_common_base::runtime::TrySpawn; use databend_common_base::JoinHandle; use databend_common_catalog::table_context::TableContext; @@ -58,6 +59,8 @@ impl StatisticsSender { let mut sleep_future = Box::pin(sleep(Duration::from_millis(100))); let mut notified = Box::pin(shutdown_flag_receiver.recv()); + let mem_stat = ctx.get_query_memory_tracking(); + loop { match futures::future::select(sleep_future, notified).await { Either::Right((Err(_), _)) => { @@ -81,7 +84,7 @@ impl StatisticsSender { notified = right; sleep_future = Box::pin(sleep(Duration::from_millis(100))); - if let Err(cause) = Self::send_progress(&ctx, &tx).await { + if let Err(cause) = Self::send_progress(&ctx, &mem_stat, &tx).await { ctx.get_exchange_manager() .shutdown_query(&query_id, Some(cause)); return; @@ -112,7 +115,7 @@ impl StatisticsSender { warn!("MutationStatus send has error, cause: {:?}.", error); } - if let Err(error) = Self::send_progress(&ctx, &tx).await { + if let Err(error) = Self::send_progress(&ctx, &mem_stat, &tx).await { warn!("Statistics send has error, cause: {:?}.", error); } } @@ -146,8 +149,22 @@ impl StatisticsSender { } #[async_backtrace::framed] - async fn send_progress(ctx: &Arc, tx: &FlightSender) -> Result<()> { - let progress = Self::fetch_progress(ctx); + async fn send_progress( + ctx: &Arc, + mem_stat: &Option>, + tx: &FlightSender, + ) -> Result<()> { + let mut progress = Self::fetch_progress(ctx); + + if let Some(mem_stat) = mem_stat { + let memory_usage = std::cmp::max(0, mem_stat.get_memory_usage()); + let peek_memory_usage = std::cmp::max(0, mem_stat.get_peek_memory_usage()); + progress.push(ProgressInfo::MemoryUsage( + memory_usage as usize, + peek_memory_usage as usize, + )); + } + let data_packet = DataPacket::SerializeProgress(progress); tx.send(data_packet).await } diff --git a/src/query/service/src/servers/flight/v1/packets/packet_data_progressinfo.rs b/src/query/service/src/servers/flight/v1/packets/packet_data_progressinfo.rs index 87a64a0fe8a0e..7648d5536c8fa 100644 --- a/src/query/service/src/servers/flight/v1/packets/packet_data_progressinfo.rs +++ b/src/query/service/src/servers/flight/v1/packets/packet_data_progressinfo.rs @@ -31,6 +31,7 @@ use crate::sessions::TableContext; #[allow(clippy::enum_variant_names)] #[derive(Debug)] pub enum ProgressInfo { + MemoryUsage(usize, usize), ScanProgress(ProgressValues), WriteProgress(ProgressValues), ResultProgress(ProgressValues), @@ -46,6 +47,7 @@ impl ProgressInfo { ProgressInfo::SpillTotalStats(values) => { ctx.set_cluster_spill_progress(source_target, values.clone()) } + ProgressInfo::MemoryUsage(_, _) => unreachable!(), }; } @@ -60,6 +62,12 @@ impl ProgressInfo { bytes.write_u64::(values.bytes as u64)?; return Ok(()); } + ProgressInfo::MemoryUsage(memory_usage, peek_memory_usage) => { + bytes.write_u8(5)?; + bytes.write_u64::(memory_usage as u64)?; + bytes.write_u64::(peek_memory_usage as u64)?; + return Ok(()); + } }; bytes.write_u8(info_type)?; @@ -79,6 +87,12 @@ impl ProgressInfo { ))); } + if info_type == 5 { + let memory_usage = bytes.read_u64::()? as usize; + let peek_memory_usage = bytes.read_u64::()? as usize; + return Ok(ProgressInfo::MemoryUsage(memory_usage, peek_memory_usage)); + } + let rows = bytes.read_u64::()? as usize; let bytes = bytes.read_u64::()? as usize; diff --git a/src/query/service/src/servers/flight/v1/packets/packet_publisher.rs b/src/query/service/src/servers/flight/v1/packets/packet_publisher.rs index bc44f08c08b89..9f08ff1ee1d8d 100644 --- a/src/query/service/src/servers/flight/v1/packets/packet_publisher.rs +++ b/src/query/service/src/servers/flight/v1/packets/packet_publisher.rs @@ -18,6 +18,7 @@ use std::fmt::Formatter; use std::ops::Deref; use std::sync::Arc; +use databend_common_base::runtime::MemStat; use databend_common_catalog::cluster_info::Cluster; use databend_common_catalog::query_kind::QueryKind; use databend_common_catalog::table_context::TableContext; @@ -157,18 +158,21 @@ impl QueryEnv { Ok(()) } - pub async fn create_query_ctx(&self) -> Result> { + pub async fn create_query_ctx(&self, mem_stat: Arc) -> Result> { let session_manager = SessionManager::instance(); let session = session_manager.register_session( session_manager.create_with_settings(SessionType::FlightRPC, self.settings.clone())?, )?; - let query_ctx = session.create_query_context_with_cluster(Arc::new(Cluster { - unassign: self.cluster.unassign, - nodes: self.cluster.nodes.clone(), - local_id: GlobalConfig::instance().query.node_id.clone(), - }))?; + let query_ctx = session.create_query_context_with_cluster( + Arc::new(Cluster { + unassign: self.cluster.unassign, + nodes: self.cluster.nodes.clone(), + local_id: GlobalConfig::instance().query.node_id.clone(), + }), + Some(mem_stat), + )?; query_ctx.update_init_query_id(self.query_id.clone()); query_ctx.attach_query_str(self.query_kind, "".to_string()); diff --git a/src/query/service/src/servers/http/middleware/session.rs b/src/query/service/src/servers/http/middleware/session.rs index 567e95f4775f5..73d093edc518b 100644 --- a/src/query/service/src/servers/http/middleware/session.rs +++ b/src/query/service/src/servers/http/middleware/session.rs @@ -23,6 +23,8 @@ use databend_common_base::headers::HEADER_STICKY; use databend_common_base::headers::HEADER_TENANT; use databend_common_base::headers::HEADER_VERSION; use databend_common_base::headers::HEADER_WAREHOUSE; +use databend_common_base::runtime::defer; +use databend_common_base::runtime::MemStat; use databend_common_base::runtime::ThreadTracker; use databend_common_config::GlobalConfig; use databend_common_config::DATABEND_SEMVER; @@ -599,9 +601,15 @@ impl Endpoint for HTTPSessionEndpoint { .map(|id| id.to_str().unwrap().to_string()) .unwrap_or_else(|| Uuid::new_v4().to_string()); + let query_mem_stat = MemStat::create(format!("Query-{}", query_id)); let mut tracking_payload = ThreadTracker::new_tracking_payload(); tracking_payload.query_id = Some(query_id.clone()); + tracking_payload.mem_stat = Some(query_mem_stat.clone()); + let _guard = ThreadTracker::tracking(tracking_payload); + let _guard2 = defer(move || { + query_mem_stat.set_limit(0); + }); ThreadTracker::tracking_future(async move { match self.auth(&req, query_id).await { diff --git a/src/query/service/src/servers/http/v1/query/execute_state.rs b/src/query/service/src/servers/http/v1/query/execute_state.rs index 14741394c1bb5..0bc277419f214 100644 --- a/src/query/service/src/servers/http/v1/query/execute_state.rs +++ b/src/query/service/src/servers/http/v1/query/execute_state.rs @@ -299,7 +299,7 @@ impl Executor { Running(r) => { info!( "{}: http query changing state from Running to Stopped, reason {:?}", - &guard.query_id, reason + &guard.query_id, reason, ); if let Err(e) = &reason { if e.code() != ErrorCode::CLOSED_QUERY { diff --git a/src/query/service/src/servers/mysql/mysql_interactive_worker.rs b/src/query/service/src/servers/mysql/mysql_interactive_worker.rs index bf64b9bfd9ba9..a12ec677f473b 100644 --- a/src/query/service/src/servers/mysql/mysql_interactive_worker.rs +++ b/src/query/service/src/servers/mysql/mysql_interactive_worker.rs @@ -19,6 +19,7 @@ use std::time::Instant; use databend_common_base::base::convert_byte_size; use databend_common_base::base::convert_number_size; use databend_common_base::base::tokio::io::AsyncWrite; +use databend_common_base::runtime::MemStat; use databend_common_base::runtime::ThreadTracker; use databend_common_base::runtime::TrySpawn; use databend_common_config::DATABEND_COMMIT_VERSION; @@ -196,6 +197,7 @@ impl AsyncMysqlShim for InteractiveWorke let mut tracking_payload = ThreadTracker::new_tracking_payload(); tracking_payload.query_id = Some(query_id.clone()); + tracking_payload.mem_stat = Some(MemStat::create(format!("Query-{}", query_id))); let _guard = ThreadTracker::tracking(tracking_payload); ThreadTracker::tracking_future(async { @@ -463,6 +465,7 @@ impl InteractiveWorkerBase { let mut tracking_payload = ThreadTracker::new_tracking_payload(); tracking_payload.query_id = Some(query_id.clone()); + tracking_payload.mem_stat = Some(MemStat::create(format!("Query-{}", query_id))); let _guard = ThreadTracker::tracking(tracking_payload); let do_query = ThreadTracker::tracking_future(self.do_query(query_id, &init_query)).await; diff --git a/src/query/service/src/sessions/mod.rs b/src/query/service/src/sessions/mod.rs index 3700e5680510b..73f1153935776 100644 --- a/src/query/service/src/sessions/mod.rs +++ b/src/query/service/src/sessions/mod.rs @@ -30,6 +30,7 @@ pub use databend_common_catalog::table_context::TableContext; pub use query_affect::QueryAffect; pub use query_ctx::convert_query_log_timestamp; pub use query_ctx::QueryContext; +pub use query_ctx_shared::MemoryUpdater; pub use query_ctx_shared::QueryContextShared; pub use queue_mgr::AcquireQueueGuard; pub use queue_mgr::QueriesQueueManager; diff --git a/src/query/service/src/sessions/query_ctx.rs b/src/query/service/src/sessions/query_ctx.rs index 513a806fad156..87e7f2f7c30e8 100644 --- a/src/query/service/src/sessions/query_ctx.rs +++ b/src/query/service/src/sessions/query_ctx.rs @@ -37,6 +37,7 @@ use databend_common_base::base::SpillProgress; use databend_common_base::runtime::profile::Profile; use databend_common_base::runtime::profile::ProfileStatisticsName; use databend_common_base::runtime::GlobalIORuntime; +use databend_common_base::runtime::MemStat; use databend_common_base::runtime::TrySpawn; use databend_common_base::JoinHandle; use databend_common_catalog::catalog::CATALOG_DEFAULT; @@ -130,6 +131,7 @@ use crate::locks::LockManager; use crate::pipelines::executor::PipelineExecutor; use crate::servers::flight::v1::exchange::DataExchangeManager; use crate::sessions::query_affect::QueryAffect; +use crate::sessions::query_ctx_shared::MemoryUpdater; use crate::sessions::ProcessInfo; use crate::sessions::QueriesQueueManager; use crate::sessions::QueryContextShared; @@ -532,6 +534,22 @@ impl QueryContext { log::error!("create spill meta file error: {}", e); } } + + pub fn get_query_memory_tracking(&self) -> Option> { + self.shared.get_query_memory_tracking() + } + + pub fn set_query_memory_tracking(&self, mem_stat: Option>) { + self.shared.set_query_memory_tracking(mem_stat) + } + + pub fn get_node_memory_updater(&self, node: &str) -> Arc { + self.shared.get_node_memory_updater(node) + } + + pub fn get_node_peek_memory_usage(&self) -> HashMap { + self.shared.get_nodes_peek_memory_usage() + } } #[async_trait::async_trait] diff --git a/src/query/service/src/sessions/query_ctx_shared.rs b/src/query/service/src/sessions/query_ctx_shared.rs index 50341c7f1eea4..b38490d7e6ef3 100644 --- a/src/query/service/src/sessions/query_ctx_shared.rs +++ b/src/query/service/src/sessions/query_ctx_shared.rs @@ -15,6 +15,7 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; use std::sync::atomic::AtomicBool; +use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; use std::sync::Arc; use std::sync::Weak; @@ -26,6 +27,7 @@ use databend_common_base::base::short_sql; use databend_common_base::base::Progress; use databend_common_base::base::SpillProgress; use databend_common_base::runtime::drop_guard; +use databend_common_base::runtime::MemStat; use databend_common_base::runtime::Runtime; use databend_common_catalog::catalog::Catalog; use databend_common_catalog::catalog::CatalogManager; @@ -66,6 +68,11 @@ use crate::sessions::query_affect::QueryAffect; use crate::sessions::Session; use crate::storages::Table; +pub struct MemoryUpdater { + pub memory_usage: AtomicUsize, + pub peek_memory_usage: AtomicUsize, +} + type DatabaseAndTable = (String, String, String); /// Data that needs to be shared in a query context. @@ -151,6 +158,8 @@ pub struct QueryContextShared { pub(in crate::sessions) spilled_files: Arc>>, pub(in crate::sessions) unload_callbacked: AtomicBool, + pub(in crate::sessions) mem_stat: Arc>>>, + pub(in crate::sessions) node_memory_usage: Arc>>>, } impl QueryContextShared { @@ -212,6 +221,8 @@ impl QueryContextShared { spilled_files: Default::default(), unload_callbacked: AtomicBool::new(false), warehouse_cache: Arc::new(RwLock::new(None)), + mem_stat: Arc::new(RwLock::new(None)), + node_memory_usage: Arc::new(RwLock::new(HashMap::new())), })) } @@ -676,6 +687,71 @@ impl QueryContextShared { }; } } + + pub fn set_query_memory_tracking(&self, mem_stat: Option>) { + let mut mem_stat_guard = self.mem_stat.write(); + *mem_stat_guard = mem_stat; + } + + pub fn get_query_memory_tracking(&self) -> Option> { + self.mem_stat.read().clone() + } + + pub fn get_node_memory_updater(&self, node: &str) -> Arc { + { + if let Some(v) = self.node_memory_usage.read().get(node) { + return v.clone(); + } + } + + let key = node.to_string(); + let node_memory_updater = Arc::new(MemoryUpdater { + memory_usage: AtomicUsize::new(0), + peek_memory_usage: AtomicUsize::new(0), + }); + + let mut guard = self.node_memory_usage.write(); + guard.insert(key, node_memory_updater.clone()); + node_memory_updater + } + + pub fn get_nodes_memory_usage(&self) -> usize { + let mut memory_usage = { + match self.mem_stat.read().as_ref() { + None => 0, + Some(mem_stat) => std::cmp::max(0, mem_stat.get_memory_usage()) as usize, + } + }; + + for (_, node_memory_updater) in self.node_memory_usage.read().iter() { + memory_usage += node_memory_updater.memory_usage.load(Ordering::Relaxed); + } + + memory_usage + } + + pub fn get_nodes_peek_memory_usage(&self) -> HashMap { + let memory_usage = { + match self.mem_stat.read().as_ref() { + None => 0, + Some(mem_stat) => std::cmp::max(0, mem_stat.get_peek_memory_usage()) as usize, + } + }; + + let mut nodes_peek_memory_usage = HashMap::new(); + + nodes_peek_memory_usage + .insert(GlobalConfig::instance().query.node_id.clone(), memory_usage); + + for (node, node_memory_updater) in self.node_memory_usage.read().iter() { + let peek_memory_usage = node_memory_updater + .peek_memory_usage + .load(Ordering::Relaxed); + nodes_peek_memory_usage.insert(node.clone(), peek_memory_usage); + } + + nodes_peek_memory_usage + } } impl Drop for QueryContextShared { diff --git a/src/query/service/src/sessions/session.rs b/src/query/service/src/sessions/session.rs index b154c3ee91c37..29b268e4daf55 100644 --- a/src/query/service/src/sessions/session.rs +++ b/src/query/service/src/sessions/session.rs @@ -17,6 +17,8 @@ use std::net::SocketAddr; use std::sync::Arc; use databend_common_base::runtime::drop_guard; +use databend_common_base::runtime::MemStat; +use databend_common_base::runtime::ThreadTracker; use databend_common_catalog::cluster_info::Cluster; use databend_common_config::GlobalConfig; use databend_common_exception::ErrorCode; @@ -149,16 +151,23 @@ impl Session { pub async fn create_query_context(self: &Arc) -> Result> { let config = GlobalConfig::instance(); let cluster = ClusterDiscovery::instance().discover(&config).await?; - self.create_query_context_with_cluster(cluster) + let mem_stat = ThreadTracker::mem_stat().cloned(); + self.create_query_context_with_cluster(cluster, mem_stat) } pub fn create_query_context_with_cluster( self: &Arc, cluster: Arc, + mem_stat: Option>, ) -> Result> { let session = self.clone(); let shared = QueryContextShared::try_create(session, cluster)?; + if let Some(mem_stat) = mem_stat { + mem_stat.set_limit(self.get_settings().get_max_query_memory_usage()? as i64); + shared.set_query_memory_tracking(Some(mem_stat)); + } + self.session_ctx .set_query_context_shared(Arc::downgrade(&shared)); Ok(QueryContext::create_from_shared(shared)) diff --git a/src/query/service/src/sessions/session_info.rs b/src/query/service/src/sessions/session_info.rs index dde95863e5b27..a346f81f27066 100644 --- a/src/query/service/src/sessions/session_info.rs +++ b/src/query/service/src/sessions/session_info.rs @@ -36,10 +36,7 @@ impl Session { let shared_query_context = &session_ctx.get_query_context_shared(); if let Some(shared) = shared_query_context { - if let Some(runtime) = shared.get_runtime() { - let mem_stat = runtime.get_tracker(); - memory_usage = mem_stat.get_memory_usage(); - } + memory_usage = shared.get_nodes_memory_usage(); } ProcessInfo { @@ -51,7 +48,7 @@ impl Session { settings: self.get_settings(), client_address: session_ctx.get_client_host(), session_extra_info: self.process_extra_info(session_ctx), - memory_usage, + memory_usage: memory_usage as i64, data_metrics: Self::query_data_metrics(session_ctx), scan_progress_value: Self::query_scan_progress_value(session_ctx), write_progress_value: Self::query_write_progress_value(session_ctx), diff --git a/src/query/service/src/sessions/session_mgr.rs b/src/query/service/src/sessions/session_mgr.rs index d871336f432df..3f7a45f5cc533 100644 --- a/src/query/service/src/sessions/session_mgr.rs +++ b/src/query/service/src/sessions/session_mgr.rs @@ -25,6 +25,7 @@ use databend_common_base::base::tokio; use databend_common_base::base::GlobalInstance; use databend_common_base::base::SignalStream; use databend_common_base::runtime::metrics::GLOBAL_METRICS_REGISTRY; +use databend_common_base::runtime::LimitMemGuard; use databend_common_catalog::table_context::ProcessInfoState; use databend_common_config::GlobalConfig; use databend_common_config::InnerConfig; @@ -179,6 +180,7 @@ impl SessionManager { } pub fn destroy_session(&self, session_id: &String) { + let _guard = LimitMemGuard::enter_unlimited(); // NOTE: order and scope of lock are very important. It's will cause deadlock // stop tracking session diff --git a/src/query/service/src/stream/processor_executor_stream.rs b/src/query/service/src/stream/processor_executor_stream.rs index bfa4a01300ae3..4511ec082ca6c 100644 --- a/src/query/service/src/stream/processor_executor_stream.rs +++ b/src/query/service/src/stream/processor_executor_stream.rs @@ -24,7 +24,7 @@ use crate::pipelines::executor::PipelinePullingExecutor; pub struct PullingExecutorStream { end_of_stream: bool, - executor: PipelinePullingExecutor, + executor: Option, } impl PullingExecutorStream { @@ -32,9 +32,32 @@ impl PullingExecutorStream { executor.start(); Ok(Self { end_of_stream: false, - executor, + executor: Some(executor), }) } + + fn poll_next_impl(&mut self) -> Poll>> { + if let Some(mut executor) = self.executor.take() { + return match executor.pull_data() { + Err(cause) => { + self.end_of_stream = true; + drop(executor); + Poll::Ready(Some(Err(cause))) + } + Ok(Some(data)) => { + self.executor = Some(executor); + Poll::Ready(Some(Ok(data))) + } + Ok(None) => { + self.end_of_stream = true; + drop(executor); + Poll::Ready(None) + } + }; + } + + Poll::Ready(None) + } } impl Stream for PullingExecutorStream { @@ -47,13 +70,6 @@ impl Stream for PullingExecutorStream { return Poll::Ready(None); } - match self_.executor.pull_data() { - Err(cause) => { - self_.end_of_stream = true; - Poll::Ready(Some(Err(cause))) - } - Ok(Some(data)) => Poll::Ready(Some(Ok(data))), - Ok(None) => Poll::Ready(None), - } + self_.poll_next_impl() } } diff --git a/src/query/service/tests/it/storages/testdata/columns_table.txt b/src/query/service/tests/it/storages/testdata/columns_table.txt index 19547e9cb5847..97e1a68a23e77 100644 --- a/src/query/service/tests/it/storages/testdata/columns_table.txt +++ b/src/query/service/tests/it/storages/testdata/columns_table.txt @@ -348,6 +348,7 @@ DB.Table: 'system'.'columns', Table: columns-table_id:1, ver:0, Engine: SystemCo | 'parent_plan_id' | 'system' | 'queries_profiling' | 'Nullable(UInt32)' | 'INT UNSIGNED' | '' | '' | 'YES' | '' | | 'partitions_sha' | 'system' | 'query_cache' | 'String' | 'VARCHAR' | '' | '' | 'NO' | '' | | 'password_policy' | 'system' | 'users' | 'Nullable(String)' | 'VARCHAR' | '' | '' | 'YES' | '' | +| 'peek_memory_usage' | 'system' | 'query_log' | 'Variant' | 'VARIANT' | '' | '' | 'NO' | '' | | 'plan_id' | 'system' | 'queries_profiling' | 'Nullable(UInt32)' | 'INT UNSIGNED' | '' | '' | 'YES' | '' | | 'plan_name' | 'system' | 'queries_profiling' | 'Nullable(String)' | 'VARCHAR' | '' | '' | 'YES' | '' | | 'port' | 'system' | 'clusters' | 'UInt16' | 'SMALLINT UNSIGNED' | '' | '' | 'NO' | '' | diff --git a/src/query/settings/src/settings_default.rs b/src/query/settings/src/settings_default.rs index 51c1f93f35ef7..c16a4eebbf1b0 100644 --- a/src/query/settings/src/settings_default.rs +++ b/src/query/settings/src/settings_default.rs @@ -171,6 +171,13 @@ impl DefaultSettings { scope: SettingScope::Both, range: Some(SettingRange::Numeric(0..=u64::MAX)), }), + ("max_query_memory_usage", DefaultSettingValue { + value: UserSettingValue::UInt64(0), + desc: "The maximum memory usage for query. If set to 0, memory usage is unlimited. This setting is the successor/replacement to the older max_memory_usage setting.", + mode: SettingMode::Both, + scope: SettingScope::Both, + range: Some(SettingRange::Numeric(0..=u64::MAX)), + }), ("data_retention_time_in_days", DefaultSettingValue { // unit of retention_period is day value: UserSettingValue::UInt64(1), @@ -1200,7 +1207,6 @@ impl DefaultSettings { scope: SettingScope::Both, range: Some(SettingRange::Numeric(0..=1)), }), - ]); Ok(Arc::new(DefaultSettings { diff --git a/src/query/settings/src/settings_getter_setter.rs b/src/query/settings/src/settings_getter_setter.rs index 94a04ec1f8526..96e700fe373b6 100644 --- a/src/query/settings/src/settings_getter_setter.rs +++ b/src/query/settings/src/settings_getter_setter.rs @@ -882,4 +882,12 @@ impl Settings { pub fn get_copy_dedup_full_path_by_default(&self) -> Result { Ok(self.try_get_u64("copy_dedup_full_path_by_default")? == 1) } + + pub fn get_max_query_memory_usage(&self) -> Result { + self.try_get_u64("max_query_memory_usage") + } + + pub fn set_max_query_memory_usage(&self, max_memory_usage: u64) -> Result<()> { + self.try_set_u64("max_query_memory_usage", max_memory_usage) + } } diff --git a/src/query/storages/system/src/metrics_table.rs b/src/query/storages/system/src/metrics_table.rs index 4921929002785..cefadde6e3c04 100644 --- a/src/query/storages/system/src/metrics_table.rs +++ b/src/query/storages/system/src/metrics_table.rs @@ -138,18 +138,11 @@ impl MetricsTable { /// Custom metrics that are not collected by prometheus. fn custom_metric_samples(&self) -> Result> { - let samples = vec![ - MetricSample { - name: "query_memory_usage_bytes".to_string(), - value: MetricValue::Counter(GLOBAL_MEM_STAT.get_memory_usage() as f64), - labels: HashMap::new(), - }, - MetricSample { - name: "query_memory_peak_usage_bytes".to_string(), - value: MetricValue::Counter(GLOBAL_MEM_STAT.get_peak_memory_usage() as f64), - labels: HashMap::new(), - }, - ]; + let samples = vec![MetricSample { + name: "query_memory_usage_bytes".to_string(), + value: MetricValue::Counter(GLOBAL_MEM_STAT.get_memory_usage() as f64), + labels: HashMap::new(), + }]; Ok(samples) } diff --git a/src/query/storages/system/src/query_log_table.rs b/src/query/storages/system/src/query_log_table.rs index 76f284ed2e143..c6c8e16946fdf 100644 --- a/src/query/storages/system/src/query_log_table.rs +++ b/src/query/storages/system/src/query_log_table.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; + use chrono::DateTime; use databend_common_exception::Result; use databend_common_expression::types::number::NumberScalar; @@ -177,6 +179,7 @@ pub struct QueryLogElement { // Transaction pub txn_state: String, pub txn_id: String, + pub peek_memory_usage: HashMap, } impl SystemLogElement for QueryLogElement { @@ -318,6 +321,7 @@ impl SystemLogElement for QueryLogElement { // Extra. TableField::new("extra", TableDataType::String), TableField::new("has_profile", TableDataType::Boolean), + TableField::new("peek_memory_usage", TableDataType::Variant), ]) } @@ -568,6 +572,17 @@ impl SystemLogElement for QueryLogElement { .next() .unwrap() .push(Scalar::Boolean(self.has_profiles).as_ref()); + columns.next().unwrap().push( + Scalar::Variant( + jsonb::Value::from(jsonb::Object::from_iter( + self.peek_memory_usage + .iter() + .map(|(k, v)| (k.clone(), jsonb::Value::from(*v))), + )) + .to_vec(), + ) + .as_ref(), + ); Ok(()) } } diff --git a/tests/sqllogictests/suites/mode/cluster/explain_v2.test b/tests/sqllogictests/suites/mode/cluster/explain_v2.test index 257be5c94406a..8b47b800a9756 100644 --- a/tests/sqllogictests/suites/mode/cluster/explain_v2.test +++ b/tests/sqllogictests/suites/mode/cluster/explain_v2.test @@ -434,21 +434,21 @@ query T EXPLAIN SELECT a.cluster_node, b.query_node FROM (SELECT name as cluster_node FROM system.clusters) AS a LEFT JOIN (SELECT DISTINCT node_id as query_node FROM system.query_log) AS b ON a.cluster_node = b.query_node ---- Exchange -├── output columns: [clusters.name (#0), query_node (#63)] +├── output columns: [clusters.name (#0), query_node (#64)] ├── exchange type: Merge └── HashJoin - ├── output columns: [clusters.name (#0), query_node (#63)] + ├── output columns: [clusters.name (#0), query_node (#64)] ├── join type: LEFT OUTER - ├── build keys: [b.query_node (#63)] + ├── build keys: [b.query_node (#64)] ├── probe keys: [CAST(a.cluster_node (#0) AS String NULL)] ├── keys is null equal: [false] ├── filters: [] ├── estimated rows: 0.00 ├── Exchange(Build) - │ ├── output columns: [query_node (#63)] + │ ├── output columns: [query_node (#64)] │ ├── exchange type: Broadcast │ └── EvalScalar - │ ├── output columns: [query_node (#63)] + │ ├── output columns: [query_node (#64)] │ ├── expressions: [CAST(b.query_node (#10) AS String NULL)] │ ├── estimated rows: 0.00 │ └── AggregateFinal