C/C++ API Reference
Loading...
Searching...
No Matches
coro.h
1// Copyright 2024 The Pigweed Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not
4// use this file except in compliance with the License. You may obtain a copy of
5// the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations under
13// the License.
14#pragma once
15
16#include <concepts>
17#include <coroutine>
18#include <cstddef>
19#include <type_traits>
20#include <utility>
21
22#include "pw_allocator/allocator.h"
23#include "pw_allocator/layout.h"
24#include "pw_assert/assert.h"
25#include "pw_async2/context.h"
26#include "pw_async2/future.h"
27#include "pw_containers/internal/optional.h"
28
29namespace pw::async2 {
30namespace internal {
31
32[[noreturn]] void CrashDueToCoroutineAllocationFailure();
33
34} // namespace internal
35
36// Forward-declare `Coro` so that it can be referenced by the promise type APIs.
37template <typename T>
38class Coro;
39
40enum class ReturnValuePolicy : bool;
41
43
46 public:
52 constexpr CoroContext(Allocator& allocator) : allocator_(&allocator) {}
53
54 constexpr CoroContext(const CoroContext&) = default;
55 constexpr CoroContext& operator=(const CoroContext&) = default;
56
57 constexpr Allocator& allocator() const { return *allocator_; }
58
59 private:
60 Allocator* allocator_;
61};
62
64
65// The internal coroutine API implementation details enabling `Coro<T>`.
66//
67// Users of `Coro<T>` need not concern themselves with these details, unless
68// they think it sounds like fun ;)
69namespace internal {
70
77template <typename PromiseType>
79 public:
80 // Construct a null (`!IsValid()`) handle.
81 constexpr OwningCoroutineHandle(std::nullptr_t) : promise_handle_(nullptr) {}
82
84 OwningCoroutineHandle(std::coroutine_handle<PromiseType>&& promise_handle)
85 : promise_handle_(std::move(promise_handle)) {}
86
87 // Empty out `other` and transfers ownership of its `promise_handle`
88 // to `this`.
90 : promise_handle_(std::move(other.promise_handle_)) {
91 other.promise_handle_ = nullptr;
92 }
93
94 // Empty out `other` and transfers ownership of its `promise_handle`
95 // to `this`.
96 OwningCoroutineHandle& operator=(OwningCoroutineHandle&& other) {
97 Release();
98 promise_handle_ = std::move(other.promise_handle_);
99 other.promise_handle_ = nullptr;
100 return *this;
101 }
102
103 // `destroy()`s the underlying `promise_handle` if valid.
104 ~OwningCoroutineHandle() { Release(); }
105
106 // Return whether or not this value contains a `promise_handle`.
107 //
108 // This will return `false` if this `OwningCoroutineHandle` was
109 // `nullptr`-initialized, moved from, or if `Release` was invoked.
110 [[nodiscard]] bool IsValid() const {
111 return promise_handle_.address() != nullptr;
112 }
113
114 // Return a reference to the underlying `PromiseType`.
115 //
116 // Precondition: `IsValid()` must be `true`.
117 [[nodiscard]] PromiseType& promise() const {
118 return promise_handle_.promise();
119 }
120
121 // Whether or not the underlying coroutine has completed.
122 //
123 // Precondition: `IsValid()` must be `true`.
124 [[nodiscard]] bool done() const { return promise_handle_.done(); }
125
126 // Resume the underlying coroutine.
127 //
128 // Precondition: `IsValid()` must be `true`, and `done()` must be
129 // `false`.
130 void resume() { promise_handle_.resume(); }
131
132 // Invokes `destroy()` on the underlying promise and deallocates its
133 // associated storage.
134 void Release() {
135 // DOCSTAG: [pw_async2-coro-release]
136 void* address = promise_handle_.address();
137 if (address != nullptr) {
138 Deallocator& dealloc = promise_handle_.promise().deallocator();
139 promise_handle_.destroy();
140 promise_handle_ = nullptr;
141 dealloc.Deallocate(address);
142 }
143 // DOCSTAG: [pw_async2-coro-release]
144 }
145
146 private:
147 std::coroutine_handle<PromiseType> promise_handle_;
148};
149
150// Forward-declare the wrapper type for values passed to `co_await`.
151template <typename CoroOrFuture, typename PromiseType>
152class Awaitable;
153
154template <typename T>
155struct is_coro : std::false_type {};
156
157template <typename T>
158struct is_coro<Coro<T>> : std::true_type {};
159
160template <typename T>
161concept IsCoro = is_coro<T>::value;
162
163enum class CoroPollState : uint8_t {
164 kPending,
165 kAborted,
166 kReady,
167};
168
169template <typename T>
171
177 public:
178 // Do not begin executing the `Coro<T>` until `resume()` has been invoked
179 // for the first time.
180 std::suspend_always initial_suspend() { return {}; }
181
182 // Unconditionally suspend to prevent `destroy()` being invoked.
183 //
184 // The caller of `resume()` needs to first observe `done()` before the
185 // state can be destroyed.
186 //
187 // Setting this to suspend means that the caller is responsible for invoking
188 // `destroy()`.
189 std::suspend_always final_suspend() noexcept { return {}; }
190
191 // Ignore exceptions in coroutines.
192 //
193 // Pigweed is not designed to be used with exceptions: `Result` or a
194 // similar type should be used to propagate errors.
195 void unhandled_exception() { PW_ASSERT(false); }
196
197 // Allocate the space for both this `CoroPromise<T>` and the coroutine state.
198 //
199 // This override does not accept alignment.
200 template <typename... Args>
201 static void* operator new(std::size_t size,
202 CoroContext coro_cx,
203 const Args&...) noexcept {
204 return SharedNew(coro_cx, size, alignof(std::max_align_t));
205 }
206
207 // Allocate the space for both this `CoroPromise<T>` and the coroutine state.
208 //
209 // This override accepts alignment.
210 template <typename... Args>
211 static void* operator new(std::size_t size,
212 std::align_val_t align,
213 CoroContext coro_cx,
214 const Args&...) noexcept {
215 return SharedNew(coro_cx, size, static_cast<size_t>(align));
216 }
217
218 // Method-receiver form.
219 //
220 // This override does not accept alignment.
221 template <typename MethodReceiver, typename... Args>
222 static void* operator new(std::size_t size,
223 const MethodReceiver&,
224 CoroContext coro_cx,
225 const Args&...) noexcept {
226 return SharedNew(coro_cx, size, alignof(std::max_align_t));
227 }
228
229 // Method-receiver form.
230 //
231 // This accepts alignment.
232 template <typename MethodReceiver, typename... Args>
233 static void* operator new(std::size_t size,
234 std::align_val_t align,
235 const MethodReceiver&,
236 CoroContext coro_cx,
237 const Args&...) noexcept {
238 return SharedNew(coro_cx, size, static_cast<size_t>(align));
239 }
240
241 // Deallocate the space for both this `CoroPromise<T>` and the coroutine
242 // state.
243 //
244 // In reality, we do nothing here!!!
245 //
246 // Coroutines do not support `destroying_delete`, so we can't access
247 // `dealloc_` here, and therefore have no way to deallocate.
248 // Instead, deallocation is handled by `OwningCoroutineHandle<T>::Release`.
249 static void operator delete(void*) {}
250
251 CoroPollState AdvanceAwaitable(Context& cx) {
252 if (pending_awaitable_ == nullptr) {
253 return CoroPollState::kReady;
254 }
255 return pending_awaitable_func_(pending_awaitable_, cx);
256 }
257
258 Deallocator& deallocator() const { return dealloc_; }
259
260 template <typename AwaitableType>
261 void SuspendAwaitable(AwaitableType& awaitable) {
262 pending_awaitable_ = &awaitable;
263 pending_awaitable_func_ = [](void* obj, Context& lambda_cx) {
264 return static_cast<AwaitableType*>(obj)->Advance(lambda_cx);
265 };
266 }
267
268 protected:
270 : dealloc_(cx.allocator()), pending_awaitable_(nullptr) {}
271
272 private:
273 static void* SharedNew(CoroContext coro_cx,
274 std::size_t size,
275 std::size_t align) noexcept;
276
277 Deallocator& dealloc_;
278
279 // Attempt to complete the current awaitable value passed to `co_await`,
280 // storing its return value inside the `Awaitable` object so that it can
281 // be retrieved by the coroutine.
282 //
283 // Each `co_await` statement creates an `Awaitable` object whose `Pend`
284 // method must be completed before the coroutine's `resume()` function can
285 // be invoked.
286 void* pending_awaitable_;
287 CoroPollState (*pending_awaitable_func_)(void*, Context&);
288};
289
290template <typename T, typename Derived>
292 public:
293 using value_type = T;
294
295 // Get the `Coro<T>` after successfully allocating the coroutine space
296 // and constructing `this`.
297 Coro<T> get_return_object();
298
299 // Create an invalid (nullptr) `Coro<T>` if `operator new` fails.
300 static Coro<T> get_return_object_on_allocation_failure();
301
302 // Indicate that allocation failed for a nested coroutine.
303 void MarkNestedCoroutineAllocationFailure() {
304 output_->reset(CoroPollState::kAborted);
305 }
306
307 // Returns a reference to the `Context` passed in.
308 Context& cx() { return *context_; }
309
310 // Sets the `Context` to use when advancing the awaitable and the pointer
311 // where to store return values.
312 void SetContextAndOutput(Context& cx, internal::CoroPoll<T>& output) {
313 context_ = &cx;
314 this->output_ = &output;
315 }
316
317 // Coroutine API functions
318
319 // Handle a `co_await` call by accepting either a future or a coroutine and
320 // returning an `Awaitable` wrapper that yields a `value_type` once complete.
321 template <typename CoroOrFuture>
322 requires(!std::is_reference_v<CoroOrFuture>)
323 Awaitable<CoroOrFuture, Derived> await_transform(
324 CoroOrFuture&& coro_or_future) {
325 return std::forward<CoroOrFuture>(coro_or_future);
326 }
327
328 template <typename CoroOrFuture>
329 Awaitable<CoroOrFuture*, Derived> await_transform(
330 CoroOrFuture& coro_or_future) {
331 return &coro_or_future;
332 }
333
334 protected:
335 using CoroPromiseBase::CoroPromiseBase;
336
337 internal::CoroPoll<T>& output() const { return *output_; }
338
339 private:
340 Context* context_;
341 internal::CoroPoll<T>* output_;
342};
343
344template <typename T>
345class CoroPromise final : public TypedCoroPromise<T, CoroPromise<T>> {
346 public:
347 // Construct the `CoroPromise` using the arguments passed to a function
348 // returning `Coro<T>`.
349 //
350 // The first argument *must* be a `CoroContext`. The other arguments are
351 // unused, but must be accepted in order for this to compile.
352 template <typename... Args>
353 CoroPromise(CoroContext cx, const Args&...)
355
356 // Method-receiver version.
357 template <typename MethodReceiver, typename... Args>
358 CoroPromise(const MethodReceiver&, CoroContext cx, const Args&...)
360
361 // Store the `co_return` arg in the memory provided by the `Pend` wrapper.
362 template <std::convertible_to<T> From>
363 void return_value(From&& value) {
364 this->output() = std::forward<From>(value);
365 }
366};
367
368// Handle void-returning coroutines.
369//
370// C++ does not allow a promise type to declare both return_value() and
371// return_void(), so use a specialization.
372template <>
373class CoroPromise<void> final
374 : public TypedCoroPromise<void, CoroPromise<void>> {
375 public:
376 template <typename... Args>
377 CoroPromise(CoroContext cx, const Args&...)
379
380 template <typename MethodReceiver, typename... Args>
381 CoroPromise(const MethodReceiver&, CoroContext cx, const Args&...)
383
384 // Mark the output as ready.
385 void return_void() { this->output().emplace(); }
386};
387
388// The object created by invoking `co_await` in a `Coro<T>` function.
389//
390// This wraps a `Coro` or future and implements the awaitable interface expected
391// by the standard coroutine API.
392template <typename CoroOrFuture, typename PromiseType>
393class Awaitable final {
394 public:
395 // The concrete type this Awaitable wraps.
396 using await_type = std::remove_pointer_t<CoroOrFuture>;
397 // The type produced by this awaitable.
398 using value_type = typename await_type::value_type;
399
400 Awaitable(CoroOrFuture&& coro_or_future)
401 : state_(std::move(coro_or_future)), is_ready_(false) {}
402
403 ~Awaitable() {
404 if (is_ready_) {
405 state_.result.~Value();
406 } else {
407 state_.coro_or_future.~CoroOrFuture();
408 }
409 }
410
411 // Confirms that `await_suspend` must be invoked.
412 bool await_ready() { return false; }
413
414 // Returns whether or not the current coroutine should be suspended.
415 //
416 // This is invoked once as part of every `co_await` call after
417 // `await_ready` returns `false`.
418 //
419 // In the process, this method attempts to complete the inner `await_type`
420 // before suspending this coroutine.
421 bool await_suspend(const std::coroutine_handle<PromiseType>& promise)
422 requires Future<await_type>
423 {
424 Context& cx = promise.promise().cx();
425 if (Advance(cx) == CoroPollState::kPending) {
426 promise.promise().SuspendAwaitable(*this);
427 return true;
428 }
429 return false;
430 }
431
432 bool await_suspend(const std::coroutine_handle<PromiseType>& promise)
433 requires IsCoro<await_type>
434 {
435 Context& cx = promise.promise().cx();
436 CoroPollState state = Advance(cx);
437 if (state == CoroPollState::kPending) {
438 promise.promise().SuspendAwaitable(*this);
439 return true;
440 }
441 if (state == CoroPollState::kAborted) {
442 promise.promise().MarkNestedCoroutineAllocationFailure();
443 return true;
444 }
445 return false;
446 }
447
448 // Returns `return_value`.
449 //
450 // This is automatically invoked by the language runtime when the promise's
451 // `resume()` method is called.
452 value_type&& await_resume()
453 requires(!std::is_void_v<value_type>)
454 {
455 // await_resume() is never called after allocation failure because the
456 // coroutine is destroyed instead of being resumed again.
457 return std::move(state_.result);
458 }
459
460 void await_resume()
461 requires(std::is_void_v<value_type>)
462 {}
463
464 // Attempts to complete the `CoroOrFuture` value, storing its return value
465 // upon completion.
466 //
467 // This method must return `kReady` before the coroutine can be safely
468 // resumed, as otherwise the return value will not be available when
469 // `await_resume` is called to produce the result of `co_await`.
470 CoroPollState Advance(Context& cx)
471 requires Future<await_type>
472 {
473 Poll<value_type> poll_res(get().Pend(cx));
474 if (poll_res.IsPending()) {
475 return CoroPollState::kPending;
476 }
477 state_.coro_or_future.~CoroOrFuture();
478 new (&state_.result) Value(std::move(*poll_res));
479 is_ready_ = true;
480 return CoroPollState::kReady;
481 }
482
483 CoroPollState Advance(Context& cx)
484 requires IsCoro<await_type>
485 {
486 if (!get().ok()) {
487 return CoroPollState::kAborted;
488 }
489 auto result = get().Pend(cx);
490 if (result.state() == CoroPollState::kReady) {
491 state_.coro_or_future.~CoroOrFuture();
492 new (&state_.result) Value(std::move(*result));
493 is_ready_ = true;
494 }
495 return result.state();
496 }
497
498 private:
499 using Value = // Use ReadyType instead of void
500 std::conditional_t<std::is_void_v<value_type>, ReadyType, value_type>;
501
502 await_type& get() {
503 if constexpr (std::is_pointer_v<CoroOrFuture>) {
504 return *state_.coro_or_future;
505 } else {
506 return state_.coro_or_future;
507 }
508 }
509
510 union State {
511 State(CoroOrFuture&& c_or_f) : coro_or_future(std::move(c_or_f)) {}
512 ~State() {}
513
514 CoroOrFuture coro_or_future;
515 Value result;
516 } state_;
517 bool is_ready_; // Tracks whether state_ contains a coro/future or its value
518};
519
520} // namespace internal
521
523
555template <typename T>
556class Coro final {
557 public:
560
562 using value_type = T;
563
565 static Coro Empty() {
567 }
568
573 [[nodiscard]] bool ok() const { return promise_handle_.IsValid(); }
574
575 private:
576 // Allow get_return_object() and get_return_object_on_allocation_failure() to
577 // use the private constructor below.
579
580 // Only allow Awaitable and Coro task wrappers to call Pend.
581 template <typename, typename>
582 friend class internal::Awaitable;
583 template <typename, ReturnValuePolicy>
584 friend class CoroTask;
585 template <typename, typename E, ReturnValuePolicy>
586 requires std::invocable<E>
587 friend class FallibleCoroTask;
588
594 using enum internal::CoroPollState;
595
596 if (!ok()) {
597 internal::CrashDueToCoroutineAllocationFailure();
598 }
599
600 // DOCSTAG: [pw_async2-coro-resume]
601 internal::CoroPoll<T> return_value(kPending);
602
603 // Stash the Context& argument for the coroutine.
604 // Reserve space for the return value and point the promise to it.
605 promise_handle_.promise().SetContextAndOutput(cx, return_value);
606
607 // If an `Awaitable` value is currently being processed, it must be
608 // allowed to complete and store its return value before we can resume
609 // the coroutine.
610 switch (promise_handle_.promise().AdvanceAwaitable(cx)) {
611 case kPending:
612 break;
613 case kAborted:
614 return_value.reset(kAborted);
615 promise_handle_.Release();
616 break;
617 case kReady:
618 // Resume the coroutine, triggering `Awaitable::await_resume()` and the
619 // returning of the resulting value from `co_await`. The promise's
620 // `return_value()` function stores the result in the `return_value`
621 // variable at this point.
622 promise_handle_.resume();
623
624 // `return_value` now reflects the results of the operation. Unless it's
625 // still pending, free the coroutine's memory.
626 if (return_value.state() != kPending) {
627 // Destroy the coroutine state: it has completed or aborted, and
628 // further calls to `resume` would result in undefined behavior.
629 promise_handle_.Release();
630 }
631 break;
632 }
633
634 return return_value;
635 // DOCSTAG: [pw_async2-coro-resume]
636 }
637
639 explicit Coro(internal::OwningCoroutineHandle<promise_type>&& promise_handle)
640 : promise_handle_(std::move(promise_handle)) {}
641
642 internal::OwningCoroutineHandle<promise_type> promise_handle_;
643};
644
645template <typename Promise>
646Coro(internal::OwningCoroutineHandle<Promise>&&)
647 -> Coro<typename Promise::value_type>;
648
650
651// Implement the remaining internal pieces that require a definition of
652// `Coro<T>`.
653namespace internal {
654
655template <typename T, typename Derived>
656Coro<T> TypedCoroPromise<T, Derived>::get_return_object() {
657 return Coro<T>(internal::OwningCoroutineHandle<Derived>(
658 std::coroutine_handle<CoroPromise<T>>::from_promise(
659 static_cast<Derived&>(*this))));
660}
661
662template <typename T, typename Derived>
663Coro<T>
664TypedCoroPromise<T, Derived>::get_return_object_on_allocation_failure() {
665 return Coro<T>(internal::OwningCoroutineHandle<Derived>(nullptr));
666}
667
668// Checks that the first argument is a by-value CoroContext.
669template <typename... Args>
670struct CoroContextIsPassedByValue : std::false_type {};
671
672// If there is only a single argument, it must be a CoroContext value.
673template <typename First>
674struct CoroContextIsPassedByValue<First> : std::is_same<First, CoroContext> {};
675
676// If there are multiple arguments, the first argument to a free function must
677// be CoroContext. For member functions, the first argument is a reference to
678// the object and the second must be CoroContext. Note that this trait cannot
679// distinguish between a member function and a free function with a reference to
680// an object as its first argument.
681template <typename First, typename Second, typename... Others>
682struct CoroContextIsPassedByValue<First, Second, Others...>
683 : std::disjunction<
684 std::is_same<First, CoroContext>,
685 std::conjunction<std::is_same<Second, CoroContext>,
686 std::is_reference<First>,
687 std::is_class<std::remove_reference_t<First>>>> {};
688
689} // namespace internal
690} // namespace pw::async2
691
692// Specialize `std::coroutine_traits` to enforce `CoroContext` semantics.
693namespace std {
694
695template <typename T, typename... Args>
696struct coroutine_traits<pw::async2::Coro<T>, Args...> {
697 using promise_type = typename pw::async2::Coro<T>::promise_type;
698
699 static_assert(
701 "CoroContext must be passed by value as the first argument to a "
702 "pw_async2 coroutine");
703
704 static_assert(
705 (static_cast<int>(
706 std::is_same_v<std::remove_cvref_t<Args>, pw::async2::CoroContext>) +
707 ...) == 1,
708 "pw_async2 coroutines must have exactly one CoroContext argument");
709};
710
711} // namespace std
Definition: allocator.h:45
Abstract interface for releasing memory.
Definition: deallocator.h:29
Definition: context.h:46
Context required for creating and executing coroutines.
Definition: coro.h:45
constexpr CoroContext(Allocator &allocator)
Definition: coro.h:52
Definition: coro.h:556
static Coro Empty()
Creates an empty, invalid coroutine object.
Definition: coro.h:565
T value_type
The type this coroutine returns from a co_return expression.
Definition: coro.h:562
bool ok() const
Definition: coro.h:573
Definition: coro_task.h:33
Definition: fallible_coro_task.h:34
Definition: poll.h:138
Definition: coro.h:393
Definition: coro.h:345
OwningCoroutineHandle(std::coroutine_handle< PromiseType > &&promise_handle)
Take ownership of promise_handle.
Definition: coro.h:84
Definition: optional.h:65
Definition: future.h:47
Definition: coro.h:161
constexpr bool IsPending() const noexcept
Returns whether or not this value is Pending.
Definition: poll.h:214
ReturnValuePolicy
Whether to store or discard the function's return value in RunOnceTask.
Definition: func_task.h:62
The Pigweed namespace.
Definition: alignment.h:27
Definition: poll.h:40
Definition: coro.h:155