win-wasapi: Schedule work on real-time work queue

MS claims it can schedule audio better if we use their API.
master
jpark37 2021-09-26 04:21:22 -07:00
parent 86b607154a
commit 24d82062ae
1 changed files with 302 additions and 8 deletions

View File

@ -14,6 +14,7 @@
#include <cinttypes> #include <cinttypes>
#include <avrt.h> #include <avrt.h>
#include <RTWorkQ.h>
using namespace std; using namespace std;
@ -25,6 +26,66 @@ static void GetWASAPIDefaults(obs_data_t *settings);
#define OBS_KSAUDIO_SPEAKER_4POINT1 \ #define OBS_KSAUDIO_SPEAKER_4POINT1 \
(KSAUDIO_SPEAKER_SURROUND | SPEAKER_LOW_FREQUENCY) (KSAUDIO_SPEAKER_SURROUND | SPEAKER_LOW_FREQUENCY)
typedef HRESULT(STDAPICALLTYPE *PFN_RtwqUnlockWorkQueue)(DWORD);
typedef HRESULT(STDAPICALLTYPE *PFN_RtwqLockSharedWorkQueue)(PCWSTR usageClass,
LONG basePriority,
DWORD *taskId,
DWORD *id);
typedef HRESULT(STDAPICALLTYPE *PFN_RtwqCreateAsyncResult)(IUnknown *,
IRtwqAsyncCallback *,
IUnknown *,
IRtwqAsyncResult **);
typedef HRESULT(STDAPICALLTYPE *PFN_RtwqPutWorkItem)(DWORD, LONG,
IRtwqAsyncResult *);
typedef HRESULT(STDAPICALLTYPE *PFN_RtwqPutWaitingWorkItem)(HANDLE, LONG,
IRtwqAsyncResult *,
RTWQWORKITEM_KEY *);
class ARtwqAsyncCallback : public IRtwqAsyncCallback {
protected:
ARtwqAsyncCallback(void *source) : source(source) {}
public:
STDMETHOD_(ULONG, AddRef)() { return ++refCount; }
STDMETHOD_(ULONG, Release)() { return --refCount; }
STDMETHOD(QueryInterface)(REFIID riid, void **ppvObject)
{
HRESULT hr = E_NOINTERFACE;
if (riid == __uuidof(IRtwqAsyncCallback) ||
riid == __uuidof(IUnknown)) {
*ppvObject = this;
AddRef();
hr = S_OK;
} else {
*ppvObject = NULL;
}
return hr;
}
STDMETHOD(GetParameters)
(DWORD *pdwFlags, DWORD *pdwQueue)
{
*pdwFlags = 0;
*pdwQueue = queue_id;
return S_OK;
}
STDMETHOD(Invoke)
(IRtwqAsyncResult *) override = 0;
DWORD GetQueueId() const { return queue_id; }
void SetQueueId(DWORD id) { queue_id = id; }
protected:
std::atomic<ULONG> refCount = 1;
void *source;
DWORD queue_id = 0;
};
class WASAPISource { class WASAPISource {
ComPtr<IMMNotificationClient> notify; ComPtr<IMMNotificationClient> notify;
ComPtr<IMMDeviceEnumerator> enumerator; ComPtr<IMMDeviceEnumerator> enumerator;
@ -35,6 +96,12 @@ class WASAPISource {
wstring default_id; wstring default_id;
string device_id; string device_id;
string device_name; string device_name;
PFN_RtwqUnlockWorkQueue rtwq_unlock_work_queue = NULL;
PFN_RtwqLockSharedWorkQueue rtwq_lock_shared_work_queue = NULL;
PFN_RtwqCreateAsyncResult rtwq_create_async_result = NULL;
PFN_RtwqPutWorkItem rtwq_put_work_item = NULL;
PFN_RtwqPutWaitingWorkItem rtwq_put_waiting_work_item = NULL;
bool rtwq_supported = false;
uint64_t lastNotifyTime = 0; uint64_t lastNotifyTime = 0;
bool isInputDevice; bool isInputDevice;
std::atomic<bool> useDeviceTiming = false; std::atomic<bool> useDeviceTiming = false;
@ -43,6 +110,55 @@ class WASAPISource {
bool previouslyFailed = false; bool previouslyFailed = false;
WinHandle reconnectThread; WinHandle reconnectThread;
class CallbackStartCapture : public ARtwqAsyncCallback {
public:
CallbackStartCapture(WASAPISource *source)
: ARtwqAsyncCallback(source)
{
}
STDMETHOD(Invoke)
(IRtwqAsyncResult *) override
{
((WASAPISource *)source)->OnStartCapture();
return S_OK;
}
} startCapture;
ComPtr<IRtwqAsyncResult> startCaptureAsyncResult;
class CallbackSampleReady : public ARtwqAsyncCallback {
public:
CallbackSampleReady(WASAPISource *source)
: ARtwqAsyncCallback(source)
{
}
STDMETHOD(Invoke)
(IRtwqAsyncResult *) override
{
((WASAPISource *)source)->OnSampleReady();
return S_OK;
}
} sampleReady;
ComPtr<IRtwqAsyncResult> sampleReadyAsyncResult;
class CallbackRestart : public ARtwqAsyncCallback {
public:
CallbackRestart(WASAPISource *source)
: ARtwqAsyncCallback(source)
{
}
STDMETHOD(Invoke)
(IRtwqAsyncResult *) override
{
((WASAPISource *)source)->OnRestart();
return S_OK;
}
} restart;
ComPtr<IRtwqAsyncResult> restartAsyncResult;
WinHandle captureThread; WinHandle captureThread;
WinHandle idleSignal; WinHandle idleSignal;
WinHandle stopSignal; WinHandle stopSignal;
@ -94,6 +210,10 @@ public:
void Update(obs_data_t *settings); void Update(obs_data_t *settings);
void SetDefaultDevice(EDataFlow flow, ERole role, LPCWSTR id); void SetDefaultDevice(EDataFlow flow, ERole role, LPCWSTR id);
void OnStartCapture();
void OnSampleReady();
void OnRestart();
}; };
class WASAPINotify : public IMMNotificationClient { class WASAPINotify : public IMMNotificationClient {
@ -149,7 +269,11 @@ public:
WASAPISource::WASAPISource(obs_data_t *settings, obs_source_t *source_, WASAPISource::WASAPISource(obs_data_t *settings, obs_source_t *source_,
bool input) bool input)
: source(source_), isInputDevice(input) : source(source_),
isInputDevice(input),
startCapture(this),
sampleReady(this),
restart(this)
{ {
UpdateSettings(settings); UpdateSettings(settings);
@ -200,11 +324,73 @@ WASAPISource::WASAPISource(obs_data_t *settings, obs_source_t *source_,
if (FAILED(hr)) if (FAILED(hr))
throw HRError("Failed to register endpoint callback", hr); throw HRError("Failed to register endpoint callback", hr);
captureThread = CreateThread(nullptr, 0, WASAPISource::CaptureThread, /* OBS will already load DLL on startup if it exists */
this, 0, nullptr); const HMODULE rtwq_module = GetModuleHandle(L"RTWorkQ.dll");
if (!captureThread.Valid()) { rtwq_supported = rtwq_module != NULL;
enumerator->UnregisterEndpointNotificationCallback(notify); if (rtwq_supported) {
throw "Failed to create capture thread"; rtwq_unlock_work_queue =
(PFN_RtwqUnlockWorkQueue)GetProcAddress(
rtwq_module, "RtwqUnlockWorkQueue");
rtwq_lock_shared_work_queue =
(PFN_RtwqLockSharedWorkQueue)GetProcAddress(
rtwq_module, "RtwqLockSharedWorkQueue");
rtwq_create_async_result =
(PFN_RtwqCreateAsyncResult)GetProcAddress(
rtwq_module, "RtwqCreateAsyncResult");
rtwq_put_work_item = (PFN_RtwqPutWorkItem)GetProcAddress(
rtwq_module, "RtwqPutWorkItem");
rtwq_put_waiting_work_item =
(PFN_RtwqPutWaitingWorkItem)GetProcAddress(
rtwq_module, "RtwqPutWaitingWorkItem");
hr = rtwq_create_async_result(nullptr, &startCapture, nullptr,
&startCaptureAsyncResult);
if (FAILED(hr)) {
enumerator->UnregisterEndpointNotificationCallback(
notify);
throw HRError(
"Could not create startCaptureAsyncResult", hr);
}
hr = rtwq_create_async_result(nullptr, &sampleReady, nullptr,
&sampleReadyAsyncResult);
if (FAILED(hr)) {
enumerator->UnregisterEndpointNotificationCallback(
notify);
throw HRError("Could not create sampleReadyAsyncResult",
hr);
}
hr = rtwq_create_async_result(nullptr, &restart, nullptr,
&restartAsyncResult);
if (FAILED(hr)) {
enumerator->UnregisterEndpointNotificationCallback(
notify);
throw HRError("Could not create restartAsyncResult",
hr);
}
DWORD taskId = 0;
DWORD id = 0;
hr = rtwq_lock_shared_work_queue(L"Capture", 0, &taskId, &id);
if (FAILED(hr)) {
enumerator->UnregisterEndpointNotificationCallback(
notify);
throw HRError("RtwqLockSharedWorkQueue failed", hr);
}
startCapture.SetQueueId(id);
sampleReady.SetQueueId(id);
restart.SetQueueId(id);
} else {
captureThread = CreateThread(nullptr, 0,
WASAPISource::CaptureThread, this,
0, nullptr);
if (!captureThread.Valid()) {
enumerator->UnregisterEndpointNotificationCallback(
notify);
throw "Failed to create capture thread";
}
} }
Start(); Start();
@ -212,7 +398,12 @@ WASAPISource::WASAPISource(obs_data_t *settings, obs_source_t *source_,
void WASAPISource::Start() void WASAPISource::Start()
{ {
SetEvent(initSignal); if (rtwq_supported) {
rtwq_put_work_item(startCapture.GetQueueId(), 0,
startCaptureAsyncResult);
} else {
SetEvent(initSignal);
}
} }
void WASAPISource::Stop() void WASAPISource::Stop()
@ -221,13 +412,19 @@ void WASAPISource::Stop()
blog(LOG_INFO, "WASAPI: Device '%s' Terminated", device_name.c_str()); blog(LOG_INFO, "WASAPI: Device '%s' Terminated", device_name.c_str());
if (rtwq_supported)
SetEvent(receiveSignal);
WaitForSingleObject(idleSignal, INFINITE); WaitForSingleObject(idleSignal, INFINITE);
SetEvent(exitSignal); SetEvent(exitSignal);
WaitForSingleObject(reconnectThread, INFINITE); WaitForSingleObject(reconnectThread, INFINITE);
WaitForSingleObject(captureThread, INFINITE); if (rtwq_supported)
rtwq_unlock_work_queue(sampleReady.GetQueueId());
else
WaitForSingleObject(captureThread, INFINITE);
} }
WASAPISource::~WASAPISource() WASAPISource::~WASAPISource()
@ -444,6 +641,24 @@ void WASAPISource::Initialize()
client = std::move(temp_client); client = std::move(temp_client);
capture = std::move(temp_capture); capture = std::move(temp_capture);
if (rtwq_supported) {
HRESULT hr = rtwq_put_waiting_work_item(
receiveSignal, 0, sampleReadyAsyncResult, nullptr);
if (FAILED(hr)) {
capture.Clear();
client.Clear();
throw HRError("RtwqPutWaitingWorkItem failed", hr);
}
hr = rtwq_put_waiting_work_item(restartSignal, 0,
restartAsyncResult, nullptr);
if (FAILED(hr)) {
capture.Clear();
client.Clear();
throw HRError("RtwqPutWaitingWorkItem failed", hr);
}
}
blog(LOG_INFO, "WASAPI: Device '%s' [%" PRIu32 " Hz] initialized", blog(LOG_INFO, "WASAPI: Device '%s' [%" PRIu32 " Hz] initialized",
device_name.c_str(), sampleRate); device_name.c_str(), sampleRate);
} }
@ -724,6 +939,85 @@ void WASAPISource::SetDefaultDevice(EDataFlow flow, ERole role, LPCWSTR id)
SetEvent(restartSignal); SetEvent(restartSignal);
} }
void WASAPISource::OnStartCapture()
{
const DWORD ret = WaitForSingleObject(stopSignal, 0);
switch (ret) {
case WAIT_OBJECT_0:
SetEvent(idleSignal);
break;
default:
assert(ret == WAIT_TIMEOUT);
if (!TryInitialize()) {
blog(LOG_INFO, "WASAPI: Device '%s' failed to start",
device_id.c_str());
reconnectDuration = RECONNECT_INTERVAL;
SetEvent(reconnectSignal);
}
}
}
void WASAPISource::OnSampleReady()
{
bool stop = false;
bool reconnect = false;
if (!ProcessCaptureData()) {
stop = true;
reconnect = true;
reconnectDuration = RECONNECT_INTERVAL;
}
if (WaitForSingleObject(restartSignal, 0) == WAIT_OBJECT_0) {
stop = true;
reconnect = true;
reconnectDuration = 0;
ResetEvent(restartSignal);
rtwq_put_waiting_work_item(restartSignal, 0, restartAsyncResult,
nullptr);
}
if (WaitForSingleObject(stopSignal, 0) == WAIT_OBJECT_0) {
stop = true;
reconnect = false;
}
if (!stop) {
if (FAILED(rtwq_put_waiting_work_item(receiveSignal, 0,
sampleReadyAsyncResult,
nullptr))) {
blog(LOG_ERROR,
"Could not requeue sample receive work");
stop = true;
reconnect = true;
reconnectDuration = RECONNECT_INTERVAL;
}
}
if (stop) {
client->Stop();
capture.Clear();
client.Clear();
if (reconnect) {
blog(LOG_INFO, "Device '%s' invalidated. Retrying",
device_name.c_str());
SetEvent(reconnectSignal);
} else {
SetEvent(idleSignal);
}
}
}
void WASAPISource::OnRestart()
{
SetEvent(receiveSignal);
}
/* ------------------------------------------------------------------------- */ /* ------------------------------------------------------------------------- */
static const char *GetWASAPIInputName(void *) static const char *GetWASAPIInputName(void *)