diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc index fadb425c2ad6d..45b0447d2d9be 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.cc +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -304,6 +304,25 @@ class GraphCacheManager : public IBufferCacheManager { // no-op - buffers are already in buckets_ } + std::vector> ExtractCachedBuffers() override { + std::vector> result; + for (auto& pair : buckets_) { + for (auto& buffer : pair.second) { + result.emplace_back(pair.first, buffer); + } + pair.second.clear(); + } + return result; + } + + void AbsorbCachedBuffers(std::vector>&& buffers) override { + for (auto& entry : buffers) { + if (entry.second) { + ReleaseBuffer(entry.second); + } + } + } + ~GraphCacheManager() { for (auto& pair : buckets_) { for (auto& buffer : pair.second) { @@ -386,6 +405,37 @@ class GraphSimpleCacheManager : public IBufferCacheManager { } } + std::vector> ExtractCachedBuffers() override { + // Donation is expected after captured commands have been released and any + // in-flight work has completed; all three containers therefore hold buffers + // no longer referenced by the device. + std::vector> result; + for (auto& pair : buffers_) { + for (auto& buffer : pair.second) { + result.emplace_back(pair.first, buffer); + } + pair.second.clear(); + } + buffers_.clear(); + for (auto& buffer : pending_buffers_) { + result.emplace_back(static_cast(wgpuBufferGetSize(buffer)), buffer); + } + pending_buffers_.clear(); + for (auto& buffer : captured_buffers_) { + result.emplace_back(static_cast(wgpuBufferGetSize(buffer)), buffer); + } + captured_buffers_.clear(); + return result; + } + + void AbsorbCachedBuffers(std::vector>&& buffers) override { + for (auto& entry : buffers) { + if (entry.second) { + buffers_[entry.first].emplace_back(entry.second); + } + } + } + protected: std::map> buffers_; std::vector pending_buffers_; diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.h b/onnxruntime/core/providers/webgpu/buffer_manager.h index 05a79726d71e4..9bd1f54711971 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.h +++ b/onnxruntime/core/providers/webgpu/buffer_manager.h @@ -4,6 +4,8 @@ #pragma once #include +#include +#include #include "core/providers/webgpu/webgpu_external_header.h" @@ -60,6 +62,24 @@ class IBufferCacheManager { // when a stream refresh is requested virtual void OnRefresh(GraphCaptureState graph_capture_state) = 0; + + // Extract all cached buffers from this manager, transferring ownership to the + // caller. The cache's internal containers are cleared (but bucket keys, if any, + // are preserved). Default returns empty; only graph-mode caches implement this. + virtual std::vector> ExtractCachedBuffers() { + return {}; + } + + // Accept buffers donated from another cache and take ownership of them. Caches + // that cannot store the buffers must release them via wgpuBufferRelease (the + // default below) to avoid leaks. + virtual void AbsorbCachedBuffers(std::vector>&& buffers) { + for (auto& entry : buffers) { + if (entry.second) { + wgpuBufferRelease(entry.second); + } + } + } }; // @@ -76,6 +96,11 @@ class BufferManager { void Download(WGPUBuffer src, void* dst, size_t size) const; void RefreshPendingBuffers(GraphCaptureState graph_capture_state) const; + // Direct access to the underlying cache managers. Used by SessionBufferPool to + // donate/seed buffers across per-graph BufferManager lifetimes. + IBufferCacheManager& StorageCache() { return *storage_cache_; } + IBufferCacheManager& UniformCache() { return *uniform_cache_; } + private: IBufferCacheManager& GetCacheManager(wgpu::BufferUsage usage) const; IBufferCacheManager& GetCacheManager(WGPUBuffer buffer) const; diff --git a/onnxruntime/core/providers/webgpu/session_buffer_pool.cc b/onnxruntime/core/providers/webgpu/session_buffer_pool.cc new file mode 100644 index 0000000000000..5d647f33b12d0 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/session_buffer_pool.cc @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/session_buffer_pool.h" + +#include "core/providers/webgpu/buffer_manager.h" + +namespace onnxruntime { +namespace webgpu { + +namespace { +void ReleaseSlotBuffers(std::vector>& entries) { + for (auto& entry : entries) { + if (entry.second) { + wgpuBufferRelease(entry.second); + } + } + entries.clear(); +} +} // namespace + +SessionBufferPool::SessionBufferPool(size_t max_generations) + : max_generations_{max_generations} { +} + +SessionBufferPool::~SessionBufferPool() { + Clear(); +} + +void SessionBufferPool::Donate(BufferManager& retiring_mgr) { + if (max_generations_ == 0) { + return; + } + + Slot slot; + slot.storage = retiring_mgr.StorageCache().ExtractCachedBuffers(); + slot.uniform = retiring_mgr.UniformCache().ExtractCachedBuffers(); + + if (slot.storage.empty() && slot.uniform.empty()) { + return; + } + + // Evict the oldest slot if at capacity so the freshest buffers (which most + // accurately reflect the current per-generator shape distribution) are kept. + while (slots_.size() >= max_generations_) { + auto& victim = slots_.front(); + ReleaseSlotBuffers(victim.storage); + ReleaseSlotBuffers(victim.uniform); + slots_.erase(slots_.begin()); + } + + slots_.emplace_back(std::move(slot)); +} + +void SessionBufferPool::SeedInto(BufferManager& new_mgr) { + if (slots_.empty()) { + return; + } + Slot slot = std::move(slots_.back()); + slots_.pop_back(); + new_mgr.StorageCache().AbsorbCachedBuffers(std::move(slot.storage)); + new_mgr.UniformCache().AbsorbCachedBuffers(std::move(slot.uniform)); +} + +void SessionBufferPool::Clear() { + for (auto& slot : slots_) { + ReleaseSlotBuffers(slot.storage); + ReleaseSlotBuffers(slot.uniform); + } + slots_.clear(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/session_buffer_pool.h b/onnxruntime/core/providers/webgpu/session_buffer_pool.h new file mode 100644 index 0000000000000..edec2b613dfd6 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/session_buffer_pool.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/providers/webgpu/webgpu_external_header.h" + +namespace onnxruntime { +namespace webgpu { + +class BufferManager; + +// SessionBufferPool retains buffers from retired per-graph BufferManagers so +// that subsequent generators on the same session can reuse them instead of +// allocating from the device. Scoped per WebGpuExecutionProvider (per session) +// because intermediate buffer shapes are model-dependent. +class SessionBufferPool { + public: + explicit SessionBufferPool(size_t max_generations); + + ~SessionBufferPool(); + + // Move freed buffers from a retiring per-graph BufferManager into the pool. + // When the pool is at capacity, the oldest slot is evicted (its buffers + // released to the device) so the freshest buffers are always retained. This + // lets the pool adapt when intermediate buffer shapes change between + // generators (for example when max_length differs). + void Donate(BufferManager& retiring_mgr); + + // Pre-populate a newly created per-graph BufferManager with one slot worth of + // pooled buffers (LIFO). No-op if the pool is empty. + void SeedInto(BufferManager& new_mgr); + + // Release all pooled buffers. Called on session teardown. + void Clear(); + + size_t Size() const { return slots_.size(); } + size_t Capacity() const { return max_generations_; } + + private: + struct Slot { + std::vector> storage; + std::vector> uniform; + }; + std::vector slots_; + size_t max_generations_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 7e11ddf6b13a0..a588fee464260 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -582,6 +582,10 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, multi_rotary_cache_concat_offset_{config.multi_rotary_cache_concat_offset}, prepack_allocator_{std::make_shared( [this]() -> const webgpu::BufferManager& { return context_.InitializerBufferManager(); }, false)} { + if (enable_graph_capture_ && config.session_buffer_pool_generations > 0) { + session_buffer_pool_ = std::make_unique( + config.session_buffer_pool_generations); + } if (config.enable_pix_capture) { #if defined(ENABLE_PIX_FOR_WEBGPU_EP) // set pix frame generator @@ -748,6 +752,11 @@ WebGpuExecutionProvider::~WebGpuExecutionProvider() { // but no entries in captured_graphs_ (edge case cleanup) per_graph_buffer_mgrs_.clear(); + // Release pooled buffers before the WebGpuContext is released. + if (session_buffer_pool_) { + session_buffer_pool_->Clear(); + } + WebGpuContextFactory::ReleaseContext(context_id_); } @@ -792,6 +801,9 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_op webgpu::BufferCacheMode::Graph, webgpu::BufferCacheMode::GraphSimple, webgpu::BufferCacheMode::Disabled); + if (session_buffer_pool_) { + session_buffer_pool_->SeedInto(*it->second); + } } graph_buffer_mgr_active_ = true; @@ -877,8 +889,14 @@ Status WebGpuExecutionProvider::ReleaseCapturedGraph(int graph_annotation_id) { // Remove from captured set captured_graph_ids_.erase(graph_annotation_id); - // Release per-graph buffer manager (destroys cached buffers) - per_graph_buffer_mgrs_.erase(graph_annotation_id); + // Release per-graph buffer manager (donate to session pool if enabled; otherwise destroy) + auto mgr_it = per_graph_buffer_mgrs_.find(graph_annotation_id); + if (mgr_it != per_graph_buffer_mgrs_.end()) { + if (session_buffer_pool_ && mgr_it->second) { + session_buffer_pool_->Donate(*mgr_it->second); + } + per_graph_buffer_mgrs_.erase(mgr_it); + } // Clean up run count tracking graph_id_to_run_count_.erase(graph_annotation_id); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 92d1ebfd36c79..1dc7de0c21e9f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -16,6 +16,7 @@ #include "core/graph/constants.h" #include "core/providers/providers.h" #include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/session_buffer_pool.h" #if defined(ENABLE_PIX_FOR_WEBGPU_EP) #include "core/providers/webgpu/webgpu_pix_frame_generator.h" @@ -46,6 +47,10 @@ struct WebGpuExecutionProviderConfig { bool enable_pix_capture{false}; // PIX capture is disabled by default bool enable_int64{false}; // int64 ops are not enabled by default uint32_t multi_rotary_cache_concat_offset{0}; // offset for concatenated multi rotary cache (0 = disabled) + // Number of generations of buffers to retain in the per-session pool for reuse + // across captured-graph lifetimes. 0 disables pooling. Default 1 caches one + // generator's worth of intermediate buffers. + size_t session_buffer_pool_generations{1}; std::vector force_cpu_node_names{}; }; @@ -144,6 +149,11 @@ class WebGpuExecutionProvider : public IExecutionProvider { // are isolated between different generators. std::unordered_map> per_graph_buffer_mgrs_; + // Per-session pool of buffers donated by retired per-graph BufferManagers, + // seeded into new per-graph BufferManagers to avoid device allocations for + // identically-shaped intermediate tensors across generators. + std::unique_ptr session_buffer_pool_; + // Store captured commands per graph annotation ID std::unordered_map> captured_graphs_; // Track which graph annotation IDs have completed capture diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index 16899370e47f1..9986b2f4f94cd 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -61,6 +61,19 @@ WebGpuExecutionProviderConfig ParseEpConfig(const ConfigOptions& config_options) } } + if (std::string pool_generations_str; + config_options.TryGetConfigEntry(kSessionBufferPoolGenerations, pool_generations_str)) { + size_t pool_generations = 0; + const char* begin = pool_generations_str.data(); + const char* end = begin + pool_generations_str.size(); + auto result = std::from_chars(begin, end, pool_generations); + if (result.ec == std::errc{} && result.ptr == end) { + webgpu_ep_config.session_buffer_pool_generations = pool_generations; + } else { + ORT_THROW("Invalid sessionBufferPoolGenerations value: ", pool_generations_str, ". Must be a non-negative integer."); + } + } + std::string enable_int64_str; if (config_options.TryGetConfigEntry(kEnableInt64, enable_int64_str)) { if (enable_int64_str == kEnableInt64_ON) { @@ -122,6 +135,7 @@ WebGpuExecutionProviderConfig ParseEpConfig(const ConfigOptions& config_options) LOGS_DEFAULT(VERBOSE) << "WebGPU EP pix capture enable: " << webgpu_ep_config.enable_pix_capture; LOGS_DEFAULT(VERBOSE) << "WebGPU EP enable int64: " << webgpu_ep_config.enable_int64; LOGS_DEFAULT(VERBOSE) << "WebGPU EP multi rotary cache concat offset: " << webgpu_ep_config.multi_rotary_cache_concat_offset; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP session buffer pool generations: " << webgpu_ep_config.session_buffer_pool_generations; return webgpu_ep_config; } diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h index d2faccdb8c4a5..4cab81dafe9c7 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h @@ -11,6 +11,7 @@ namespace options { constexpr const char* kPreferredLayout = "ep.webgpuexecutionprovider.preferredLayout"; constexpr const char* kEnableGraphCapture = "ep.webgpuexecutionprovider.enableGraphCapture"; +constexpr const char* kSessionBufferPoolGenerations = "ep.webgpuexecutionprovider.sessionBufferPoolGenerations"; constexpr const char* kEnableInt64 = "ep.webgpuexecutionprovider.enableInt64"; constexpr const char* kMultiRotaryCacheConcatOffset = "ep.webgpuexecutionprovider.multiRotaryCacheConcatOffset";