With the below context (Rust Windows driver), I have highlighted the mem::swap
with a comment:
pub struct DriverMessagesWithMutex {
#[derive(Serialize, Deserialize, Default, Debug)]
process_creations: Vec<ProcessStarted>,
impl Default for DriverMessagesWithMutex {
let mut mutex = FAST_MUTEX::default();
unsafe { ExInitializeFastMutex(&mut mutex) };
let data = DriverMessages::default();
DriverMessagesWithMutex { lock: mutex, is_empty: true, data }
impl DriverMessagesWithMutex {
/// Adds a print msg to the queue.
/// This function will wait for an acquisition of the spin lock to continue and will block
pub fn add_message_to_queue(&mut self, data: String)
let irql = unsafe { KeGetCurrentIrql() };
println!("[sanctum] [-] IRQL is not PASSIVE_LEVEL: {}", irql);
unsafe { ExAcquireFastMutex(&mut self.lock) };
let irql = unsafe { KeGetCurrentIrql() };
if irql > APC_LEVEL as u8 {
println!("[sanctum] [-] IRQL is not APIC_LEVEL: {}", irql);
unsafe { ExReleaseFastMutex(&mut self.lock) };
self.data.messages.push(data);
unsafe { ExReleaseFastMutex(&mut self.lock) };
/// Adds data of type ProcessStarted to the message queue.
/// This function will wait for an acquisition of the spin lock to continue and will block
pub fn add_process_creation_to_queue(&mut self, data: ProcessStarted)
let irql = unsafe { KeGetCurrentIrql() };
println!("[sanctum] [-] IRQL is not PASSIVE_LEVEL: {}", irql);
unsafe { ExAcquireFastMutex(&mut self.lock) };
let irql = unsafe { KeGetCurrentIrql() };
if irql > APC_LEVEL as u8 {
println!("[sanctum] [-] IRQL is not APIC_LEVEL: {}", irql);
unsafe { ExReleaseFastMutex(&mut self.lock) };
self.data.process_creations.push(data);
unsafe { ExReleaseFastMutex(&mut self.lock) };
fn extract_all(&mut self) -> Option<DriverMessages> {
let irql = unsafe { KeGetCurrentIrql() };
println!("[sanctum] [-] IRQL is not PASSIVE_LEVEL: {}", irql);
unsafe { ExAcquireFastMutex(&mut self.lock) };
let irql = unsafe { KeGetCurrentIrql() };
if irql > APC_LEVEL as u8 {
println!("[sanctum] [-] IRQL is not APIC_LEVEL: {}", irql);
unsafe { ExReleaseFastMutex(&mut self.lock) };
unsafe { ExReleaseFastMutex(&mut self.lock) };
let mut extracted_data = DriverMessages::default();
mem::swap(&mut extracted_data, &mut self.data);
self.is_empty = true; // reset flag
unsafe { ExReleaseFastMutex(&mut self.lock) };
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ProcessStarted {
pub command_line: String,
// Windows sys crate for driver development:
pub unsafe fn ExInitializeFastMutex(kmutex: *mut FAST_MUTEX) {
core::ptr::write_volatile(&mut (*kmutex).Count, FM_LOCK_BIT as i32);
(*kmutex).Owner = core::ptr::null_mut();
(*kmutex).Contention = 0;
KeInitializeEvent(&mut (*kmutex).Event, SynchronizationEvent, FALSE as _)
<code>use core::mem;
pub struct DriverMessagesWithMutex {
lock: FAST_MUTEX,
is_empty: bool,
data: DriverMessages,
}
#[derive(Serialize, Deserialize, Default, Debug)]
struct DriverMessages {
messages: Vec<String>,
process_creations: Vec<ProcessStarted>,
}
impl Default for DriverMessagesWithMutex {
fn default() -> Self {
let mut mutex = FAST_MUTEX::default();
unsafe { ExInitializeFastMutex(&mut mutex) };
let data = DriverMessages::default();
DriverMessagesWithMutex { lock: mutex, is_empty: true, data }
}
}
impl DriverMessagesWithMutex {
// ..
/// Adds a print msg to the queue.
///
/// This function will wait for an acquisition of the spin lock to continue and will block
/// until that point.
pub fn add_message_to_queue(&mut self, data: String)
{
let irql = unsafe { KeGetCurrentIrql() };
if irql != 0 {
println!("[sanctum] [-] IRQL is not PASSIVE_LEVEL: {}", irql);
return;
}
unsafe { ExAcquireFastMutex(&mut self.lock) };
let irql = unsafe { KeGetCurrentIrql() };
if irql > APC_LEVEL as u8 {
println!("[sanctum] [-] IRQL is not APIC_LEVEL: {}", irql);
unsafe { ExReleaseFastMutex(&mut self.lock) };
return;
}
self.is_empty = false;
self.data.messages.push(data);
unsafe { ExReleaseFastMutex(&mut self.lock) };
}
/// Adds data of type ProcessStarted to the message queue.
///
/// This function will wait for an acquisition of the spin lock to continue and will block
/// until that point.
pub fn add_process_creation_to_queue(&mut self, data: ProcessStarted)
{
let irql = unsafe { KeGetCurrentIrql() };
if irql != 0 {
println!("[sanctum] [-] IRQL is not PASSIVE_LEVEL: {}", irql);
return;
}
unsafe { ExAcquireFastMutex(&mut self.lock) };
let irql = unsafe { KeGetCurrentIrql() };
if irql > APC_LEVEL as u8 {
println!("[sanctum] [-] IRQL is not APIC_LEVEL: {}", irql);
unsafe { ExReleaseFastMutex(&mut self.lock) };
return;
}
self.is_empty = false;
self.data.process_creations.push(data);
unsafe { ExReleaseFastMutex(&mut self.lock) };
}
fn extract_all(&mut self) -> Option<DriverMessages> {
let irql = unsafe { KeGetCurrentIrql() };
if irql != 0 {
println!("[sanctum] [-] IRQL is not PASSIVE_LEVEL: {}", irql);
return None;
}
unsafe { ExAcquireFastMutex(&mut self.lock) };
let irql = unsafe { KeGetCurrentIrql() };
if irql > APC_LEVEL as u8 {
println!("[sanctum] [-] IRQL is not APIC_LEVEL: {}", irql);
unsafe { ExReleaseFastMutex(&mut self.lock) };
return None;
}
if self.is_empty {
unsafe { ExReleaseFastMutex(&mut self.lock) };
return None;
}
// THIS PART HERE
let mut extracted_data = DriverMessages::default();
mem::swap(&mut extracted_data, &mut self.data);
self.is_empty = true; // reset flag
unsafe { ExReleaseFastMutex(&mut self.lock) };
Some(extracted_data)
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ProcessStarted {
pub image_name: String,
pub command_line: String,
pub parent_pid: String,
}
// Windows sys crate for driver development:
#[allow(non_snake_case)]
pub unsafe fn ExInitializeFastMutex(kmutex: *mut FAST_MUTEX) {
core::ptr::write_volatile(&mut (*kmutex).Count, FM_LOCK_BIT as i32);
(*kmutex).Owner = core::ptr::null_mut();
(*kmutex).Contention = 0;
KeInitializeEvent(&mut (*kmutex).Event, SynchronizationEvent, FALSE as _)
}
</code>
use core::mem;
pub struct DriverMessagesWithMutex {
lock: FAST_MUTEX,
is_empty: bool,
data: DriverMessages,
}
#[derive(Serialize, Deserialize, Default, Debug)]
struct DriverMessages {
messages: Vec<String>,
process_creations: Vec<ProcessStarted>,
}
impl Default for DriverMessagesWithMutex {
fn default() -> Self {
let mut mutex = FAST_MUTEX::default();
unsafe { ExInitializeFastMutex(&mut mutex) };
let data = DriverMessages::default();
DriverMessagesWithMutex { lock: mutex, is_empty: true, data }
}
}
impl DriverMessagesWithMutex {
// ..
/// Adds a print msg to the queue.
///
/// This function will wait for an acquisition of the spin lock to continue and will block
/// until that point.
pub fn add_message_to_queue(&mut self, data: String)
{
let irql = unsafe { KeGetCurrentIrql() };
if irql != 0 {
println!("[sanctum] [-] IRQL is not PASSIVE_LEVEL: {}", irql);
return;
}
unsafe { ExAcquireFastMutex(&mut self.lock) };
let irql = unsafe { KeGetCurrentIrql() };
if irql > APC_LEVEL as u8 {
println!("[sanctum] [-] IRQL is not APIC_LEVEL: {}", irql);
unsafe { ExReleaseFastMutex(&mut self.lock) };
return;
}
self.is_empty = false;
self.data.messages.push(data);
unsafe { ExReleaseFastMutex(&mut self.lock) };
}
/// Adds data of type ProcessStarted to the message queue.
///
/// This function will wait for an acquisition of the spin lock to continue and will block
/// until that point.
pub fn add_process_creation_to_queue(&mut self, data: ProcessStarted)
{
let irql = unsafe { KeGetCurrentIrql() };
if irql != 0 {
println!("[sanctum] [-] IRQL is not PASSIVE_LEVEL: {}", irql);
return;
}
unsafe { ExAcquireFastMutex(&mut self.lock) };
let irql = unsafe { KeGetCurrentIrql() };
if irql > APC_LEVEL as u8 {
println!("[sanctum] [-] IRQL is not APIC_LEVEL: {}", irql);
unsafe { ExReleaseFastMutex(&mut self.lock) };
return;
}
self.is_empty = false;
self.data.process_creations.push(data);
unsafe { ExReleaseFastMutex(&mut self.lock) };
}
fn extract_all(&mut self) -> Option<DriverMessages> {
let irql = unsafe { KeGetCurrentIrql() };
if irql != 0 {
println!("[sanctum] [-] IRQL is not PASSIVE_LEVEL: {}", irql);
return None;
}
unsafe { ExAcquireFastMutex(&mut self.lock) };
let irql = unsafe { KeGetCurrentIrql() };
if irql > APC_LEVEL as u8 {
println!("[sanctum] [-] IRQL is not APIC_LEVEL: {}", irql);
unsafe { ExReleaseFastMutex(&mut self.lock) };
return None;
}
if self.is_empty {
unsafe { ExReleaseFastMutex(&mut self.lock) };
return None;
}
// THIS PART HERE
let mut extracted_data = DriverMessages::default();
mem::swap(&mut extracted_data, &mut self.data);
self.is_empty = true; // reset flag
unsafe { ExReleaseFastMutex(&mut self.lock) };
Some(extracted_data)
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ProcessStarted {
pub image_name: String,
pub command_line: String,
pub parent_pid: String,
}
// Windows sys crate for driver development:
#[allow(non_snake_case)]
pub unsafe fn ExInitializeFastMutex(kmutex: *mut FAST_MUTEX) {
core::ptr::write_volatile(&mut (*kmutex).Count, FM_LOCK_BIT as i32);
(*kmutex).Owner = core::ptr::null_mut();
(*kmutex).Contention = 0;
KeInitializeEvent(&mut (*kmutex).Event, SynchronizationEvent, FALSE as _)
}
I am looking to return out the content of self.data in a memory efficient way (avoiding a clone()
) so that a caching structure in the kernel can take ownership of a cached view to do with it as it pleases.
When using mem::take()
instead of mem::swap()
, a blue screen kernel panic was happening each time; and I’m trying to figure the internals out for my own learning.
The rust docs say:
Take allows taking ownership of a struct field by replacing it with an “empty” value.
This seems to be the exact same thing as using mem::replace
. The only thing I can think of is somehow a part of the kernel is trying to access / write to the memory at self.data, at the time of the operation? Or that memory was deemed valid but has now been invalidated after the take
operation – but in both cases, should it not work the same as mem::swap()
?
I would be grateful for an explanation as to the differences between these functions, and whether the context of running in the kernel makes a difference to the assurance against panics using these functions.
To provide context as to how the struct is used, I have two static global variables (which I know is not a good rust pattern) for holding the ‘live’ and cached DriverMessagesWithMutex data:
static DRIVER_MESSAGES: AtomicPtr<DriverMessagesWithMutex> = AtomicPtr::new(null_mut());
static DRIVER_MESSAGES_CACHE: AtomicPtr<DriverMessagesWithMutex> = AtomicPtr::new(null_mut());
let messages = Box::new(DriverMessagesWithMutex::new());
let messages_cache = Box::new(DriverMessagesWithMutex::new());
DRIVER_MESSAGES.store(Box::into_raw(messages), Ordering::SeqCst);
DRIVER_MESSAGES_CACHE.store(Box::into_raw(messages_cache), Ordering::SeqCst);
<code>// delcared
static DRIVER_MESSAGES: AtomicPtr<DriverMessagesWithMutex> = AtomicPtr::new(null_mut());
static DRIVER_MESSAGES_CACHE: AtomicPtr<DriverMessagesWithMutex> = AtomicPtr::new(null_mut());
// initialised like so:
let messages = Box::new(DriverMessagesWithMutex::new());
let messages_cache = Box::new(DriverMessagesWithMutex::new());
DRIVER_MESSAGES.store(Box::into_raw(messages), Ordering::SeqCst);
DRIVER_MESSAGES_CACHE.store(Box::into_raw(messages_cache), Ordering::SeqCst);
</code>
// delcared
static DRIVER_MESSAGES: AtomicPtr<DriverMessagesWithMutex> = AtomicPtr::new(null_mut());
static DRIVER_MESSAGES_CACHE: AtomicPtr<DriverMessagesWithMutex> = AtomicPtr::new(null_mut());
// initialised like so:
let messages = Box::new(DriverMessagesWithMutex::new());
let messages_cache = Box::new(DriverMessagesWithMutex::new());
DRIVER_MESSAGES.store(Box::into_raw(messages), Ordering::SeqCst);
DRIVER_MESSAGES_CACHE.store(Box::into_raw(messages_cache), Ordering::SeqCst);
These are global, as in the driver I am using callback functions to kernel events to log data, these callbacks only accept a function pointer (i.e. I cannot pass a RC/ARC into them), for instance PsSetCreateProcessNotifyRoutineEx:
<code>PsSetCreateProcessNotifyRoutineEx(Some(core_callback_notify_ps), FALSE as u8);
<code>PsSetCreateProcessNotifyRoutineEx(Some(core_callback_notify_ps), FALSE as u8);
</code>
PsSetCreateProcessNotifyRoutineEx(Some(core_callback_notify_ps), FALSE as u8);
Within the core_callback_notify_ps
function, I access the global DRIVER_MESSAGES like so:
<code>pub unsafe extern "C" fn core_callback_notify_ps(process: PEPROCESS, pid: HANDLE, created: *mut PS_CREATE_NOTIFY_INFO) {
let image_name = unicode_to_string((*created).ImageFileName);
let command_line = unicode_to_string((*created).CommandLine);
let ppid = format!("{:?}", (*created).ParentProcessId);
if image_name.is_err() || command_line.is_err() {
let process_started = ProcessStarted {
image_name: image_name.unwrap().replace("\??\", ""),
command_line: command_line.unwrap().replace("\??\", ""),
// Attempt to dereference the DRIVER_MESSAGES global; if the dereference is successful,
// add the relevant data to the queue
if !DRIVER_MESSAGES.load(Ordering::SeqCst).is_null() {
let obj = unsafe { &mut *DRIVER_MESSAGES.load(Ordering::SeqCst) };
obj.add_process_creation_to_queue(process_started);
println!("[sanctum] [-] Driver messages is null");
<code>pub unsafe extern "C" fn core_callback_notify_ps(process: PEPROCESS, pid: HANDLE, created: *mut PS_CREATE_NOTIFY_INFO) {
if !created.is_null() {
let image_name = unicode_to_string((*created).ImageFileName);
let command_line = unicode_to_string((*created).CommandLine);
let ppid = format!("{:?}", (*created).ParentProcessId);
if image_name.is_err() || command_line.is_err() {
return;
}
let process_started = ProcessStarted {
image_name: image_name.unwrap().replace("\??\", ""),
command_line: command_line.unwrap().replace("\??\", ""),
parent_pid: ppid,
};
// Attempt to dereference the DRIVER_MESSAGES global; if the dereference is successful,
// add the relevant data to the queue
if !DRIVER_MESSAGES.load(Ordering::SeqCst).is_null() {
let obj = unsafe { &mut *DRIVER_MESSAGES.load(Ordering::SeqCst) };
obj.add_process_creation_to_queue(process_started);
} else {
println!("[sanctum] [-] Driver messages is null");
};
}
}
</code>
pub unsafe extern "C" fn core_callback_notify_ps(process: PEPROCESS, pid: HANDLE, created: *mut PS_CREATE_NOTIFY_INFO) {
if !created.is_null() {
let image_name = unicode_to_string((*created).ImageFileName);
let command_line = unicode_to_string((*created).CommandLine);
let ppid = format!("{:?}", (*created).ParentProcessId);
if image_name.is_err() || command_line.is_err() {
return;
}
let process_started = ProcessStarted {
image_name: image_name.unwrap().replace("\??\", ""),
command_line: command_line.unwrap().replace("\??\", ""),
parent_pid: ppid,
};
// Attempt to dereference the DRIVER_MESSAGES global; if the dereference is successful,
// add the relevant data to the queue
if !DRIVER_MESSAGES.load(Ordering::SeqCst).is_null() {
let obj = unsafe { &mut *DRIVER_MESSAGES.load(Ordering::SeqCst) };
obj.add_process_creation_to_queue(process_started);
} else {
println!("[sanctum] [-] Driver messages is null");
};
}
}
I am then trying to read from the globals like so:
<code>if !DRIVER_MESSAGES.load(Ordering::SeqCst).is_null() {
let original_object = unsafe { &mut *DRIVER_MESSAGES.load(Ordering::SeqCst) };
// drained would then be copied into the cache if empty, or appended if not empty
let drained = original_object.extract_all();
<code>if !DRIVER_MESSAGES.load(Ordering::SeqCst).is_null() {
let original_object = unsafe { &mut *DRIVER_MESSAGES.load(Ordering::SeqCst) };
// drained would then be copied into the cache if empty, or appended if not empty
let drained = original_object.extract_all();
}
</code>
if !DRIVER_MESSAGES.load(Ordering::SeqCst).is_null() {
let original_object = unsafe { &mut *DRIVER_MESSAGES.load(Ordering::SeqCst) };
// drained would then be copied into the cache if empty, or appended if not empty
let drained = original_object.extract_all();
}