C/C++ API Reference
Loading...
Searching...
No Matches
coro.h
1// Copyright 2024 The Pigweed Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not
4// use this file except in compliance with the License. You may obtain a copy of
5// the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations under
13// the License.
14#pragma once
15
16#include <concepts>
17#include <coroutine>
18#include <variant>
19
20#include "pw_allocator/allocator.h"
21#include "pw_allocator/layout.h"
22#include "pw_async2/dispatcher.h"
23#include "pw_function/function.h"
24#include "pw_log/log.h"
25#include "pw_status/status.h"
26#include "pw_status/try.h"
27
28namespace pw::async2 {
29
31
32// Forward-declare `Coro` so that it can be referenced by the promise type APIs.
33template <std::constructible_from<pw::Status> T>
34class Coro;
35
38 public:
41 explicit CoroContext(pw::allocator::Allocator& alloc) : alloc_(alloc) {}
42 pw::allocator::Allocator& alloc() const { return alloc_; }
43
44 private:
46};
47
49
50// The internal coroutine API implementation details enabling `Coro<T>`.
51//
52// Users of `Coro<T>` need not concern themselves with these details, unless
53// they think it sounds like fun ;)
54namespace internal {
55
56void LogCoroAllocationFailure(size_t requested_size);
57
58template <typename T>
59class OptionalWrapper final {
60 public:
61 // Create an empty container for a to-be-provided value.
62 OptionalWrapper() : value_() {}
63
64 // Assign a value.
65 template <typename U>
66 OptionalWrapper& operator=(U&& value) {
67 value_ = std::forward<U>(value);
68 return *this;
69 }
70
71 // Retrieve the inner value.
72 //
73 // This operation will fail if no value was assigned.
74 operator T() {
75 PW_ASSERT(value_.has_value());
76 return *value_;
77 }
78
79 private:
80 std::optional<T> value_;
81};
82
83// A container for a to-be-produced value of type `T`.
84//
85// This is designed to allow avoiding the overhead of `std::optional` when
86// `T` is default-initializable.
87//
88// Values of this type begin as either:
89// - a default-initialized `T` if `T` is default-initializable or
90// - `std::nullopt`
91template <typename T>
92using OptionalOrDefault =
93 std::conditional<std::is_default_constructible<T>::value,
94 T,
95 OptionalWrapper<T>>::type;
96
97// A wrapper for `std::coroutine_handle` that assumes unique ownership of the
98// underlying `PromiseType`.
99//
100// This type will `destroy()` the underlying promise in its destructor, or
101// when `Release()` is called.
102template <typename PromiseType>
104 public:
105 // Construct a null (`!IsValid()`) handle.
106 OwningCoroutineHandle(std::nullptr_t) : promise_handle_(nullptr) {}
107
109 OwningCoroutineHandle(std::coroutine_handle<PromiseType>&& promise_handle)
110 : promise_handle_(std::move(promise_handle)) {}
111
112 // Empty out `other` and transfers ownership of its `promise_handle`
113 // to `this`.
115 : promise_handle_(std::move(other.promise_handle_)) {
116 other.promise_handle_ = nullptr;
117 }
118
119 // Empty out `other` and transfers ownership of its `promise_handle`
120 // to `this`.
121 OwningCoroutineHandle& operator=(OwningCoroutineHandle&& other) {
122 Release();
123 promise_handle_ = std::move(other.promise_handle_);
124 other.promise_handle_ = nullptr;
125 return *this;
126 }
127
128 // `destroy()`s the underlying `promise_handle` if valid.
129 ~OwningCoroutineHandle() { Release(); }
130
131 // Return whether or not this value contains a `promise_handle`.
132 //
133 // This will return `false` if this `OwningCoroutineHandle` was
134 // `nullptr`-initialized, moved from, or if `Release` was invoked.
135 [[nodiscard]] bool IsValid() const {
136 return promise_handle_.address() != nullptr;
137 }
138
139 // Return a reference to the underlying `PromiseType`.
140 //
141 // Precondition: `IsValid()` must be `true`.
142 [[nodiscard]] PromiseType& promise() const {
143 return promise_handle_.promise();
144 }
145
146 // Whether or not the underlying coroutine has completed.
147 //
148 // Precondition: `IsValid()` must be `true`.
149 [[nodiscard]] bool done() const { return promise_handle_.done(); }
150
151 // Resume the underlying coroutine.
152 //
153 // Precondition: `IsValid()` must be `true`, and `done()` must be
154 // `false`.
155 void resume() { promise_handle_.resume(); }
156
157 // Invokes `destroy()` on the underlying promise and deallocates its
158 // associated storage.
159 void Release() {
160 // DOCSTAG: [pw_async2-coro-release]
161 void* address = promise_handle_.address();
162 if (address != nullptr) {
163 pw::allocator::Deallocator& dealloc = promise_handle_.promise().dealloc_;
164 promise_handle_.destroy();
165 promise_handle_ = nullptr;
166 dealloc.Deallocate(address);
167 }
168 // DOCSTAG: [pw_async2-coro-release]
169 }
170
171 private:
172 std::coroutine_handle<PromiseType> promise_handle_;
173};
174
175// Forward-declare the wrapper type for values passed to `co_await`.
176template <typename Pendable, typename PromiseType>
177class Awaitable;
178
179// A container for values passed in and out of the promise.
180//
181// The C++20 coroutine `resume()` function cannot accept arguments no return
182// values, so instead coroutine inputs and outputs are funneled through this
183// type. A pointer to the `InOut` object is stored in the `CoroPromiseType`
184// so that the coroutine object can access it.
185template <typename T>
186struct InOut final {
187 // The `Context` passed into the coroutine via `Pend`.
188 Context* input_cx;
189
190 // The output assigned to by the coroutine if the coroutine is `done()`.
191 OptionalOrDefault<T>* output;
192};
193
194// Attempt to complete the current pendable value passed to `co_await`,
195// storing its return value inside the `Awaitable` object so that it can
196// be retrieved by the coroutine.
197//
198// Each `co_await` statement creates an `Awaitable` object whose `Pend`
199// method must be completed before the coroutine's `resume()` function can
200// be invoked.
201//
202// `sizeof(void*)` is used as the size since only one pointer capture is
203// required in all cases.
204using PendFillReturnValueFn =
205 pw::InlineFunction<Poll<>(Context&), sizeof(void*)>;
206
207// The `promise_type` of `Coro<T>`.
208//
209// To understand this type, it may be necessary to refer to the reference
210// documentation for the C++20 coroutine API.
211template <typename T>
212class CoroPromiseType final {
213 public:
214 // Construct the `CoroPromiseType` using the arguments passed to a
215 // function returning `Coro<T>`.
216 //
217 // The first argument *must* be a `CoroContext`. The other
218 // arguments are unused, but must be accepted in order for this to compile.
219 template <typename... Args>
220 CoroPromiseType(CoroContext& cx, const Args&...)
221 : dealloc_(cx.alloc()), currently_pending_(nullptr), in_out_(nullptr) {}
222
223 // Method-receiver version.
224 template <typename MethodReceiver, typename... Args>
225 CoroPromiseType(const MethodReceiver&, CoroContext& cx, const Args&...)
226 : dealloc_(cx.alloc()), currently_pending_(nullptr), in_out_(nullptr) {}
227
228 // Get the `Coro<T>` after successfully allocating the coroutine space
229 // and constructing `this`.
230 Coro<T> get_return_object();
231
232 // Do not begin executing the `Coro<T>` until `resume()` has been invoked
233 // for the first time.
234 std::suspend_always initial_suspend() { return {}; }
235
236 // Unconditionally suspend to prevent `destroy()` being invoked.
237 //
238 // The caller of `resume()` needs to first observe `done()` before the
239 // state can be destroyed.
240 //
241 // Setting this to suspend means that the caller is responsible for invoking
242 // `destroy()`.
243 std::suspend_always final_suspend() noexcept { return {}; }
244
245 // Store the `co_return` argument in the `InOut<T>` object provided by
246 // the `Pend` wrapper.
247 template <std::convertible_to<T> From>
248 void return_value(From&& value) {
249 *in_out_->output = std::forward<From>(value);
250 }
251
252 // Ignore exceptions in coroutines.
253 //
254 // Pigweed is not designed to be used with exceptions: `Result` or a
255 // similar type should be used to propagate errors.
256 void unhandled_exception() { PW_ASSERT(false); }
257
258 // Create an invalid (nullptr) `Coro<T>` if `operator new` below fails.
259 static Coro<T> get_return_object_on_allocation_failure();
260
261 // Allocate the space for both this `CoroPromiseType<T>` and the coroutine
262 // state.
263 //
264 // This override does not accept alignment.
265 template <typename... Args>
266 static void* operator new(std::size_t size,
267 CoroContext& coro_cx,
268 const Args&...) noexcept {
269 return SharedNew(coro_cx, size, alignof(std::max_align_t));
270 }
271
272 // Allocate the space for both this `CoroPromiseType<T>` and the coroutine
273 // state.
274 //
275 // This override accepts alignment.
276 template <typename... Args>
277 static void* operator new(std::size_t size,
278 std::align_val_t align,
279 CoroContext& coro_cx,
280 const Args&...) noexcept {
281 return SharedNew(coro_cx, size, static_cast<size_t>(align));
282 }
283
284 // Method-receiver form.
285 //
286 // This override does not accept alignment.
287 template <typename MethodReceiver, typename... Args>
288 static void* operator new(std::size_t size,
289 const MethodReceiver&,
290 CoroContext& coro_cx,
291 const Args&...) noexcept {
292 return SharedNew(coro_cx, size, alignof(std::max_align_t));
293 }
294
295 // Method-receiver form.
296 //
297 // This accepts alignment.
298 template <typename MethodReceiver, typename... Args>
299 static void* operator new(std::size_t size,
300 std::align_val_t align,
301 const MethodReceiver&,
302 CoroContext& coro_cx,
303 const Args&...) noexcept {
304 return SharedNew(coro_cx, size, static_cast<size_t>(align));
305 }
306
307 static void* SharedNew(CoroContext& coro_cx,
308 std::size_t size,
309 std::size_t align) noexcept {
310 auto ptr = coro_cx.alloc().Allocate(pw::allocator::Layout(size, align));
311 if (ptr == nullptr) {
312 internal::LogCoroAllocationFailure(size);
313 }
314 return ptr;
315 }
316
317 // Deallocate the space for both this `CoroPromiseType<T>` and the
318 // coroutine state.
319 //
320 // In reality, we do nothing here!!!
321 //
322 // Coroutines do not support `destroying_delete`, so we can't access
323 // `dealloc_` here, and therefore have no way to deallocate.
324 // Instead, deallocation is handled by `OwningCoroutineHandle<T>::Release`.
325 static void operator delete(void*) {}
326
327 // Handle a `co_await` call by accepting a type with a
328 // `Poll<U> Pend(Context&)` method, returning an `Awaitable` which will
329 // yield a `U` once complete.
330 template <typename Pendable>
331 requires(!std::is_reference_v<Pendable>)
332 Awaitable<Pendable, CoroPromiseType> await_transform(Pendable&& pendable) {
333 return pendable;
334 }
335
336 template <typename Pendable>
337 Awaitable<Pendable*, CoroPromiseType> await_transform(Pendable& pendable) {
338 return &pendable;
339 }
340
341 // Returns a reference to the `Context` passed in.
342 Context& cx() { return *in_out_->input_cx; }
343
344 pw::allocator::Deallocator& dealloc_;
345 PendFillReturnValueFn currently_pending_;
346 InOut<T>* in_out_;
347};
348
349// The object created by invoking `co_await` in a `Coro<T>` function.
350//
351// This wraps a `Pendable` type and implements the awaitable interface
352// expected by the standard coroutine API.
353template <typename Pendable, typename PromiseType>
354class Awaitable final {
355 public:
356 // The `value_type` in `Poll<value_type> Pendable::Pend(Context&)`.
357 using value_type = std::remove_cvref_t<
358 decltype(std::declval<std::remove_pointer_t<Pendable>>()
359 .Pend(std::declval<Context&>())
360 .value())>;
361
362 Awaitable(Pendable&& pendable) : state_(std::forward<Pendable>(pendable)) {}
363
364 // Confirms that `await_suspend` must be invoked.
365 bool await_ready() { return false; }
366
367 // Returns whether or not the current coroutine should be suspended.
368 //
369 // This is invoked once as part of every `co_await` call after
370 // `await_ready` returns `false`.
371 //
372 // In the process, this method attempts to complete the inner `Pendable`
373 // before suspending this coroutine.
374 bool await_suspend(const std::coroutine_handle<PromiseType>& promise) {
375 Context& cx = promise.promise().cx();
376 if (PendFillReturnValue(cx).IsPending()) {
378 promise.promise().currently_pending_ = [this](Context& lambda_cx) {
379 return PendFillReturnValue(lambda_cx);
380 };
381 return true;
382 }
383 return false;
384 }
385
386 // Returns `return_value`.
387 //
388 // This is automatically invoked by the language runtime when the promise's
389 // `resume()` method is called.
390 value_type&& await_resume() {
391 return std::move(std::get<value_type>(state_));
392 }
393
394 auto& PendableNoPtr() {
395 if constexpr (std::is_pointer_v<Pendable>) {
396 return *std::get<Pendable>(state_);
397 } else {
398 return std::get<Pendable>(state_);
399 }
400 }
401
402 // Attempts to complete the `Pendable` value, storing its return value
403 // upon completion.
404 //
405 // This method must return `Ready()` before the coroutine can be safely
406 // resumed, as otherwise the return value will not be available when
407 // `await_resume` is called to produce the result of `co_await`.
408 Poll<> PendFillReturnValue(Context& cx) {
409 Poll<value_type> poll_res(PendableNoPtr().Pend(cx));
410 if (poll_res.IsPending()) {
411 return Pending();
412 }
413 state_ = std::move(*poll_res);
414 return Ready();
415 }
416
417 private:
418 std::variant<Pendable, value_type> state_;
419};
420
421} // namespace internal
422
424
468template <std::constructible_from<pw::Status> T>
469class Coro final {
470 public:
472 static Coro Empty() {
474 }
475
480 [[nodiscard]] bool IsValid() const { return promise_handle_.IsValid(); }
481
487 if (!IsValid()) {
488 // This coroutine failed to allocate its internal state.
489 // (Or `Pend` is being erroniously invoked after previously completing.)
490 return Ready(Status::Internal());
491 }
492
493 // If an `Awaitable` value is currently being processed, it must be
494 // allowed to complete and store its return value before we can resume
495 // the coroutine.
496 if (promise_handle_.promise().currently_pending_ != nullptr &&
497 promise_handle_.promise().currently_pending_(cx).IsPending()) {
498 return Pending();
499 }
500 // DOCSTAG: [pw_async2-coro-resume]
501 // Create the arguments (and output storage) for the coroutine.
502 internal::InOut<T> in_out;
503 internal::OptionalOrDefault<T> return_value;
504 in_out.input_cx = &cx;
505 in_out.output = &return_value;
506 promise_handle_.promise().in_out_ = &in_out;
507
508 // Resume the coroutine, triggering `Awaitable::await_resume()` and the
509 // returning of the resulting value from `co_await`.
510 promise_handle_.resume();
511 if (!promise_handle_.done()) {
512 return Pending();
513 }
514
515 // Destroy the coroutine state: it has completed, and further calls to
516 // `resume` would result in undefined behavior.
517 promise_handle_.Release();
518
519 // When the coroutine completed in `resume()` above, it stored its
520 // `co_return` value into `return_value`. This retrieves that value.
521 return return_value;
522 // DOCSTAG: [pw_async2-coro-resume]
523 }
524
528
529 private:
530 // Allow `CoroPromiseType<T>::get_return_object()` and
531 // `CoroPromiseType<T>::get_retunr_object_on_allocation_failure()` to
532 // use the private constructor below.
533 friend promise_type;
534
537 : promise_handle_(std::move(promise_handle)) {}
538
540};
541
543
544// Implement the remaining internal pieces that require a definition of
545// `Coro<T>`.
546namespace internal {
547
548template <typename T>
549Coro<T> CoroPromiseType<T>::get_return_object() {
550 return internal::OwningCoroutineHandle<CoroPromiseType<T>>(
551 std::coroutine_handle<CoroPromiseType<T>>::from_promise(*this));
552}
553
554template <typename T>
555Coro<T> CoroPromiseType<T>::get_return_object_on_allocation_failure() {
556 return internal::OwningCoroutineHandle<CoroPromiseType<T>>(nullptr);
557}
558
559} // namespace internal
560} // namespace pw::async2
Definition: allocator.h:36
static constexpr Status Internal()
Internal error occurred; e.g. system invariants were violated.
Definition: status.h:182
Definition: layout.h:58
Definition: context.h:55
Context required for creating and executing coroutines.
Definition: coro.h:37
CoroContext(pw::allocator::Allocator &alloc)
Definition: coro.h:41
Definition: coro.h:469
static Coro Empty()
Creates an empty, invalid coroutine object.
Definition: coro.h:472
::pw::async2::internal::CoroPromiseType< T > promise_type
Definition: coro.h:527
bool IsValid() const
Definition: coro.h:480
Poll< T > Pend(Context &cx)
Definition: coro.h:486
Definition: poll.h:60
Definition: coro.h:354
bool await_suspend(const std::coroutine_handle< PromiseType > &promise)
Definition: coro.h:374
OwningCoroutineHandle(std::coroutine_handle< PromiseType > &&promise_handle)
Take ownership of promise_handle.
Definition: coro.h:109
constexpr PendingType Pending()
Returns a value indicating that an operation was not yet able to complete.
Definition: poll.h:271
constexpr Poll Ready()
Returns a value indicating completion.
Definition: poll.h:255
fit::inline_function< FunctionType, inline_target_size > InlineFunction
Definition: function.h:90
Definition: coro.h:186