// Concrete WASAPI loopback capture and render device.
// This file intentionally owns the real default-endpoint activation and COM
// client setup path instead of pushing another injected boundary below
// `AudioDevice`.

#include "wasapi_audio_device.hpp"

// clang-format off
// initguid.h must precede functiondiscoverykeys_devpkey.h: `DEFINE_PROPERTYKEY`
// and `DEFINE_GUID` emit definitions only when `INITGUID` is set.
#include <initguid.h>
// clang-format on

#include <audioclient.h>
#include <functiondiscoverykeys_devpkey.h>
#include <mmdeviceapi.h>

#include <cstring>

#include "endpoint_audio_format.hpp"
#include "loopback_capture_activation.hpp"

namespace {

using ::endpoint_audio_format::DecodeToStereoFloat;
using ::endpoint_audio_format::StereoPcmBuffer;
using ::endpoint_audio_format::SupportsDirectStereoFloatCopy;

template <typename T>
using ScopedComPtr = std::unique_ptr<T, WasapiAudioDevice::ComRelease>;

using ScopedWaveFormat =
    std::unique_ptr<WAVEFORMATEX, WasapiAudioDevice::CoTaskMemFreeDeleter>;

// Returns true when `failure` indicates the endpoint or session was invalidated
// but a fresh client pair on the current default device may succeed; returns
// false for all other failures.
[[nodiscard]] bool IsRecoverableStreamFailure(HRESULT failure) {
  return failure == AUDCLNT_E_DEVICE_INVALIDATED ||
         failure == AUDCLNT_E_RESOURCES_INVALIDATED ||
         failure == AUDCLNT_E_SERVICE_NOT_RUNNING;
}

// Returns `Ok` when the render mix format is packed float32 stereo (the only
// layout the render path can write directly); returns an error otherwise.
[[nodiscard]] AudioPipelineInterface::Status ValidateRenderMixFormat(
    const WAVEFORMATEX& render_format) {
  if (SupportsDirectStereoFloatCopy(render_format)) {
    return AudioPipelineInterface::Status::Ok();
  }
  return AudioPipelineInterface::Status::Error(
      AUDCLNT_E_UNSUPPORTED_FORMAT,
      L"Render mix format is not packed float32 stereo; "
      L"conversion path is unavailable");
}

// Initializes the audio client in shared mode for rendering. Returns `Ok` on
// success; otherwise returns the failing HRESULT.
[[nodiscard]] AudioPipelineInterface::Status InitializeRenderClient(
    IAudioClient& audio_client, WAVEFORMATEX* render_format) {
  // 20 ms is the lowest buffer duration that avoids glitches on most Windows
  // hardware while keeping latency perceptually invisible. WASAPI measures
  // time in 100-nanosecond units; 20 ms = 200,000 units.
  constexpr REFERENCE_TIME kRenderBufferDuration20Ms = 200'000;

  if (const HRESULT initialize_render =
          audio_client.Initialize(AUDCLNT_SHAREMODE_SHARED, /*StreamFlags=*/0,
                                  kRenderBufferDuration20Ms,
                                  /*hnsPeriodicity=*/0, render_format,
                                  /*audioSessionGuid=*/nullptr);
      FAILED(initialize_render)) {
    return AudioPipelineInterface::Status::Error(
        initialize_render, L"IAudioClient::Initialize (render) failed");
  }
  return AudioPipelineInterface::Status::Ok();
}

// Queries the `IAudioRenderClient` service from an initialized audio client.
// Returns `Ok` on success; otherwise returns the failing HRESULT.
[[nodiscard]] AudioPipelineInterface::Status AcquireRenderClientService(
    IAudioClient& audio_client, IAudioRenderClient*& raw_audio_render_client) {
  if (const HRESULT render_service = audio_client.GetService(
          __uuidof(IAudioRenderClient),
          reinterpret_cast<void**>(&raw_audio_render_client));
      FAILED(render_service)) {
    return AudioPipelineInterface::Status::Error(
        render_service, L"GetService IAudioRenderClient failed");
  }
  return AudioPipelineInterface::Status::Ok();
}

// Initializes the audio client in shared mode for loopback capture. The
// process loopback activation selects which audio is captured, but the stream
// flag `AUDCLNT_STREAMFLAGS_LOOPBACK` is still required. Buffer duration 0
// lets WASAPI choose the optimal size for process loopback. Returns `Ok` on
// success; otherwise returns the failing HRESULT.
[[nodiscard]] AudioPipelineInterface::Status InitializeCaptureClient(
    IAudioClient& audio_client, WAVEFORMATEX* capture_format) {
  if (const HRESULT initialize_capture = audio_client.Initialize(
          AUDCLNT_SHAREMODE_SHARED, AUDCLNT_STREAMFLAGS_LOOPBACK,
          /*hnsBufferDuration=*/0,
          /*hnsPeriodicity=*/0, capture_format,
          /*audioSessionGuid=*/nullptr);
      FAILED(initialize_capture)) {
    return AudioPipelineInterface::Status::Error(
        initialize_capture, L"IAudioClient::Initialize (loopback) failed");
  }
  return AudioPipelineInterface::Status::Ok();
}

// Queries the `IAudioCaptureClient` service from an initialized audio client.
// Returns `Ok` on success; otherwise returns the failing HRESULT.
[[nodiscard]] AudioPipelineInterface::Status AcquireCaptureClientService(
    IAudioClient& audio_client,
    IAudioCaptureClient*& raw_audio_capture_client) {
  if (const HRESULT capture_service = audio_client.GetService(
          __uuidof(IAudioCaptureClient),
          reinterpret_cast<void**>(&raw_audio_capture_client));
      FAILED(capture_service)) {
    return AudioPipelineInterface::Status::Error(
        capture_service, L"GetService IAudioCaptureClient failed");
  }
  return AudioPipelineInterface::Status::Ok();
}

// Initializes the capture client and acquires its capture service in one step.
// Returns `Ok` on success; otherwise returns the first failing status.
[[nodiscard]] AudioPipelineInterface::Status InitializeCaptureStream(
    IAudioClient& audio_client, WAVEFORMATEX* capture_format,
    IAudioCaptureClient*& raw_audio_capture_client) {
  if (const AudioPipelineInterface::Status capture_init =
          InitializeCaptureClient(audio_client, capture_format);
      !capture_init.ok()) {
    return capture_init;
  }
  return AcquireCaptureClientService(audio_client, raw_audio_capture_client);
}

// Writes the friendly name of `render_device` into `endpoint_name`. Falls back
// to "Default Render Device" when the property store is unavailable or the name
// is empty.
void ReadEndpointName(IMMDevice* render_device, std::wstring& endpoint_name) {
  if (render_device == nullptr) {
    endpoint_name = L"Default Render Device";
    return;
  }

  IPropertyStore* raw_props = nullptr;
  if (FAILED(render_device->OpenPropertyStore(STGM_READ, &raw_props))) {
    endpoint_name = L"Default Render Device";
    return;
  }

  ScopedComPtr<IPropertyStore> props(raw_props);
  PROPVARIANT prop_variant;
  PropVariantInit(&prop_variant);
  if (SUCCEEDED(props->GetValue(PKEY_Device_FriendlyName, &prop_variant)) &&
      prop_variant.vt == VT_LPWSTR) {
    endpoint_name = prop_variant.pwszVal;
  }
  PropVariantClear(&prop_variant);
  if (endpoint_name.empty()) {
    endpoint_name = L"Default Render Device";
  }
}

// Result of acquiring the default render endpoint. On failure, `status` carries
// the failing HRESULT and message; the remaining fields are empty.
struct EndpointAcquisition {
  AudioPipelineInterface::Status status;
  ScopedComPtr<IMMDeviceEnumerator> enumerator;
  ScopedComPtr<IMMDevice> render_device;
  std::wstring endpoint_name;
};

// Returns the current default render endpoint and its friendly name.
[[nodiscard]] EndpointAcquisition AcquireEndpoint() {
  EndpointAcquisition endpoint;

  IMMDeviceEnumerator* raw_enum = nullptr;
  if (const HRESULT create_enumerator = CoCreateInstance(
          __uuidof(MMDeviceEnumerator),
          /*pUnkOuter=*/nullptr, CLSCTX_ALL, __uuidof(IMMDeviceEnumerator),
          reinterpret_cast<void**>(&raw_enum));
      FAILED(create_enumerator)) {
    endpoint.status = AudioPipelineInterface::Status::Error(
        create_enumerator, L"CoCreateInstance(MMDeviceEnumerator) failed");
    return endpoint;
  }
  endpoint.enumerator.reset(raw_enum);

  IMMDevice* raw_device = nullptr;
  if (const HRESULT default_endpoint =
          endpoint.enumerator->GetDefaultAudioEndpoint(eRender, eConsole,
                                                       &raw_device);
      FAILED(default_endpoint)) {
    endpoint.status = AudioPipelineInterface::Status::Error(
        default_endpoint, L"GetDefaultAudioEndpoint failed");
    return endpoint;
  }
  endpoint.render_device.reset(raw_device);

  ReadEndpointName(endpoint.render_device.get(), endpoint.endpoint_name);
  return endpoint;
}

// WASAPI render client, its buffer-writing service, and the negotiated mix
// format, paired so they can be set up together and moved into the device as a
// unit.
struct RenderClientSetup {
  // Owns the render stream configuration and lifetime.
  ScopedComPtr<IAudioClient> audio_client;
  // Writes processed frames into the render buffer owned by `audio_client`.
  ScopedComPtr<IAudioRenderClient> service;
  // Negotiated mix format; also used as the capture format because process
  // loopback captures in whatever format the render endpoint uses.
  ScopedWaveFormat format;
};

// WASAPI loopback capture client and its packet-pulling service, paired so they
// can be set up together and moved into the device as a unit.
struct CaptureClientSetup {
  // Owns the loopback stream configuration and lifetime.
  ScopedComPtr<IAudioClient> audio_client;
  // Pulls captured packets from `audio_client`.
  ScopedComPtr<IAudioCaptureClient> service;
};

// Activates a shared-mode render client on `render_device`, validates the mix
// format is float32 stereo, and initializes the render stream. Populates
// `setup` on success; returns the first failing status otherwise.
[[nodiscard]] AudioPipelineInterface::Status SetupRenderClient(
    IMMDevice& render_device, RenderClientSetup& setup) {
  IAudioClient* raw_render = nullptr;
  if (const HRESULT activate_render = render_device.Activate(
          __uuidof(IAudioClient), CLSCTX_ALL,
          /*pActivationParams=*/nullptr, reinterpret_cast<void**>(&raw_render));
      FAILED(activate_render)) {
    return AudioPipelineInterface::Status::Error(
        activate_render, L"Activate render IAudioClient failed");
  }
  setup.audio_client.reset(raw_render);

  WAVEFORMATEX* raw_render_format = nullptr;
  if (const HRESULT render_mix_format =
          setup.audio_client->GetMixFormat(&raw_render_format);
      FAILED(render_mix_format)) {
    return AudioPipelineInterface::Status::Error(
        render_mix_format, L"GetMixFormat (render) failed");
  }
  setup.format.reset(raw_render_format);
  if (const AudioPipelineInterface::Status format_check =
          ValidateRenderMixFormat(*setup.format);
      !format_check.ok()) {
    return format_check;
  }

  if (const AudioPipelineInterface::Status render_init =
          InitializeRenderClient(*setup.audio_client, setup.format.get());
      !render_init.ok()) {
    return render_init;
  }

  IAudioRenderClient* raw_audio_render_client = nullptr;
  if (const AudioPipelineInterface::Status render_service =
          AcquireRenderClientService(*setup.audio_client,
                                     raw_audio_render_client);
      !render_service.ok()) {
    return render_service;
  }
  setup.service.reset(raw_audio_render_client);
  return AudioPipelineInterface::Status::Ok();
}

// Activates a loopback capture client using the process loopback API and
// initializes it with `render_format`. Populates `setup` on success; returns
// the first failing status otherwise.
[[nodiscard]] AudioPipelineInterface::Status SetupCaptureClient(
    WAVEFORMATEX* render_format, CaptureClientSetup& setup) {
  IAudioClient* raw_capture = nullptr;
  if (const AudioPipelineInterface::Status activate =
          ActivateLoopbackCaptureClient(raw_capture);
      !activate.ok()) {
    return activate;
  }
  setup.audio_client.reset(raw_capture);

  IAudioCaptureClient* raw_client = nullptr;
  if (const AudioPipelineInterface::Status capture_stream =
          InitializeCaptureStream(*setup.audio_client, render_format,
                                  raw_client);
      !capture_stream.ok()) {
    return capture_stream;
  }
  setup.service.reset(raw_client);
  return AudioPipelineInterface::Status::Ok();
}

// Combined capture and render setup, built during startup and recovery then
// moved into the device's long-lived members.
struct StreamClientSetup {
  CaptureClientSetup capture;
  RenderClientSetup render;
};

// Sets up the render client first (to obtain the mix format), then activates
// the process loopback capture client using that same format. Returns `Ok`
// when both clients are ready; returns the first failing status otherwise.
[[nodiscard]] AudioPipelineInterface::Status SetupStreamClients(
    IMMDevice& render_device, StreamClientSetup& clients) {
  if (const AudioPipelineInterface::Status render =
          SetupRenderClient(render_device, clients.render);
      !render.ok()) {
    return render;
  }
  if (const AudioPipelineInterface::Status capture =
          SetupCaptureClient(clients.render.format.get(), clients.capture);
      !capture.ok()) {
    return capture;
  }
  return AudioPipelineInterface::Status::Ok();
}

}  // namespace

WasapiAudioDevice::WasapiAudioDevice() : WasapiAudioDevice(Dependencies{}) {}

WasapiAudioDevice::WasapiAudioDevice(Dependencies dependencies)
    : capture_client_(dependencies.capture_client),
      render_client_(dependencies.render_client),
      capture_service_(dependencies.capture_service),
      render_service_(dependencies.render_service),
      format_(dependencies.format),
      endpoint_name_(std::move(dependencies.endpoint_name)) {}

WasapiAudioDevice::~WasapiAudioDevice() { Close(); }

AudioPipelineInterface::Status WasapiAudioDevice::Open() {
  EndpointAcquisition endpoint = AcquireEndpoint();
  if (!endpoint.status.ok()) {
    return endpoint.status;
  }
  StreamClientSetup clients;
  if (const AudioPipelineInterface::Status status =
          SetupStreamClients(*endpoint.render_device, clients);
      !status.ok()) {
    return status;
  }
  enumerator_ = std::move(endpoint.enumerator);
  render_device_ = std::move(endpoint.render_device);
  endpoint_name_ = std::move(endpoint.endpoint_name);
  capture_client_ = std::move(clients.capture.audio_client);
  capture_service_ = std::move(clients.capture.service);
  render_client_ = std::move(clients.render.audio_client);
  render_service_ = std::move(clients.render.service);
  format_ = std::move(clients.render.format);
  return AudioPipelineInterface::Status::Ok();
}

AudioPipelineInterface::Status WasapiAudioDevice::StartStreams() {
  if (capture_client_ == nullptr || render_client_ == nullptr) {
    return AudioPipelineInterface::Status::Error(
        E_POINTER, L"Capture or render client is null; call `Open()` first");
  }

  if (const HRESULT capture_start = capture_client_->Start();
      FAILED(capture_start)) {
    return AudioPipelineInterface::Status::Error(
        capture_start, L"Capture client `Start()` failed");
  }

  if (const HRESULT render_start = render_client_->Start();
      FAILED(render_start)) {
    capture_client_->Stop();
    return AudioPipelineInterface::Status::Error(
        render_start, L"Render client `Start()` failed");
  }
  return AudioPipelineInterface::Status::Ok();
}

void WasapiAudioDevice::StopStreams() {
  if (capture_client_ != nullptr) {
    capture_client_->Stop();
  }
  if (render_client_ != nullptr) {
    render_client_->Stop();
  }
}

void WasapiAudioDevice::Close() {
  StopStreams();
  capture_service_ = nullptr;
  render_service_ = nullptr;
  capture_client_ = nullptr;
  render_client_ = nullptr;
  format_ = nullptr;
  render_device_ = nullptr;
  enumerator_ = nullptr;
}

CapturePacket WasapiAudioDevice::ReadNextPacket() {
  CapturePacket packet;
  if (capture_service_ == nullptr || format_ == nullptr) {
    packet.status = E_POINTER;
    return packet;
  }

  UINT32 packet_size = 0;
  if (const HRESULT query = capture_service_->GetNextPacketSize(&packet_size);
      FAILED(query)) {
    packet.status = query;
    return packet;
  }
  if (packet_size == 0) {
    return packet;
  }

  BYTE* capture_bytes = nullptr;
  UINT32 frames = 0;
  DWORD flags = 0;
  if (const HRESULT get_buf = capture_service_->GetBuffer(
          &capture_bytes, &frames, &flags,
          /*device_position=*/nullptr, /*qpc_position=*/nullptr);
      FAILED(get_buf)) {
    packet.status = get_buf;
    return packet;
  }

  packet.frames = frames;
  packet.silent = (flags & AUDCLNT_BUFFERFLAGS_SILENT) != 0;

  if (!packet.silent && frames > 0) {
    StereoPcmBuffer buf = DecodeToStereoFloat(capture_bytes, frames, *format_);
    packet.samples = std::move(buf.samples);
    packet.frames = buf.frames;
  }

  capture_service_->ReleaseBuffer(frames);
  return packet;
}

HRESULT WasapiAudioDevice::WriteRenderPacket(std::span<const float> pcm) {
  if (render_client_ == nullptr || render_service_ == nullptr) {
    return E_POINTER;
  }

  const UINT32 frames = static_cast<UINT32>(pcm.size() / 2);
  if (frames == 0) {
    return S_OK;
  }

  UINT32 render_buf_size = 0;
  UINT32 render_padding = 0;
  if (const HRESULT status = render_client_->GetBufferSize(&render_buf_size);
      FAILED(status)) {
    return status;
  }
  if (const HRESULT status = render_client_->GetCurrentPadding(&render_padding);
      FAILED(status)) {
    return status;
  }

  const UINT32 available = render_buf_size - render_padding;
  if (available < frames) {
    return S_FALSE;
  }

  BYTE* render_buf = nullptr;
  if (const HRESULT status = render_service_->GetBuffer(frames, &render_buf);
      FAILED(status)) {
    return status;
  }
  if (render_buf == nullptr) {
    render_service_->ReleaseBuffer(frames, AUDCLNT_BUFFERFLAGS_SILENT);
    return E_POINTER;
  }

  std::memcpy(render_buf, pcm.data(), pcm.size_bytes());
  return render_service_->ReleaseBuffer(frames, /*dwFlags=*/0);
}

bool WasapiAudioDevice::TryRecover(HRESULT failure) {
  if (!IsRecoverableStreamFailure(failure)) {
    return false;
  }

  StopStreams();

  EndpointAcquisition endpoint = AcquireEndpoint();
  if (!endpoint.status.ok()) {
    return false;
  }

  StreamClientSetup clients;
  if (!SetupStreamClients(*endpoint.render_device, clients).ok()) {
    return false;
  }

  enumerator_ = std::move(endpoint.enumerator);
  render_device_ = std::move(endpoint.render_device);
  endpoint_name_ = std::move(endpoint.endpoint_name);
  capture_client_ = std::move(clients.capture.audio_client);
  capture_service_ = std::move(clients.capture.service);
  render_client_ = std::move(clients.render.audio_client);
  render_service_ = std::move(clients.render.service);
  format_ = std::move(clients.render.format);

  if (FAILED(capture_client_->Start())) {
    return false;
  }
  return SUCCEEDED(render_client_->Start());
}

double WasapiAudioDevice::sample_rate() const {
  if (format_ == nullptr) {
    return 0.0;
  }
  return static_cast<double>(format_->nSamplesPerSec);
}

const std::wstring& WasapiAudioDevice::endpoint_name() const {
  return endpoint_name_;
}

Generated by OpenCppCoverage (Version: 0.9.9.0)