From 94ea17f27b41daa12c898274753c70bb8d7b8cc4 Mon Sep 17 00:00:00 2001 From: Yannis Juglaret Date: Mon, 11 Mar 2024 19:06:34 +0100 Subject: [PATCH] Stop depending on `avrt.dll` statically on Windows * Load `avrt.dll` dynamically with `LoadLibraryW` * Fail with an `AudioThreadPriorityError` See also https://bugzilla.mozilla.org/show_bug.cgi?id=1884214 --- Cargo.toml | 5 ++- src/rt_win.rs | 114 ++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 109 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d5cfd95..488421f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,9 +32,12 @@ libc = "0.2" version = "0.52" features = [ "Win32_Foundation", - "Win32_System_Threading", + "Win32_System_LibraryLoader", ] +[target.'cfg(target_os = "windows")'.dependencies.once_cell] +version = "1.19" + [target.'cfg(target_os = "linux")'.dependencies] libc = "0.2" diff --git a/src/rt_win.rs b/src/rt_win.rs index 3a2bbf2..57e5536 100644 --- a/src/rt_win.rs +++ b/src/rt_win.rs @@ -1,12 +1,16 @@ +use crate::AudioThreadPriorityError; +use once_cell::sync; +use windows_sys::core::PCSTR; use windows_sys::s; +use windows_sys::Win32::Foundation::FreeLibrary; use windows_sys::Win32::Foundation::GetLastError; +use windows_sys::Win32::Foundation::BOOL; use windows_sys::Win32::Foundation::FALSE; use windows_sys::Win32::Foundation::HANDLE; -use windows_sys::Win32::System::Threading::{ - AvRevertMmThreadCharacteristics, AvSetMmThreadCharacteristicsA, -}; - -use crate::AudioThreadPriorityError; +use windows_sys::Win32::Foundation::HMODULE; +use windows_sys::Win32::Foundation::WIN32_ERROR; +use windows_sys::Win32::System::LibraryLoader::GetProcAddress; +use windows_sys::Win32::System::LibraryLoader::LoadLibraryA; use log::info; @@ -28,7 +32,8 @@ impl RtPriorityHandleInternal { pub fn demote_current_thread_from_real_time_internal( rt_priority_handle: RtPriorityHandleInternal, ) -> Result<(), AudioThreadPriorityError> { - let rv = unsafe { AvRevertMmThreadCharacteristics(rt_priority_handle.task_handle) }; + let rv = + unsafe { (av_rt()?.av_revert_mm_thread_characteristics)(rt_priority_handle.task_handle) }; if rv == FALSE { return Err(AudioThreadPriorityError::new(&format!( "Unable to restore the thread priority ({:?})", @@ -49,13 +54,13 @@ pub fn promote_current_thread_to_real_time_internal( _audio_samplerate_hz: u32, ) -> Result { let mut task_index = 0u32; - - let handle = unsafe { AvSetMmThreadCharacteristicsA(s!("Audio"), &mut task_index) }; + let handle = + unsafe { (av_rt()?.av_set_mm_thread_characteristics_a)(s!("Audio"), &mut task_index) }; let handle = RtPriorityHandleInternal::new(task_index, handle); if handle.task_handle == 0 { return Err(AudioThreadPriorityError::new(&format!( - "Unable to restore the thread priority ({:?})", + "Unable to bump the thread priority ({:?})", unsafe { GetLastError() } ))); } @@ -67,3 +72,94 @@ pub fn promote_current_thread_to_real_time_internal( Ok(handle) } + +// We don't expect to see API failures on test machines +#[test] +fn test_successful_api_use() { + let handle = promote_current_thread_to_real_time_internal(0, 0); + println!("handle: {handle:?}"); + assert!(handle.is_ok()); + let result = demote_current_thread_from_real_time_internal(handle.unwrap()); + println!("result: {result:?}"); + assert!(result.is_ok()); +} + +fn av_rt() -> Result<&'static AvRtLibrary, AudioThreadPriorityError> { + static AV_RT_LIBRARY: sync::OnceCell> = sync::OnceCell::new(); + AV_RT_LIBRARY + .get_or_init(AvRtLibrary::try_new) + .as_ref() + .map_err(|win32_error| { + AudioThreadPriorityError::new(&format!("Unable to load avrt.dll ({win32_error})")) + }) +} + +// We don't expect to fail to load the library on test machines +#[test] +fn test_successful_avrt_library_load_as_static_ref() { + assert!(av_rt().is_ok()) +} + +type AvSetMmThreadCharacteristicsAFn = unsafe extern "system" fn(PCSTR, *mut u32) -> HANDLE; +type AvRevertMmThreadCharacteristicsFn = unsafe extern "system" fn(HANDLE) -> BOOL; + +#[derive(Debug)] +struct AvRtLibrary { + // This field is used for its Drop behavior + #[allow(dead_code)] + module: OwnedLibrary, + av_set_mm_thread_characteristics_a: AvSetMmThreadCharacteristicsAFn, + av_revert_mm_thread_characteristics: AvRevertMmThreadCharacteristicsFn, +} + +impl AvRtLibrary { + fn try_new() -> Result { + let module = unsafe { LoadLibraryA(s!("avrt.dll")) }; + if module != 0 { + let module = OwnedLibrary(module); + let set_fn = + unsafe { GetProcAddress(module.raw(), s!("AvSetMmThreadCharacteristicsA")) }; + if let Some(set_fn) = set_fn { + let revert_fn = + unsafe { GetProcAddress(module.raw(), s!("AvRevertMmThreadCharacteristics")) }; + if let Some(revert_fn) = revert_fn { + let av_set_mm_thread_characteristics_a = unsafe { + std::mem::transmute::<_, AvSetMmThreadCharacteristicsAFn>(set_fn) + }; + let av_revert_mm_thread_characteristics = unsafe { + std::mem::transmute::<_, AvRevertMmThreadCharacteristicsFn>(revert_fn) + }; + return Ok(AvRtLibrary { + module, + av_set_mm_thread_characteristics_a, + av_revert_mm_thread_characteristics, + }); + } + } + } + Err(unsafe { GetLastError() }) + } +} + +// We don't expect to fail to load the library on test machines +#[test] +fn test_successful_temporary_avrt_library_load() { + assert!(AvRtLibrary::try_new().is_ok()) +} + +#[derive(Debug)] +struct OwnedLibrary(HMODULE); + +impl OwnedLibrary { + fn raw(&self) -> HMODULE { + self.0 + } +} + +impl Drop for OwnedLibrary { + fn drop(&mut self) { + unsafe { + FreeLibrary(self.raw()); + } + } +}