Pigweed
 
Loading...
Searching...
No Matches
coro.h
1// Copyright 2024 The Pigweed Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not
4// use this file except in compliance with the License. You may obtain a copy of
5// the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations under
13// the License.
14#pragma once
15
16#include <concepts>
17#include <coroutine>
18#include <variant>
19
20#include "pw_allocator/allocator.h"
21#include "pw_allocator/layout.h"
22#include "pw_async2/dispatcher.h"
23#include "pw_function/function.h"
24#include "pw_log/log.h"
25#include "pw_status/status.h"
26#include "pw_status/try.h"
27
28namespace pw::async2 {
29
30// Forward-declare `Coro` so that it can be referenced by the promise type APIs.
31template <std::constructible_from<pw::Status> T>
32class Coro;
33
36 public:
39 explicit CoroContext(pw::allocator::Allocator& alloc) : alloc_(alloc) {}
40 pw::allocator::Allocator& alloc() const { return alloc_; }
41
42 private:
44};
45
46// The internal coroutine API implementation details enabling `Coro<T>`.
47//
48// Users of `Coro<T>` need not concern themselves with these details, unless
49// they think it sounds like fun ;)
50namespace internal {
51
52void LogCoroAllocationFailure(size_t requested_size);
53
54template <typename T>
55class OptionalWrapper final {
56 public:
57 // Create an empty container for a to-be-provided value.
58 OptionalWrapper() : value_() {}
59
60 // Assign a value.
61 template <typename U>
62 OptionalWrapper& operator=(U&& value) {
63 value_ = std::forward<U>(value);
64 return *this;
65 }
66
67 // Retrieve the inner value.
68 //
69 // This operation will fail if no value was assigned.
70 operator T() {
71 PW_ASSERT(value_.has_value());
72 return *value_;
73 }
74
75 private:
76 std::optional<T> value_;
77};
78
79// A container for a to-be-produced value of type `T`.
80//
81// This is designed to allow avoiding the overhead of `std::optional` when
82// `T` is default-initializable.
83//
84// Values of this type begin as either:
85// - a default-initialized `T` if `T` is default-initializable or
86// - `std::nullopt`
87template <typename T>
88using OptionalOrDefault =
89 std::conditional<std::is_default_constructible<T>::value,
90 T,
91 OptionalWrapper<T>>::type;
92
93// A wrapper for `std::coroutine_handle` that assumes unique ownership of the
94// underlying `PromiseType`.
95//
96// This type will `destroy()` the underlying promise in its destructor, or
97// when `Release()` is called.
98template <typename PromiseType>
100 public:
101 // Construct a null (`!IsValid()`) handle.
102 OwningCoroutineHandle(std::nullptr_t) : promise_handle_(nullptr) {}
103
105 OwningCoroutineHandle(std::coroutine_handle<PromiseType>&& promise_handle)
106 : promise_handle_(std::move(promise_handle)) {}
107
108 // Empty out `other` and transfers ownership of its `promise_handle`
109 // to `this`.
111 : promise_handle_(std::move(other.promise_handle_)) {
112 other.promise_handle_ = nullptr;
113 }
114
115 // Empty out `other` and transfers ownership of its `promise_handle`
116 // to `this`.
117 OwningCoroutineHandle& operator=(OwningCoroutineHandle&& other) {
118 Release();
119 promise_handle_ = std::move(other.promise_handle_);
120 other.promise_handle_ = nullptr;
121 return *this;
122 }
123
124 // `destroy()`s the underlying `promise_handle` if valid.
125 ~OwningCoroutineHandle() { Release(); }
126
127 // Return whether or not this value contains a `promise_handle`.
128 //
129 // This will return `false` if this `OwningCoroutineHandle` was
130 // `nullptr`-initialized, moved from, or if `Release` was invoked.
131 [[nodiscard]] bool IsValid() const {
132 return promise_handle_.address() != nullptr;
133 }
134
135 // Return a reference to the underlying `PromiseType`.
136 //
137 // Precondition: `IsValid()` must be `true`.
138 [[nodiscard]] PromiseType& promise() const {
139 return promise_handle_.promise();
140 }
141
142 // Whether or not the underlying coroutine has completed.
143 //
144 // Precondition: `IsValid()` must be `true`.
145 [[nodiscard]] bool done() const { return promise_handle_.done(); }
146
147 // Resume the underlying coroutine.
148 //
149 // Precondition: `IsValid()` must be `true`, and `done()` must be
150 // `false`.
151 void resume() { promise_handle_.resume(); }
152
153 // Invokes `destroy()` on the underlying promise and deallocates its
154 // associated storage.
155 void Release() {
156 // DOCSTAG: [pw_async2-coro-release]
157 void* address = promise_handle_.address();
158 if (address != nullptr) {
159 pw::allocator::Deallocator& dealloc = promise_handle_.promise().dealloc_;
160 promise_handle_.destroy();
161 promise_handle_ = nullptr;
162 dealloc.Deallocate(address);
163 }
164 // DOCSTAG: [pw_async2-coro-release]
165 }
166
167 private:
168 std::coroutine_handle<PromiseType> promise_handle_;
169};
170
171// Forward-declare the wrapper type for values passed to `co_await`.
172template <typename Pendable, typename PromiseType>
173class Awaitable;
174
175// A container for values passed in and out of the promise.
176//
177// The C++20 coroutine `resume()` function cannot accept arguments no return
178// values, so instead coroutine inputs and outputs are funneled through this
179// type. A pointer to the `InOut` object is stored in the `CoroPromiseType`
180// so that the coroutine object can access it.
181template <typename T>
182struct InOut final {
183 // The `Context` passed into the coroutine via `Pend`.
184 Context* input_cx;
185
186 // The output assigned to by the coroutine if the coroutine is `done()`.
187 OptionalOrDefault<T>* output;
188};
189
190// Attempt to complete the current pendable value passed to `co_await`,
191// storing its return value inside the `Awaitable` object so that it can
192// be retrieved by the coroutine.
193//
194// Each `co_await` statement creates an `Awaitable` object whose `Pend`
195// method must be completed before the coroutine's `resume()` function can
196// be invoked.
197//
198// `sizeof(void*)` is used as the size since only one pointer capture is
199// required in all cases.
200using PendFillReturnValueFn =
201 pw::InlineFunction<Poll<>(Context&), sizeof(void*)>;
202
203// The `promise_type` of `Coro<T>`.
204//
205// To understand this type, it may be necessary to refer to the reference
206// documentation for the C++20 coroutine API.
207template <typename T>
208class CoroPromiseType final {
209 public:
210 // Construct the `CoroPromiseType` using the arguments passed to a
211 // function returning `Coro<T>`.
212 //
213 // The first argument *must* be a `CoroContext`. The other
214 // arguments are unused, but must be accepted in order for this to compile.
215 template <typename... Args>
216 CoroPromiseType(CoroContext& cx, const Args&...)
217 : dealloc_(cx.alloc()), currently_pending_(nullptr), in_out_(nullptr) {}
218
219 // Method-receiver version.
220 template <typename MethodReceiver, typename... Args>
221 CoroPromiseType(const MethodReceiver&, CoroContext& cx, const Args&...)
222 : dealloc_(cx.alloc()), currently_pending_(nullptr), in_out_(nullptr) {}
223
224 // Get the `Coro<T>` after successfully allocating the coroutine space
225 // and constructing `this`.
226 Coro<T> get_return_object();
227
228 // Do not begin executing the `Coro<T>` until `resume()` has been invoked
229 // for the first time.
230 std::suspend_always initial_suspend() { return {}; }
231
232 // Unconditionally suspend to prevent `destroy()` being invoked.
233 //
234 // The caller of `resume()` needs to first observe `done()` before the
235 // state can be destroyed.
236 //
237 // Setting this to suspend means that the caller is responsible for invoking
238 // `destroy()`.
239 std::suspend_always final_suspend() noexcept { return {}; }
240
241 // Store the `co_return` argument in the `InOut<T>` object provided by
242 // the `Pend` wrapper.
243 template <std::convertible_to<T> From>
244 void return_value(From&& value) {
245 *in_out_->output = std::forward<From>(value);
246 }
247
248 // Ignore exceptions in coroutines.
249 //
250 // Pigweed is not designed to be used with exceptions: `Result` or a
251 // similar type should be used to propagate errors.
252 void unhandled_exception() { PW_ASSERT(false); }
253
254 // Create an invalid (nullptr) `Coro<T>` if `operator new` below fails.
255 static Coro<T> get_return_object_on_allocation_failure();
256
257 // Allocate the space for both this `CoroPromiseType<T>` and the coroutine
258 // state.
259 //
260 // This override does not accept alignment.
261 template <typename... Args>
262 static void* operator new(std::size_t size,
263 CoroContext& coro_cx,
264 const Args&...) noexcept {
265 return SharedNew(coro_cx, size, alignof(std::max_align_t));
266 }
267
268 // Allocate the space for both this `CoroPromiseType<T>` and the coroutine
269 // state.
270 //
271 // This override accepts alignment.
272 template <typename... Args>
273 static void* operator new(std::size_t size,
274 std::align_val_t align,
275 CoroContext& coro_cx,
276 const Args&...) noexcept {
277 return SharedNew(coro_cx, size, static_cast<size_t>(align));
278 }
279
280 // Method-receiver form.
281 //
282 // This override does not accept alignment.
283 template <typename MethodReceiver, typename... Args>
284 static void* operator new(std::size_t size,
285 const MethodReceiver&,
286 CoroContext& coro_cx,
287 const Args&...) noexcept {
288 return SharedNew(coro_cx, size, alignof(std::max_align_t));
289 }
290
291 // Method-receiver form.
292 //
293 // This accepts alignment.
294 template <typename MethodReceiver, typename... Args>
295 static void* operator new(std::size_t size,
296 std::align_val_t align,
297 const MethodReceiver&,
298 CoroContext& coro_cx,
299 const Args&...) noexcept {
300 return SharedNew(coro_cx, size, static_cast<size_t>(align));
301 }
302
303 static void* SharedNew(CoroContext& coro_cx,
304 std::size_t size,
305 std::size_t align) noexcept {
306 auto ptr = coro_cx.alloc().Allocate(pw::allocator::Layout(size, align));
307 if (ptr == nullptr) {
308 internal::LogCoroAllocationFailure(size);
309 }
310 return ptr;
311 }
312
313 // Deallocate the space for both this `CoroPromiseType<T>` and the
314 // coroutine state.
315 //
316 // In reality, we do nothing here!!!
317 //
318 // Coroutines do not support `destroying_delete`, so we can't access
319 // `dealloc_` here, and therefore have no way to deallocate.
320 // Instead, deallocation is handled by `OwningCoroutineHandle<T>::Release`.
321 static void operator delete(void*) {}
322
323 // Handle a `co_await` call by accepting a type with a
324 // `Poll<U> Pend(Context&)` method, returning an `Awaitable` which will
325 // yield a `U` once complete.
326 template <typename Pendable>
327 requires(!std::is_reference_v<Pendable>)
328 Awaitable<Pendable, CoroPromiseType> await_transform(Pendable&& pendable) {
329 return pendable;
330 }
331
332 template <typename Pendable>
333 Awaitable<Pendable*, CoroPromiseType> await_transform(Pendable& pendable) {
334 return &pendable;
335 }
336
337 // Returns a reference to the `Context` passed in.
338 Context& cx() { return *in_out_->input_cx; }
339
340 pw::allocator::Deallocator& dealloc_;
341 PendFillReturnValueFn currently_pending_;
342 InOut<T>* in_out_;
343};
344
345// The object created by invoking `co_await` in a `Coro<T>` function.
346//
347// This wraps a `Pendable` type and implements the awaitable interface
348// expected by the standard coroutine API.
349template <typename Pendable, typename PromiseType>
350class Awaitable final {
351 public:
352 // The `OutputType` in `Poll<OutputType> Pendable::Pend(Context&)`.
353 using OutputType = std::remove_cvref_t<
354 decltype(std::declval<std::remove_pointer_t<Pendable>>()
355 .Pend(std::declval<Context&>())
356 .value())>;
357
358 Awaitable(Pendable&& pendable) : state_(std::forward<Pendable>(pendable)) {}
359
360 // Confirms that `await_suspend` must be invoked.
361 bool await_ready() { return false; }
362
363 // Returns whether or not the current coroutine should be suspended.
364 //
365 // This is invoked once as part of every `co_await` call after
366 // `await_ready` returns `false`.
367 //
368 // In the process, this method attempts to complete the inner `Pendable`
369 // before suspending this coroutine.
370 bool await_suspend(const std::coroutine_handle<PromiseType>& promise) {
371 Context& cx = promise.promise().cx();
372 if (PendFillReturnValue(cx).IsPending()) {
374 promise.promise().currently_pending_ = [this](Context& lambda_cx) {
375 return PendFillReturnValue(lambda_cx);
376 };
377 return true;
378 }
379 return false;
380 }
381
382 // Returns `return_value`.
383 //
384 // This is automatically invoked by the language runtime when the promise's
385 // `resume()` method is called.
386 OutputType&& await_resume() {
387 return std::move(std::get<OutputType>(state_));
388 }
389
390 auto& PendableNoPtr() {
391 if constexpr (std::is_pointer_v<Pendable>) {
392 return *std::get<Pendable>(state_);
393 } else {
394 return std::get<Pendable>(state_);
395 }
396 }
397
398 // Attempts to complete the `Pendable` value, storing its return value
399 // upon completion.
400 //
401 // This method must return `Ready()` before the coroutine can be safely
402 // resumed, as otherwise the return value will not be available when
403 // `await_resume` is called to produce the result of `co_await`.
404 Poll<> PendFillReturnValue(Context& cx) {
405 Poll<OutputType> poll_res(PendableNoPtr().Pend(cx));
406 if (poll_res.IsPending()) {
407 return Pending();
408 }
409 state_ = std::move(*poll_res);
410 return Ready();
411 }
412
413 private:
414 std::variant<Pendable, OutputType> state_;
415};
416
417} // namespace internal
418
462template <std::constructible_from<pw::Status> T>
463class Coro final {
464 public:
466 static Coro Empty() {
468 }
469
474 [[nodiscard]] bool IsValid() const { return promise_handle_.IsValid(); }
475
481 if (!IsValid()) {
482 // This coroutine failed to allocate its internal state.
483 // (Or `Pend` is being erroniously invoked after previously completing.)
484 return Ready(Status::Internal());
485 }
486
487 // If an `Awaitable` value is currently being processed, it must be
488 // allowed to complete and store its return value before we can resume
489 // the coroutine.
490 if (promise_handle_.promise().currently_pending_ != nullptr &&
491 promise_handle_.promise().currently_pending_(cx).IsPending()) {
492 return Pending();
493 }
494 // DOCSTAG: [pw_async2-coro-resume]
495 // Create the arguments (and output storage) for the coroutine.
496 internal::InOut<T> in_out;
497 internal::OptionalOrDefault<T> return_value;
498 in_out.input_cx = &cx;
499 in_out.output = &return_value;
500 promise_handle_.promise().in_out_ = &in_out;
501
502 // Resume the coroutine, triggering `Awaitable::await_resume()` and the
503 // returning of the resulting value from `co_await`.
504 promise_handle_.resume();
505 if (!promise_handle_.done()) {
506 return Pending();
507 }
508
509 // Destroy the coroutine state: it has completed, and further calls to
510 // `resume` would result in undefined behavior.
511 promise_handle_.Release();
512
513 // When the coroutine completed in `resume()` above, it stored its
514 // `co_return` value into `return_value`. This retrieves that value.
515 return return_value;
516 // DOCSTAG: [pw_async2-coro-resume]
517 }
518
522
523 private:
524 // Allow `CoroPromiseType<T>::get_return_object()` and
525 // `CoroPromiseType<T>::get_retunr_object_on_allocation_failure()` to
526 // use the private constructor below.
527 friend promise_type;
528
531 : promise_handle_(std::move(promise_handle)) {}
532
534};
535
536// Implement the remaining internal pieces that require a definition of
537// `Coro<T>`.
538namespace internal {
539
540template <typename T>
541Coro<T> CoroPromiseType<T>::get_return_object() {
542 return internal::OwningCoroutineHandle<CoroPromiseType<T>>(
543 std::coroutine_handle<CoroPromiseType<T>>::from_promise(*this));
544}
545
546template <typename T>
547Coro<T> CoroPromiseType<T>::get_return_object_on_allocation_failure() {
548 return internal::OwningCoroutineHandle<CoroPromiseType<T>>(nullptr);
549}
550
551} // namespace internal
552} // namespace pw::async2
Definition: allocator.h:34
Definition: layout.h:56
Definition: dispatcher_base.h:52
Context required for creating and executing coroutines.
Definition: coro.h:35
CoroContext(pw::allocator::Allocator &alloc)
Definition: coro.h:39
Definition: coro.h:463
static Coro Empty()
Creates an empty, invalid coroutine object.
Definition: coro.h:466
::pw::async2::internal::CoroPromiseType< T > promise_type
Definition: coro.h:521
bool IsValid() const
Definition: coro.h:474
Poll< T > Pend(Context &cx)
Definition: coro.h:480
Definition: poll.h:54
Definition: coro.h:350
bool await_suspend(const std::coroutine_handle< PromiseType > &promise)
Definition: coro.h:370
OwningCoroutineHandle(std::coroutine_handle< PromiseType > &&promise_handle)
Take ownership of promise_handle.
Definition: coro.h:105
fit::inline_function< FunctionType, inline_target_size > InlineFunction
Definition: function.h:91
Definition: coro.h:182