Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions onnxruntime/core/providers/webgpu/buffer_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,25 @@ class GraphCacheManager : public IBufferCacheManager {
// no-op - buffers are already in buckets_
}

std::vector<std::pair<size_t, WGPUBuffer>> ExtractCachedBuffers() override {
std::vector<std::pair<size_t, WGPUBuffer>> result;
for (auto& pair : buckets_) {
for (auto& buffer : pair.second) {
result.emplace_back(pair.first, buffer);
}
pair.second.clear();
}
return result;
}

void AbsorbCachedBuffers(std::vector<std::pair<size_t, WGPUBuffer>>&& buffers) override {
for (auto& entry : buffers) {
if (entry.second) {
ReleaseBuffer(entry.second);
}
}
}

~GraphCacheManager() {
for (auto& pair : buckets_) {
for (auto& buffer : pair.second) {
Expand Down Expand Up @@ -386,6 +405,37 @@ class GraphSimpleCacheManager : public IBufferCacheManager {
}
}

std::vector<std::pair<size_t, WGPUBuffer>> ExtractCachedBuffers() override {
// Donation is expected after captured commands have been released and any
// in-flight work has completed; all three containers therefore hold buffers
// no longer referenced by the device.
std::vector<std::pair<size_t, WGPUBuffer>> result;
for (auto& pair : buffers_) {
for (auto& buffer : pair.second) {
result.emplace_back(pair.first, buffer);
}
pair.second.clear();
}
buffers_.clear();
for (auto& buffer : pending_buffers_) {
result.emplace_back(static_cast<size_t>(wgpuBufferGetSize(buffer)), buffer);
}
pending_buffers_.clear();
for (auto& buffer : captured_buffers_) {
result.emplace_back(static_cast<size_t>(wgpuBufferGetSize(buffer)), buffer);
}
captured_buffers_.clear();
return result;
}

void AbsorbCachedBuffers(std::vector<std::pair<size_t, WGPUBuffer>>&& buffers) override {
for (auto& entry : buffers) {
if (entry.second) {
buffers_[entry.first].emplace_back(entry.second);
}
}
}

protected:
std::map<size_t, std::vector<WGPUBuffer>> buffers_;
std::vector<WGPUBuffer> pending_buffers_;
Expand Down
25 changes: 25 additions & 0 deletions onnxruntime/core/providers/webgpu/buffer_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#pragma once

#include <iosfwd>
#include <utility>
#include <vector>

#include "core/providers/webgpu/webgpu_external_header.h"

Expand Down Expand Up @@ -60,6 +62,24 @@ class IBufferCacheManager {

// when a stream refresh is requested
virtual void OnRefresh(GraphCaptureState graph_capture_state) = 0;

// Extract all cached buffers from this manager, transferring ownership to the
// caller. The cache's internal containers are cleared (but bucket keys, if any,
// are preserved). Default returns empty; only graph-mode caches implement this.
virtual std::vector<std::pair<size_t, WGPUBuffer>> ExtractCachedBuffers() {
return {};
}

// Accept buffers donated from another cache and take ownership of them. Caches
// that cannot store the buffers must release them via wgpuBufferRelease (the
// default below) to avoid leaks.
virtual void AbsorbCachedBuffers(std::vector<std::pair<size_t, WGPUBuffer>>&& buffers) {
for (auto& entry : buffers) {
if (entry.second) {
wgpuBufferRelease(entry.second);
}
}
}
};

//
Expand All @@ -76,6 +96,11 @@ class BufferManager {
void Download(WGPUBuffer src, void* dst, size_t size) const;
void RefreshPendingBuffers(GraphCaptureState graph_capture_state) const;

// Direct access to the underlying cache managers. Used by SessionBufferPool to
// donate/seed buffers across per-graph BufferManager lifetimes.
IBufferCacheManager& StorageCache() { return *storage_cache_; }
IBufferCacheManager& UniformCache() { return *uniform_cache_; }

private:
IBufferCacheManager& GetCacheManager(wgpu::BufferUsage usage) const;
IBufferCacheManager& GetCacheManager(WGPUBuffer buffer) const;
Expand Down
74 changes: 74 additions & 0 deletions onnxruntime/core/providers/webgpu/session_buffer_pool.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/session_buffer_pool.h"

#include "core/providers/webgpu/buffer_manager.h"

namespace onnxruntime {
namespace webgpu {

namespace {
void ReleaseSlotBuffers(std::vector<std::pair<size_t, WGPUBuffer>>& entries) {
for (auto& entry : entries) {
if (entry.second) {
wgpuBufferRelease(entry.second);
}
}
entries.clear();
}
} // namespace

SessionBufferPool::SessionBufferPool(size_t max_generations)
: max_generations_{max_generations} {
}
Comment thread
qjia7 marked this conversation as resolved.

SessionBufferPool::~SessionBufferPool() {
Clear();
}

void SessionBufferPool::Donate(BufferManager& retiring_mgr) {
if (max_generations_ == 0) {
return;
}

Slot slot;
slot.storage = retiring_mgr.StorageCache().ExtractCachedBuffers();
slot.uniform = retiring_mgr.UniformCache().ExtractCachedBuffers();

if (slot.storage.empty() && slot.uniform.empty()) {
return;
}

// Evict the oldest slot if at capacity so the freshest buffers (which most
// accurately reflect the current per-generator shape distribution) are kept.
while (slots_.size() >= max_generations_) {
auto& victim = slots_.front();
ReleaseSlotBuffers(victim.storage);
ReleaseSlotBuffers(victim.uniform);
slots_.erase(slots_.begin());
}

slots_.emplace_back(std::move(slot));
}

void SessionBufferPool::SeedInto(BufferManager& new_mgr) {
if (slots_.empty()) {
return;
}
Slot slot = std::move(slots_.back());
slots_.pop_back();
new_mgr.StorageCache().AbsorbCachedBuffers(std::move(slot.storage));
new_mgr.UniformCache().AbsorbCachedBuffers(std::move(slot.uniform));

Check warning on line 62 in onnxruntime/core/providers/webgpu/session_buffer_pool.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/session_buffer_pool.cc:62: Add #include <utility> for move [build/include_what_you_use] [4]
}

void SessionBufferPool::Clear() {
for (auto& slot : slots_) {
ReleaseSlotBuffers(slot.storage);
ReleaseSlotBuffers(slot.uniform);
}
slots_.clear();
}

} // namespace webgpu
} // namespace onnxruntime
54 changes: 54 additions & 0 deletions onnxruntime/core/providers/webgpu/session_buffer_pool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <cstddef>
#include <utility>
#include <vector>

#include "core/providers/webgpu/webgpu_external_header.h"

namespace onnxruntime {
namespace webgpu {

class BufferManager;

// SessionBufferPool retains buffers from retired per-graph BufferManagers so
// that subsequent generators on the same session can reuse them instead of
// allocating from the device. Scoped per WebGpuExecutionProvider (per session)
// because intermediate buffer shapes are model-dependent.
class SessionBufferPool {
public:
explicit SessionBufferPool(size_t max_generations);

~SessionBufferPool();

// Move freed buffers from a retiring per-graph BufferManager into the pool.
// When the pool is at capacity, the oldest slot is evicted (its buffers
// released to the device) so the freshest buffers are always retained. This
// lets the pool adapt when intermediate buffer shapes change between
// generators (for example when max_length differs).
void Donate(BufferManager& retiring_mgr);

// Pre-populate a newly created per-graph BufferManager with one slot worth of
// pooled buffers (LIFO). No-op if the pool is empty.
void SeedInto(BufferManager& new_mgr);

// Release all pooled buffers. Called on session teardown.
void Clear();

size_t Size() const { return slots_.size(); }
size_t Capacity() const { return max_generations_; }

private:
struct Slot {
std::vector<std::pair<size_t, WGPUBuffer>> storage;
std::vector<std::pair<size_t, WGPUBuffer>> uniform;
};
std::vector<Slot> slots_;
size_t max_generations_;
};

} // namespace webgpu
} // namespace onnxruntime
22 changes: 20 additions & 2 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,10 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
multi_rotary_cache_concat_offset_{config.multi_rotary_cache_concat_offset},
prepack_allocator_{std::make_shared<webgpu::GpuBufferAllocator>(
[this]() -> const webgpu::BufferManager& { return context_.InitializerBufferManager(); }, false)} {
if (enable_graph_capture_ && config.session_buffer_pool_generations > 0) {
session_buffer_pool_ = std::make_unique<webgpu::SessionBufferPool>(
config.session_buffer_pool_generations);
}
if (config.enable_pix_capture) {
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
// set pix frame generator
Expand Down Expand Up @@ -748,6 +752,11 @@ WebGpuExecutionProvider::~WebGpuExecutionProvider() {
// but no entries in captured_graphs_ (edge case cleanup)
per_graph_buffer_mgrs_.clear();

// Release pooled buffers before the WebGpuContext is released.
if (session_buffer_pool_) {
session_buffer_pool_->Clear();
}

WebGpuContextFactory::ReleaseContext(context_id_);
}

Expand Down Expand Up @@ -792,6 +801,9 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_op
webgpu::BufferCacheMode::Graph,
webgpu::BufferCacheMode::GraphSimple,
webgpu::BufferCacheMode::Disabled);
if (session_buffer_pool_) {
session_buffer_pool_->SeedInto(*it->second);
}
}
graph_buffer_mgr_active_ = true;

Expand Down Expand Up @@ -877,8 +889,14 @@ Status WebGpuExecutionProvider::ReleaseCapturedGraph(int graph_annotation_id) {
// Remove from captured set
captured_graph_ids_.erase(graph_annotation_id);

// Release per-graph buffer manager (destroys cached buffers)
per_graph_buffer_mgrs_.erase(graph_annotation_id);
// Release per-graph buffer manager (donate to session pool if enabled; otherwise destroy)
auto mgr_it = per_graph_buffer_mgrs_.find(graph_annotation_id);
if (mgr_it != per_graph_buffer_mgrs_.end()) {
if (session_buffer_pool_ && mgr_it->second) {
session_buffer_pool_->Donate(*mgr_it->second);
}
per_graph_buffer_mgrs_.erase(mgr_it);
}

// Clean up run count tracking
graph_id_to_run_count_.erase(graph_annotation_id);
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "core/graph/constants.h"
#include "core/providers/providers.h"
#include "core/providers/webgpu/buffer_manager.h"
#include "core/providers/webgpu/session_buffer_pool.h"

#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
#include "core/providers/webgpu/webgpu_pix_frame_generator.h"
Expand Down Expand Up @@ -46,6 +47,10 @@ struct WebGpuExecutionProviderConfig {
bool enable_pix_capture{false}; // PIX capture is disabled by default
bool enable_int64{false}; // int64 ops are not enabled by default
uint32_t multi_rotary_cache_concat_offset{0}; // offset for concatenated multi rotary cache (0 = disabled)
// Number of generations of buffers to retain in the per-session pool for reuse
// across captured-graph lifetimes. 0 disables pooling. Default 1 caches one
// generator's worth of intermediate buffers.
size_t session_buffer_pool_generations{1};
std::vector<std::string> force_cpu_node_names{};
};

Expand Down Expand Up @@ -144,6 +149,11 @@ class WebGpuExecutionProvider : public IExecutionProvider {
// are isolated between different generators.
std::unordered_map<int, std::unique_ptr<webgpu::BufferManager>> per_graph_buffer_mgrs_;

// Per-session pool of buffers donated by retired per-graph BufferManagers,
// seeded into new per-graph BufferManagers to avoid device allocations for
// identically-shaped intermediate tensors across generators.
std::unique_ptr<webgpu::SessionBufferPool> session_buffer_pool_;

// Store captured commands per graph annotation ID
std::unordered_map<int, std::vector<webgpu::CapturedCommandInfo>> captured_graphs_;
// Track which graph annotation IDs have completed capture
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ WebGpuExecutionProviderConfig ParseEpConfig(const ConfigOptions& config_options)
}
}

if (std::string pool_generations_str;
config_options.TryGetConfigEntry(kSessionBufferPoolGenerations, pool_generations_str)) {
size_t pool_generations = 0;
const char* begin = pool_generations_str.data();
const char* end = begin + pool_generations_str.size();
auto result = std::from_chars(begin, end, pool_generations);
if (result.ec == std::errc{} && result.ptr == end) {
webgpu_ep_config.session_buffer_pool_generations = pool_generations;
} else {
ORT_THROW("Invalid sessionBufferPoolGenerations value: ", pool_generations_str, ". Must be a non-negative integer.");
}
}

std::string enable_int64_str;
if (config_options.TryGetConfigEntry(kEnableInt64, enable_int64_str)) {
if (enable_int64_str == kEnableInt64_ON) {
Expand Down Expand Up @@ -122,6 +135,7 @@ WebGpuExecutionProviderConfig ParseEpConfig(const ConfigOptions& config_options)
LOGS_DEFAULT(VERBOSE) << "WebGPU EP pix capture enable: " << webgpu_ep_config.enable_pix_capture;
LOGS_DEFAULT(VERBOSE) << "WebGPU EP enable int64: " << webgpu_ep_config.enable_int64;
LOGS_DEFAULT(VERBOSE) << "WebGPU EP multi rotary cache concat offset: " << webgpu_ep_config.multi_rotary_cache_concat_offset;
LOGS_DEFAULT(VERBOSE) << "WebGPU EP session buffer pool generations: " << webgpu_ep_config.session_buffer_pool_generations;

return webgpu_ep_config;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace options {

constexpr const char* kPreferredLayout = "ep.webgpuexecutionprovider.preferredLayout";
constexpr const char* kEnableGraphCapture = "ep.webgpuexecutionprovider.enableGraphCapture";
constexpr const char* kSessionBufferPoolGenerations = "ep.webgpuexecutionprovider.sessionBufferPoolGenerations";
constexpr const char* kEnableInt64 = "ep.webgpuexecutionprovider.enableInt64";
constexpr const char* kMultiRotaryCacheConcatOffset = "ep.webgpuexecutionprovider.multiRotaryCacheConcatOffset";

Expand Down
Loading