wdk_mutex/fast_mutex.rs
1//! A Rust idiomatic Windows Kernel Driver FAST_MUTEX type which protects the inner type T
2
3use alloc::boxed::Box;
4use core::{
5 ffi::c_void, fmt::Display, mem::ManuallyDrop, ops::{Deref, DerefMut}, ptr::{self, drop_in_place}
6};
7use wdk_sys::{
8 ntddk::{
9 ExAcquireFastMutex, ExAllocatePool2, ExFreePool, ExReleaseFastMutex, KeGetCurrentIrql,
10 KeInitializeEvent,
11 },
12 APC_LEVEL, DISPATCH_LEVEL, FALSE, FAST_MUTEX, FM_LOCK_BIT, POOL_FLAG_NON_PAGED,
13 _EVENT_TYPE::SynchronizationEvent,
14};
15
16extern crate alloc;
17
18use crate::errors::DriverMutexError;
19
20/// An internal binding for the ExInitializeFastMutex routine.
21///
22/// # Safety
23///
24/// This function does not check the IRQL as the only place this function is used is in an area where the IRQL
25/// is already checked.
26#[allow(non_snake_case)]
27unsafe fn ExInitializeFastMutex(fast_mutex: *mut FAST_MUTEX) {
28 core::ptr::write_volatile(&mut (*fast_mutex).Count, FM_LOCK_BIT as i32);
29
30 (*fast_mutex).Owner = core::ptr::null_mut();
31 (*fast_mutex).Contention = 0;
32 KeInitializeEvent(&mut (*fast_mutex).Event, SynchronizationEvent, FALSE as _)
33}
34
35/// A thread safe mutex implemented through acquiring a `FAST_MUTEX` in the Windows kernel.
36///
37/// The type `FastMutex<T>` provides mutually exclusive access to the inner type T allocated through
38/// this crate in the non-paged pool. All data required to initialise the FastMutex is allocated in the
39/// non-paged pool and as such is safe to pass stack data into the type as it will not go out of scope.
40///
41/// `FastMutex` holds an inner value which is a pointer to a `FastMutexInner` type which is the actual type
42/// allocated in the non-paged pool, and this holds information relating to the mutex.
43///
44/// Access to the `T` within the `FastMutex` can be done through calling [`Self::lock`].
45///
46/// # Lifetimes
47///
48/// As the `FastMutex` is designed to be used in the Windows Kernel, with the Windows `wdk` crate, the lifetimes of
49/// the `FastMutex` must be considered by the caller. See examples below for usage.
50///
51/// The `FastMutex` can exist in a locally scoped function with little additional configuration. To use the mutex across
52/// thread boundaries, or to use it in callback functions, you can use the `Grt` module found in this crate. See below for
53/// details.
54///
55/// # Deallocation
56///
57/// FastMutex handles the deallocation of resources at the point the FastMutex is dropped.
58///
59/// # Examples
60///
61/// ## Locally scoped mutex:
62///
63/// ```
64/// {
65/// let mtx = FastMutex::new(0u32).unwrap();
66/// let lock = mtx.lock().unwrap();
67///
68/// // If T implements display, you do not need to dereference the lock to print.
69/// println!("The value is: {}", lock);
70/// } // Mutex will become unlocked as it is managed via RAII
71/// ```
72///
73/// ## Global scope via the `Grt` module in `wdk-mutex`:
74///
75/// ```
76/// // Initialise the mutex on DriverEntry
77///
78/// #[export_name = "DriverEntry"]
79/// pub unsafe extern "system" fn driver_entry(
80/// driver: &mut DRIVER_OBJECT,
81/// registry_path: PCUNICODE_STRING,
82/// ) -> NTSTATUS {
83/// if let Err(e) = Grt::init() {
84/// println!("Error creating Grt!: {:?}", e);
85/// return STATUS_UNSUCCESSFUL;
86/// }
87///
88/// // ...
89/// my_function();
90/// }
91///
92///
93/// // Register a new Mutex in the `Grt` of value 0u32:
94///
95/// pub fn my_function() {
96/// Grt::register_fast_mutex("my_test_mutex", 0u32);
97/// }
98///
99/// unsafe extern "C" fn my_thread_fn_pointer(_: *mut c_void) {
100/// let my_mutex = Grt::get_fast_mutex::<u32>("my_test_mutex");
101/// if let Err(e) = my_mutex {
102/// println!("Error in thread: {:?}", e);
103/// return;
104/// }
105///
106/// let mut lock = my_mutex.unwrap().lock().unwrap();
107/// *lock += 1;
108/// }
109///
110///
111/// // Destroy the Grt to prevent memory leak on DriverExit
112///
113/// extern "C" fn driver_exit(driver: *mut DRIVER_OBJECT) {
114/// unsafe {Grt::destroy()};
115/// }
116/// ```
117pub struct FastMutex<T> {
118 inner: *mut FastMutexInner<T>,
119}
120
121/// The underlying data which is non-page pool allocated which is pointed to by the `FastMutex`.
122struct FastMutexInner<T> {
123 mutex: FAST_MUTEX,
124 /// The data for which the mutex is protecting
125 data: T,
126}
127
128unsafe impl<T> Sync for FastMutex<T> {}
129unsafe impl<T> Send for FastMutex<T> {}
130
131impl<T> FastMutex<T> {
132 /// Creates a new `FAST_MUTEX` Windows Kernel Driver Mutex.
133 ///
134 /// # IRQL
135 ///
136 /// This can be called at IRQL <= DISPATCH_LEVEL.
137 ///
138 /// # Examples
139 ///
140 /// ```
141 /// use wdk_mutex::Mutex;
142 ///
143 /// let my_mutex = wdk_mutex::FastMutex::new(0u32);
144 /// ```
145 pub fn new(data: T) -> Result<Self, DriverMutexError> {
146 // This can only be called at a level <= DISPATCH_LEVEL; check current IRQL
147 // https://learn.microsoft.com/en-us/windows-hardware/drivers/ddi/wdm/nf-wdm-exinitializefastmutex
148 if unsafe { KeGetCurrentIrql() } > DISPATCH_LEVEL as u8 {
149 return Err(DriverMutexError::IrqlTooHigh);
150 }
151
152 //
153 // Non-Paged heap alloc for all struct data required for FastMutexInner
154 //
155 let total_sz_required = size_of::<FastMutexInner<T>>();
156 let inner_heap_ptr: *mut c_void = unsafe {
157 ExAllocatePool2(
158 POOL_FLAG_NON_PAGED,
159 total_sz_required as u64,
160 u32::from_be_bytes(*b"kmtx"),
161 )
162 };
163 if inner_heap_ptr.is_null() {
164 return Err(DriverMutexError::PagedPoolAllocFailed);
165 }
166
167 // Cast the memory allocation to a pointer to the inner
168 let fast_mtx_inner_ptr = inner_heap_ptr as *mut FastMutexInner<T>;
169
170 // SAFETY: This raw write is safe as the pointer validity is checked above.
171 unsafe {
172 ptr::write(
173 fast_mtx_inner_ptr,
174 FastMutexInner {
175 mutex: FAST_MUTEX::default(),
176 data,
177 },
178 );
179
180 // Initialise the FastMutex object via the kernel
181 ExInitializeFastMutex(&mut (*fast_mtx_inner_ptr).mutex);
182 }
183
184 Ok(Self {
185 inner: fast_mtx_inner_ptr,
186 })
187 }
188
189 /// Acquires the mutex, raising the IRQL to `APC_LEVEL`.
190 ///
191 /// Once the thread has acquired the mutex, it will return a `FastMutexGuard` which is a RAII scoped
192 /// guard allowing exclusive access to the inner T.
193 ///
194 /// # Errors
195 ///
196 /// If the IRQL is too high, this function will return an error and will not acquire a lock. To prevent
197 /// a kernel panic, the caller should match the return value rather than just unwrapping the value.
198 ///
199 /// # IRQL
200 ///
201 /// This function must be called at IRQL `<= APC_LEVEL`, if the IRQL is higher than this,
202 /// the function will return an error.
203 ///
204 /// It is the callers responsibility to ensure the IRQL is sufficient to call this function and it
205 /// will not alter the IRQL for the caller, as this may introduce undefined behaviour elsewhere in the
206 /// driver / kernel.
207 ///
208 /// # Examples
209 ///
210 /// ```
211 /// let mtx = FastMutex::new(0u32).unwrap();
212 /// let lock = mtx.lock().unwrap();
213 /// ```
214 pub fn lock(&self) -> Result<FastMutexGuard<'_, T>, DriverMutexError> {
215 // Check the IRQL is <= APC_LEVEL as per remarks at
216 // https://learn.microsoft.com/en-us/windows-hardware/drivers/ddi/wdm/nf-wdm-exacquirefastmutex
217 let irql = unsafe { KeGetCurrentIrql() };
218 if irql > APC_LEVEL as u8 {
219 return Err(DriverMutexError::IrqlTooHigh);
220 }
221
222 // SAFETY: RAII manages pointer validity and IRQL checked.
223 unsafe { ExAcquireFastMutex(&mut (*self.inner).mutex as *mut _ as *mut _) };
224
225 Ok(FastMutexGuard { fast_mutex: self })
226 }
227
228 /// Consumes the mutex and returns an owned copy of the protected data (`T`).
229 ///
230 /// This method performs a deep copy of the data (`T`) guarded by the mutex before
231 /// deallocating the internal memory. Be cautious when using this method with large
232 /// data types, as it may lead to inefficiencies or stack overflows.
233 ///
234 /// For scenarios involving large data that you prefer not to allocate on the stack,
235 /// consider using [`Self::to_owned_box`] instead.
236 ///
237 /// # Safety
238 ///
239 /// This function moves `T` out of the mutex without running `Drop` on the original
240 /// in-place value. The returned `T` remains fully owned by the caller and will be
241 /// dropped normally.
242 ///
243 /// - **Single Ownership Guarantee:** After calling [`Self::to_owned`], ensure that
244 /// no other references (especially static or global ones) attempt to access the
245 /// underlying mutex. This is because the mutexes memory is deallocated once this
246 /// method is invoked.
247 /// - **Exclusive Access:** This function should only be called when you can guarantee
248 /// that there will be no further access to the protected `T`. Violating this can
249 /// lead to undefined behavior since the memory is freed after the call.
250 ///
251 /// # Example
252 ///
253 /// ```
254 /// unsafe {
255 /// let owned_data: T = mutex.to_owned();
256 /// // Use `owned_data` safely here
257 /// }
258 /// ```
259 pub unsafe fn to_owned(self) -> T {
260 let manually_dropped = ManuallyDrop::new(self);
261 let data_read = unsafe { ptr::read(&(*manually_dropped.inner).data) };
262
263 // Free the mutex allocation without using drop semantics which could cause an
264 // accidental double drop of the underlying `T`.
265 unsafe { ExFreePool(manually_dropped.inner as _) };
266
267 data_read
268 }
269
270 /// Consumes the mutex and returns an owned `Box<T>` containing the protected data (`T`).
271 ///
272 /// This method is an alternative to [`Self::to_owned`] and is particularly useful when
273 /// dealing with large data types. By returning a `Box<T>`, the data is pool-allocated,
274 /// avoiding potential stack overflows associated with large stack allocations.
275 ///
276 /// # Safety
277 ///
278 /// This function moves `T` out of the mutex without running `Drop` on the original
279 /// in-place value. The returned `T` remains fully owned by the caller and will be
280 /// dropped normally.
281 ///
282 /// - **Single Ownership Guarantee:** After calling [`Self::to_owned_box`], ensure that
283 /// no other references (especially static or global ones) attempt to access the
284 /// underlying mutex. This is because the mutexes memory is deallocated once this
285 /// method is invoked.
286 /// - **Exclusive Access:** This function should only be called when you can guarantee
287 /// that there will be no further access to the protected `T`. Violating this can
288 /// lead to undefined behavior since the memory is freed after the call.
289 ///
290 /// # Example
291 ///
292 /// ```rust
293 /// unsafe {
294 /// let boxed_data: Box<T> = mutex.to_owned_box();
295 /// // Use `boxed_data` safely here
296 /// }
297 /// ```
298 pub unsafe fn to_owned_box(self) -> Box<T> {
299 let manually_dropped = ManuallyDrop::new(self);
300 let data_read = unsafe { ptr::read(&(*manually_dropped.inner).data) };
301
302 // Free the mutex allocation without using drop semantics which could cause an
303 // accidental double drop of the underlying `T`.
304 unsafe { ExFreePool(manually_dropped.inner as _) };
305
306 Box::new(data_read)
307 }
308}
309
310impl<T> Drop for FastMutex<T> {
311 fn drop(&mut self) {
312 unsafe {
313 // Drop the underlying data and run destructors for the data, this would be relevant in the
314 // case where Self contains other heap allocated types which have their own deallocation
315 // methods.
316 drop_in_place(&mut (*self.inner).data);
317
318 // Free the memory we allocated
319 ExFreePool(self.inner as *mut _);
320 }
321 }
322}
323
324/// A RAII scoped guard for the inner data protected by the mutex. Once this guard is given out, the protected data
325/// may be safely mutated by the caller as we guarantee exclusive access via Windows Kernel Mutex primitives.
326///
327/// When this structure is dropped (falls out of scope), the lock will be unlocked.
328///
329/// # IRQL
330///
331/// Access to the data within this guard must be done at `APC_LEVEL` It is the callers responsible to
332/// manage IRQL whilst using the `FastMutex`. On calling [`FastMutex::lock`], the IRQL will automatically
333/// be raised to `APC_LEVEL`.
334///
335/// If you wish to manually drop the lock with a safety check, call the function [`Self::drop_safe`].
336///
337/// # Kernel panic
338///
339/// Raising the IRQL above safe limits whilst using the mutex will cause a Kernel Panic if not appropriately handled.
340///
341pub struct FastMutexGuard<'a, T> {
342 fast_mutex: &'a FastMutex<T>,
343}
344
345impl<T> Display for FastMutexGuard<'_, T>
346where
347 T: Display,
348{
349 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
350 // SAFETY: Dereferencing the inner data is safe as RAII controls the memory allocations.
351 write!(f, "{}", unsafe { &(*self.fast_mutex.inner).data })
352 }
353}
354
355impl<T> Deref for FastMutexGuard<'_, T> {
356 type Target = T;
357
358 fn deref(&self) -> &Self::Target {
359 // SAFETY: Dereferencing the inner data is safe as RAII controls the memory allocations.
360 unsafe { &(*self.fast_mutex.inner).data }
361 }
362}
363
364impl<T> DerefMut for FastMutexGuard<'_, T> {
365 fn deref_mut(&mut self) -> &mut Self::Target {
366 // SAFETY: Dereferencing the inner data is safe as RAII controls the memory allocations.
367 // Mutable access is safe due to Self only being given out whilst a mutex is held from the
368 // kernel.
369 unsafe { &mut (*self.fast_mutex.inner).data }
370 }
371}
372
373impl<T> Drop for FastMutexGuard<'_, T> {
374 fn drop(&mut self) {
375 // NOT SAFE AT AN INVALID IRQL
376 unsafe { ExReleaseFastMutex(&mut (*self.fast_mutex.inner).mutex) };
377 }
378}
379
380impl<T> FastMutexGuard<'_, T> {
381 /// Safely drop the `FastMutexGuard`, an alternative to RAII.
382 ///
383 /// This function checks the IRQL before attempting to drop the guard.
384 ///
385 /// # Errors
386 ///
387 /// If the IRQL != `APC_LEVEL`, no unlock will occur and a DriverMutexError will be returned to the
388 /// caller.
389 ///
390 /// # IRQL
391 ///
392 /// This function must be called at `APC_LEVEL`
393 pub fn drop_safe(&mut self) -> Result<(), DriverMutexError> {
394 let irql = unsafe { KeGetCurrentIrql() };
395 if irql != APC_LEVEL as u8 {
396 return Err(DriverMutexError::IrqlTooHigh);
397 }
398
399 unsafe { ExReleaseFastMutex(&mut (*self.fast_mutex.inner).mutex) };
400
401 Ok(())
402 }
403}