diff --git a/plugins/win-wasapi/win-wasapi.cpp b/plugins/win-wasapi/win-wasapi.cpp index 801fd9c10..757d9f2c5 100644 --- a/plugins/win-wasapi/win-wasapi.cpp +++ b/plugins/win-wasapi/win-wasapi.cpp @@ -14,6 +14,7 @@ #include #include +#include using namespace std; @@ -25,6 +26,66 @@ static void GetWASAPIDefaults(obs_data_t *settings); #define OBS_KSAUDIO_SPEAKER_4POINT1 \ (KSAUDIO_SPEAKER_SURROUND | SPEAKER_LOW_FREQUENCY) +typedef HRESULT(STDAPICALLTYPE *PFN_RtwqUnlockWorkQueue)(DWORD); +typedef HRESULT(STDAPICALLTYPE *PFN_RtwqLockSharedWorkQueue)(PCWSTR usageClass, + LONG basePriority, + DWORD *taskId, + DWORD *id); +typedef HRESULT(STDAPICALLTYPE *PFN_RtwqCreateAsyncResult)(IUnknown *, + IRtwqAsyncCallback *, + IUnknown *, + IRtwqAsyncResult **); +typedef HRESULT(STDAPICALLTYPE *PFN_RtwqPutWorkItem)(DWORD, LONG, + IRtwqAsyncResult *); +typedef HRESULT(STDAPICALLTYPE *PFN_RtwqPutWaitingWorkItem)(HANDLE, LONG, + IRtwqAsyncResult *, + RTWQWORKITEM_KEY *); + +class ARtwqAsyncCallback : public IRtwqAsyncCallback { +protected: + ARtwqAsyncCallback(void *source) : source(source) {} + +public: + STDMETHOD_(ULONG, AddRef)() { return ++refCount; } + + STDMETHOD_(ULONG, Release)() { return --refCount; } + + STDMETHOD(QueryInterface)(REFIID riid, void **ppvObject) + { + HRESULT hr = E_NOINTERFACE; + + if (riid == __uuidof(IRtwqAsyncCallback) || + riid == __uuidof(IUnknown)) { + *ppvObject = this; + AddRef(); + hr = S_OK; + } else { + *ppvObject = NULL; + } + + return hr; + } + + STDMETHOD(GetParameters) + (DWORD *pdwFlags, DWORD *pdwQueue) + { + *pdwFlags = 0; + *pdwQueue = queue_id; + return S_OK; + } + + STDMETHOD(Invoke) + (IRtwqAsyncResult *) override = 0; + + DWORD GetQueueId() const { return queue_id; } + void SetQueueId(DWORD id) { queue_id = id; } + +protected: + std::atomic refCount = 1; + void *source; + DWORD queue_id = 0; +}; + class WASAPISource { ComPtr notify; ComPtr enumerator; @@ -35,6 +96,12 @@ class WASAPISource { wstring default_id; string device_id; string device_name; + PFN_RtwqUnlockWorkQueue rtwq_unlock_work_queue = NULL; + PFN_RtwqLockSharedWorkQueue rtwq_lock_shared_work_queue = NULL; + PFN_RtwqCreateAsyncResult rtwq_create_async_result = NULL; + PFN_RtwqPutWorkItem rtwq_put_work_item = NULL; + PFN_RtwqPutWaitingWorkItem rtwq_put_waiting_work_item = NULL; + bool rtwq_supported = false; uint64_t lastNotifyTime = 0; bool isInputDevice; std::atomic useDeviceTiming = false; @@ -43,6 +110,55 @@ class WASAPISource { bool previouslyFailed = false; WinHandle reconnectThread; + class CallbackStartCapture : public ARtwqAsyncCallback { + public: + CallbackStartCapture(WASAPISource *source) + : ARtwqAsyncCallback(source) + { + } + + STDMETHOD(Invoke) + (IRtwqAsyncResult *) override + { + ((WASAPISource *)source)->OnStartCapture(); + return S_OK; + } + + } startCapture; + ComPtr startCaptureAsyncResult; + + class CallbackSampleReady : public ARtwqAsyncCallback { + public: + CallbackSampleReady(WASAPISource *source) + : ARtwqAsyncCallback(source) + { + } + + STDMETHOD(Invoke) + (IRtwqAsyncResult *) override + { + ((WASAPISource *)source)->OnSampleReady(); + return S_OK; + } + } sampleReady; + ComPtr sampleReadyAsyncResult; + + class CallbackRestart : public ARtwqAsyncCallback { + public: + CallbackRestart(WASAPISource *source) + : ARtwqAsyncCallback(source) + { + } + + STDMETHOD(Invoke) + (IRtwqAsyncResult *) override + { + ((WASAPISource *)source)->OnRestart(); + return S_OK; + } + } restart; + ComPtr restartAsyncResult; + WinHandle captureThread; WinHandle idleSignal; WinHandle stopSignal; @@ -94,6 +210,10 @@ public: void Update(obs_data_t *settings); void SetDefaultDevice(EDataFlow flow, ERole role, LPCWSTR id); + + void OnStartCapture(); + void OnSampleReady(); + void OnRestart(); }; class WASAPINotify : public IMMNotificationClient { @@ -149,7 +269,11 @@ public: WASAPISource::WASAPISource(obs_data_t *settings, obs_source_t *source_, bool input) - : source(source_), isInputDevice(input) + : source(source_), + isInputDevice(input), + startCapture(this), + sampleReady(this), + restart(this) { UpdateSettings(settings); @@ -200,11 +324,73 @@ WASAPISource::WASAPISource(obs_data_t *settings, obs_source_t *source_, if (FAILED(hr)) throw HRError("Failed to register endpoint callback", hr); - captureThread = CreateThread(nullptr, 0, WASAPISource::CaptureThread, - this, 0, nullptr); - if (!captureThread.Valid()) { - enumerator->UnregisterEndpointNotificationCallback(notify); - throw "Failed to create capture thread"; + /* OBS will already load DLL on startup if it exists */ + const HMODULE rtwq_module = GetModuleHandle(L"RTWorkQ.dll"); + rtwq_supported = rtwq_module != NULL; + if (rtwq_supported) { + rtwq_unlock_work_queue = + (PFN_RtwqUnlockWorkQueue)GetProcAddress( + rtwq_module, "RtwqUnlockWorkQueue"); + rtwq_lock_shared_work_queue = + (PFN_RtwqLockSharedWorkQueue)GetProcAddress( + rtwq_module, "RtwqLockSharedWorkQueue"); + rtwq_create_async_result = + (PFN_RtwqCreateAsyncResult)GetProcAddress( + rtwq_module, "RtwqCreateAsyncResult"); + rtwq_put_work_item = (PFN_RtwqPutWorkItem)GetProcAddress( + rtwq_module, "RtwqPutWorkItem"); + rtwq_put_waiting_work_item = + (PFN_RtwqPutWaitingWorkItem)GetProcAddress( + rtwq_module, "RtwqPutWaitingWorkItem"); + + hr = rtwq_create_async_result(nullptr, &startCapture, nullptr, + &startCaptureAsyncResult); + if (FAILED(hr)) { + enumerator->UnregisterEndpointNotificationCallback( + notify); + throw HRError( + "Could not create startCaptureAsyncResult", hr); + } + + hr = rtwq_create_async_result(nullptr, &sampleReady, nullptr, + &sampleReadyAsyncResult); + if (FAILED(hr)) { + enumerator->UnregisterEndpointNotificationCallback( + notify); + throw HRError("Could not create sampleReadyAsyncResult", + hr); + } + + hr = rtwq_create_async_result(nullptr, &restart, nullptr, + &restartAsyncResult); + if (FAILED(hr)) { + enumerator->UnregisterEndpointNotificationCallback( + notify); + throw HRError("Could not create restartAsyncResult", + hr); + } + + DWORD taskId = 0; + DWORD id = 0; + hr = rtwq_lock_shared_work_queue(L"Capture", 0, &taskId, &id); + if (FAILED(hr)) { + enumerator->UnregisterEndpointNotificationCallback( + notify); + throw HRError("RtwqLockSharedWorkQueue failed", hr); + } + + startCapture.SetQueueId(id); + sampleReady.SetQueueId(id); + restart.SetQueueId(id); + } else { + captureThread = CreateThread(nullptr, 0, + WASAPISource::CaptureThread, this, + 0, nullptr); + if (!captureThread.Valid()) { + enumerator->UnregisterEndpointNotificationCallback( + notify); + throw "Failed to create capture thread"; + } } Start(); @@ -212,7 +398,12 @@ WASAPISource::WASAPISource(obs_data_t *settings, obs_source_t *source_, void WASAPISource::Start() { - SetEvent(initSignal); + if (rtwq_supported) { + rtwq_put_work_item(startCapture.GetQueueId(), 0, + startCaptureAsyncResult); + } else { + SetEvent(initSignal); + } } void WASAPISource::Stop() @@ -221,13 +412,19 @@ void WASAPISource::Stop() blog(LOG_INFO, "WASAPI: Device '%s' Terminated", device_name.c_str()); + if (rtwq_supported) + SetEvent(receiveSignal); + WaitForSingleObject(idleSignal, INFINITE); SetEvent(exitSignal); WaitForSingleObject(reconnectThread, INFINITE); - WaitForSingleObject(captureThread, INFINITE); + if (rtwq_supported) + rtwq_unlock_work_queue(sampleReady.GetQueueId()); + else + WaitForSingleObject(captureThread, INFINITE); } WASAPISource::~WASAPISource() @@ -444,6 +641,24 @@ void WASAPISource::Initialize() client = std::move(temp_client); capture = std::move(temp_capture); + if (rtwq_supported) { + HRESULT hr = rtwq_put_waiting_work_item( + receiveSignal, 0, sampleReadyAsyncResult, nullptr); + if (FAILED(hr)) { + capture.Clear(); + client.Clear(); + throw HRError("RtwqPutWaitingWorkItem failed", hr); + } + + hr = rtwq_put_waiting_work_item(restartSignal, 0, + restartAsyncResult, nullptr); + if (FAILED(hr)) { + capture.Clear(); + client.Clear(); + throw HRError("RtwqPutWaitingWorkItem failed", hr); + } + } + blog(LOG_INFO, "WASAPI: Device '%s' [%" PRIu32 " Hz] initialized", device_name.c_str(), sampleRate); } @@ -724,6 +939,85 @@ void WASAPISource::SetDefaultDevice(EDataFlow flow, ERole role, LPCWSTR id) SetEvent(restartSignal); } +void WASAPISource::OnStartCapture() +{ + const DWORD ret = WaitForSingleObject(stopSignal, 0); + switch (ret) { + case WAIT_OBJECT_0: + SetEvent(idleSignal); + break; + + default: + assert(ret == WAIT_TIMEOUT); + + if (!TryInitialize()) { + blog(LOG_INFO, "WASAPI: Device '%s' failed to start", + device_id.c_str()); + reconnectDuration = RECONNECT_INTERVAL; + SetEvent(reconnectSignal); + } + } +} + +void WASAPISource::OnSampleReady() +{ + bool stop = false; + bool reconnect = false; + + if (!ProcessCaptureData()) { + stop = true; + reconnect = true; + reconnectDuration = RECONNECT_INTERVAL; + } + + if (WaitForSingleObject(restartSignal, 0) == WAIT_OBJECT_0) { + stop = true; + reconnect = true; + reconnectDuration = 0; + + ResetEvent(restartSignal); + rtwq_put_waiting_work_item(restartSignal, 0, restartAsyncResult, + nullptr); + } + + if (WaitForSingleObject(stopSignal, 0) == WAIT_OBJECT_0) { + stop = true; + reconnect = false; + } + + if (!stop) { + if (FAILED(rtwq_put_waiting_work_item(receiveSignal, 0, + sampleReadyAsyncResult, + nullptr))) { + blog(LOG_ERROR, + "Could not requeue sample receive work"); + stop = true; + reconnect = true; + reconnectDuration = RECONNECT_INTERVAL; + } + } + + if (stop) { + client->Stop(); + + capture.Clear(); + client.Clear(); + + if (reconnect) { + blog(LOG_INFO, "Device '%s' invalidated. Retrying", + device_name.c_str()); + SetEvent(reconnectSignal); + } else { + SetEvent(idleSignal); + } + } +} + +void WASAPISource::OnRestart() +{ + SetEvent(receiveSignal); +} + /* ------------------------------------------------------------------------- */ static const char *GetWASAPIInputName(void *)