wdk_mutex/
fast_mutex.rs

1//! A Rust idiomatic Windows Kernel Driver FAST_MUTEX type which protects the inner type T
2
3use alloc::boxed::Box;
4use core::{
5    ffi::c_void, fmt::Display, mem::ManuallyDrop, ops::{Deref, DerefMut}, ptr::{self, drop_in_place}
6};
7use wdk_sys::{
8    ntddk::{
9        ExAcquireFastMutex, ExAllocatePool2, ExFreePool, ExReleaseFastMutex, KeGetCurrentIrql,
10        KeInitializeEvent,
11    },
12    APC_LEVEL, DISPATCH_LEVEL, FALSE, FAST_MUTEX, FM_LOCK_BIT, POOL_FLAG_NON_PAGED,
13    _EVENT_TYPE::SynchronizationEvent,
14};
15
16extern crate alloc;
17
18use crate::errors::DriverMutexError;
19
20/// An internal binding for the ExInitializeFastMutex routine.
21///
22/// # Safety
23///
24/// This function does not check the IRQL as the only place this function is used is in an area where the IRQL
25/// is already checked.
26#[allow(non_snake_case)]
27unsafe fn ExInitializeFastMutex(fast_mutex: *mut FAST_MUTEX) {
28    core::ptr::write_volatile(&mut (*fast_mutex).Count, FM_LOCK_BIT as i32);
29
30    (*fast_mutex).Owner = core::ptr::null_mut();
31    (*fast_mutex).Contention = 0;
32    KeInitializeEvent(&mut (*fast_mutex).Event, SynchronizationEvent, FALSE as _)
33}
34
35/// A thread safe mutex implemented through acquiring a `FAST_MUTEX` in the Windows kernel.
36///
37/// The type `FastMutex<T>` provides mutually exclusive access to the inner type T allocated through
38/// this crate in the non-paged pool. All data required to initialise the FastMutex is allocated in the
39/// non-paged pool and as such is safe to pass stack data into the type as it will not go out of scope.
40///
41/// `FastMutex` holds an inner value which is a pointer to a `FastMutexInner` type which is the actual type
42/// allocated in the non-paged pool, and this holds information relating to the mutex.
43///
44/// Access to the `T` within the `FastMutex` can be done through calling [`Self::lock`].
45///
46/// # Lifetimes
47///
48/// As the `FastMutex` is designed to be used in the Windows Kernel, with the Windows `wdk` crate, the lifetimes of
49/// the `FastMutex` must be considered by the caller. See examples below for usage.
50///
51/// The `FastMutex` can exist in a locally scoped function with little additional configuration. To use the mutex across
52/// thread boundaries, or to use it in callback functions, you can use the `Grt` module found in this crate. See below for
53/// details.
54///
55/// # Deallocation
56///
57/// FastMutex handles the deallocation of resources at the point the FastMutex is dropped.
58///
59/// # Examples
60///
61/// ## Locally scoped mutex:
62///
63/// ```
64/// {
65///     let mtx = FastMutex::new(0u32).unwrap();
66///     let lock = mtx.lock().unwrap();
67///
68///     // If T implements display, you do not need to dereference the lock to print.
69///     println!("The value is: {}", lock);
70/// } // Mutex will become unlocked as it is managed via RAII
71/// ```
72///
73/// ## Global scope via the `Grt` module in `wdk-mutex`:
74///
75/// ```
76/// // Initialise the mutex on DriverEntry
77///
78/// #[export_name = "DriverEntry"]
79/// pub unsafe extern "system" fn driver_entry(
80///     driver: &mut DRIVER_OBJECT,
81///     registry_path: PCUNICODE_STRING,
82/// ) -> NTSTATUS {
83///     if let Err(e) = Grt::init() {
84///         println!("Error creating Grt!: {:?}", e);
85///         return STATUS_UNSUCCESSFUL;
86///     }
87///
88///     // ...
89///     my_function();
90/// }
91///
92///
93/// // Register a new Mutex in the `Grt` of value 0u32:
94///
95/// pub fn my_function() {
96///     Grt::register_fast_mutex("my_test_mutex", 0u32);
97/// }
98///
99/// unsafe extern "C" fn my_thread_fn_pointer(_: *mut c_void) {
100///     let my_mutex = Grt::get_fast_mutex::<u32>("my_test_mutex");
101///     if let Err(e) = my_mutex {
102///         println!("Error in thread: {:?}", e);
103///         return;
104///     }
105///
106///     let mut lock = my_mutex.unwrap().lock().unwrap();
107///     *lock += 1;
108/// }
109///
110///
111/// // Destroy the Grt to prevent memory leak on DriverExit
112///
113/// extern "C" fn driver_exit(driver: *mut DRIVER_OBJECT) {
114///     unsafe {Grt::destroy()};
115/// }
116/// ```
117pub struct FastMutex<T> {
118    inner: *mut FastMutexInner<T>,
119}
120
121/// The underlying data which is non-page pool allocated which is pointed to by the `FastMutex`.
122struct FastMutexInner<T> {
123    mutex: FAST_MUTEX,
124    /// The data for which the mutex is protecting
125    data: T,
126}
127
128unsafe impl<T> Sync for FastMutex<T> {}
129unsafe impl<T> Send for FastMutex<T> {}
130
131impl<T> FastMutex<T> {
132    /// Creates a new `FAST_MUTEX` Windows Kernel Driver Mutex.
133    ///
134    /// # IRQL
135    ///
136    /// This can be called at IRQL <= DISPATCH_LEVEL.
137    ///
138    /// # Examples
139    ///
140    /// ```
141    /// use wdk_mutex::Mutex;
142    ///
143    /// let my_mutex = wdk_mutex::FastMutex::new(0u32);
144    /// ```
145    pub fn new(data: T) -> Result<Self, DriverMutexError> {
146        // This can only be called at a level <= DISPATCH_LEVEL; check current IRQL
147        // https://learn.microsoft.com/en-us/windows-hardware/drivers/ddi/wdm/nf-wdm-exinitializefastmutex
148        if unsafe { KeGetCurrentIrql() } > DISPATCH_LEVEL as u8 {
149            return Err(DriverMutexError::IrqlTooHigh);
150        }
151
152        //
153        // Non-Paged heap alloc for all struct data required for FastMutexInner
154        //
155        let total_sz_required = size_of::<FastMutexInner<T>>();
156        let inner_heap_ptr: *mut c_void = unsafe {
157            ExAllocatePool2(
158                POOL_FLAG_NON_PAGED,
159                total_sz_required as u64,
160                u32::from_be_bytes(*b"kmtx"),
161            )
162        };
163        if inner_heap_ptr.is_null() {
164            return Err(DriverMutexError::PagedPoolAllocFailed);
165        }
166
167        // Cast the memory allocation to a pointer to the inner
168        let fast_mtx_inner_ptr = inner_heap_ptr as *mut FastMutexInner<T>;
169
170        // SAFETY: This raw write is safe as the pointer validity is checked above.
171        unsafe {
172            ptr::write(
173                fast_mtx_inner_ptr,
174                FastMutexInner {
175                    mutex: FAST_MUTEX::default(),
176                    data,
177                },
178            );
179
180            // Initialise the FastMutex object via the kernel
181            ExInitializeFastMutex(&mut (*fast_mtx_inner_ptr).mutex);
182        }
183
184        Ok(Self {
185            inner: fast_mtx_inner_ptr,
186        })
187    }
188
189    /// Acquires the mutex, raising the IRQL to `APC_LEVEL`.
190    ///
191    /// Once the thread has acquired the mutex, it will return a `FastMutexGuard` which is a RAII scoped
192    /// guard allowing exclusive access to the inner T.
193    ///
194    /// # Errors
195    ///
196    /// If the IRQL is too high, this function will return an error and will not acquire a lock. To prevent
197    /// a kernel panic, the caller should match the return value rather than just unwrapping the value.
198    ///
199    /// # IRQL
200    ///
201    /// This function must be called at IRQL `<= APC_LEVEL`, if the IRQL is higher than this,
202    /// the function will return an error.
203    ///
204    /// It is the callers responsibility to ensure the IRQL is sufficient to call this function and it
205    /// will not alter the IRQL for the caller, as this may introduce undefined behaviour elsewhere in the
206    /// driver / kernel.
207    ///
208    /// # Examples
209    ///
210    /// ```
211    /// let mtx = FastMutex::new(0u32).unwrap();
212    /// let lock = mtx.lock().unwrap();
213    /// ```
214    pub fn lock(&self) -> Result<FastMutexGuard<'_, T>, DriverMutexError> {
215        // Check the IRQL is <= APC_LEVEL as per remarks at
216        // https://learn.microsoft.com/en-us/windows-hardware/drivers/ddi/wdm/nf-wdm-exacquirefastmutex
217        let irql = unsafe { KeGetCurrentIrql() };
218        if irql > APC_LEVEL as u8 {
219            return Err(DriverMutexError::IrqlTooHigh);
220        }
221
222        // SAFETY: RAII manages pointer validity and IRQL checked.
223        unsafe { ExAcquireFastMutex(&mut (*self.inner).mutex as *mut _ as *mut _) };
224
225        Ok(FastMutexGuard { fast_mutex: self })
226    }
227
228    /// Consumes the mutex and returns an owned copy of the protected data (`T`).
229    ///
230    /// This method performs a deep copy of the data (`T`) guarded by the mutex before
231    /// deallocating the internal memory. Be cautious when using this method with large
232    /// data types, as it may lead to inefficiencies or stack overflows.
233    ///
234    /// For scenarios involving large data that you prefer not to allocate on the stack,
235    /// consider using [`Self::to_owned_box`] instead.
236    ///
237    /// # Safety
238    /// 
239    /// This function moves `T` out of the mutex without running `Drop` on the original 
240    /// in-place value. The returned `T` remains fully owned by the caller and will be 
241    /// dropped normally.
242    ///
243    /// - **Single Ownership Guarantee:** After calling [`Self::to_owned`], ensure that
244    ///   no other references (especially static or global ones) attempt to access the
245    ///   underlying mutex. This is because the mutexes memory is deallocated once this
246    ///   method is invoked.
247    /// - **Exclusive Access:** This function should only be called when you can guarantee
248    ///   that there will be no further access to the protected `T`. Violating this can
249    ///   lead to undefined behavior since the memory is freed after the call.
250    /// 
251    /// # Example
252    ///
253    /// ```
254    /// unsafe {
255    ///     let owned_data: T = mutex.to_owned();
256    ///     // Use `owned_data` safely here
257    /// }
258    /// ```
259    pub unsafe fn to_owned(self) -> T {
260        let manually_dropped = ManuallyDrop::new(self);
261        let data_read = unsafe { ptr::read(&(*manually_dropped.inner).data) };
262        
263        // Free the mutex allocation without using drop semantics which could cause an
264        // accidental double drop of the underlying `T`.
265        unsafe { ExFreePool(manually_dropped.inner as _) };
266
267        data_read
268    }
269
270    /// Consumes the mutex and returns an owned `Box<T>` containing the protected data (`T`). 
271    ///
272    /// This method is an alternative to [`Self::to_owned`] and is particularly useful when
273    /// dealing with large data types. By returning a `Box<T>`, the data is pool-allocated,
274    /// avoiding potential stack overflows associated with large stack allocations.
275    ///
276    /// # Safety
277    ///
278    /// This function moves `T` out of the mutex without running `Drop` on the original 
279    /// in-place value. The returned `T` remains fully owned by the caller and will be 
280    /// dropped normally.
281    /// 
282    /// - **Single Ownership Guarantee:** After calling [`Self::to_owned_box`], ensure that
283    /// no other references (especially static or global ones) attempt to access the
284    /// underlying mutex. This is because the mutexes memory is deallocated once this
285    /// method is invoked.
286    /// - **Exclusive Access:** This function should only be called when you can guarantee
287    /// that there will be no further access to the protected `T`. Violating this can
288    /// lead to undefined behavior since the memory is freed after the call.
289    /// 
290    /// # Example
291    ///
292    /// ```rust
293    /// unsafe {
294    ///     let boxed_data: Box<T> = mutex.to_owned_box();
295    ///     // Use `boxed_data` safely here
296    /// }
297    /// ```
298    pub unsafe fn to_owned_box(self) -> Box<T> {
299        let manually_dropped = ManuallyDrop::new(self);
300        let data_read = unsafe { ptr::read(&(*manually_dropped.inner).data) };
301        
302        // Free the mutex allocation without using drop semantics which could cause an
303        // accidental double drop of the underlying `T`.
304        unsafe { ExFreePool(manually_dropped.inner as _) };
305
306        Box::new(data_read)
307    }
308}
309
310impl<T> Drop for FastMutex<T> {
311    fn drop(&mut self) {
312        unsafe {
313            // Drop the underlying data and run destructors for the data, this would be relevant in the
314            // case where Self contains other heap allocated types which have their own deallocation
315            // methods.
316            drop_in_place(&mut (*self.inner).data);
317
318            // Free the memory we allocated
319            ExFreePool(self.inner as *mut _);
320        }
321    }
322}
323
324/// A RAII scoped guard for the inner data protected by the mutex. Once this guard is given out, the protected data
325/// may be safely mutated by the caller as we guarantee exclusive access via Windows Kernel Mutex primitives.
326///
327/// When this structure is dropped (falls out of scope), the lock will be unlocked.
328///
329/// # IRQL
330///
331/// Access to the data within this guard must be done at `APC_LEVEL` It is the callers responsible to
332/// manage IRQL whilst using the `FastMutex`. On calling [`FastMutex::lock`], the IRQL will automatically
333/// be raised to `APC_LEVEL`.
334///
335/// If you wish to manually drop the lock with a safety check, call the function [`Self::drop_safe`].
336///
337/// # Kernel panic
338///
339/// Raising the IRQL above safe limits whilst using the mutex will cause a Kernel Panic if not appropriately handled.
340///
341pub struct FastMutexGuard<'a, T> {
342    fast_mutex: &'a FastMutex<T>,
343}
344
345impl<T> Display for FastMutexGuard<'_, T>
346where
347    T: Display,
348{
349    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
350        // SAFETY: Dereferencing the inner data is safe as RAII controls the memory allocations.
351        write!(f, "{}", unsafe { &(*self.fast_mutex.inner).data })
352    }
353}
354
355impl<T> Deref for FastMutexGuard<'_, T> {
356    type Target = T;
357
358    fn deref(&self) -> &Self::Target {
359        // SAFETY: Dereferencing the inner data is safe as RAII controls the memory allocations.
360        unsafe { &(*self.fast_mutex.inner).data }
361    }
362}
363
364impl<T> DerefMut for FastMutexGuard<'_, T> {
365    fn deref_mut(&mut self) -> &mut Self::Target {
366        // SAFETY: Dereferencing the inner data is safe as RAII controls the memory allocations.
367        // Mutable access is safe due to Self only being given out whilst a mutex is held from the
368        // kernel.
369        unsafe { &mut (*self.fast_mutex.inner).data }
370    }
371}
372
373impl<T> Drop for FastMutexGuard<'_, T> {
374    fn drop(&mut self) {
375        // NOT SAFE AT AN INVALID IRQL
376        unsafe { ExReleaseFastMutex(&mut (*self.fast_mutex.inner).mutex) };
377    }
378}
379
380impl<T> FastMutexGuard<'_, T> {
381    /// Safely drop the `FastMutexGuard`, an alternative to RAII.
382    ///
383    /// This function checks the IRQL before attempting to drop the guard.
384    ///
385    /// # Errors
386    ///
387    /// If the IRQL != `APC_LEVEL`, no unlock will occur and a DriverMutexError will be returned to the
388    /// caller.
389    ///
390    /// # IRQL
391    ///
392    /// This function must be called at `APC_LEVEL`
393    pub fn drop_safe(&mut self) -> Result<(), DriverMutexError> {
394        let irql = unsafe { KeGetCurrentIrql() };
395        if irql != APC_LEVEL as u8 {
396            return Err(DriverMutexError::IrqlTooHigh);
397        }
398
399        unsafe { ExReleaseFastMutex(&mut (*self.fast_mutex.inner).mutex) };
400
401        Ok(())
402    }
403}