1 //===--- SemaCoroutines.cpp - Semantic Analysis for Coroutines ------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // This file implements semantic analysis for C++ Coroutines.
12 //===----------------------------------------------------------------------===//
14 #include "clang/Sema/SemaInternal.h"
15 #include "clang/AST/Decl.h"
16 #include "clang/AST/ExprCXX.h"
17 #include "clang/AST/StmtCXX.h"
18 #include "clang/Lex/Preprocessor.h"
19 #include "clang/Sema/Initialization.h"
20 #include "clang/Sema/Overload.h"
21 using namespace clang;
24 /// Look up the std::coroutine_traits<...>::promise_type for the given
26 static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType,
28 // FIXME: Cache std::coroutine_traits once we've found it.
29 NamespaceDecl *Std = S.getStdNamespace();
31 S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
35 LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"),
36 Loc, Sema::LookupOrdinaryName);
37 if (!S.LookupQualifiedName(Result, Std)) {
38 S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
42 ClassTemplateDecl *CoroTraits = Result.getAsSingle<ClassTemplateDecl>();
44 Result.suppressDiagnostics();
45 // We found something weird. Complain about the first thing we found.
46 NamedDecl *Found = *Result.begin();
47 S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits);
51 // Form template argument list for coroutine_traits<R, P1, P2, ...>.
52 TemplateArgumentListInfo Args(Loc, Loc);
53 Args.addArgument(TemplateArgumentLoc(
54 TemplateArgument(FnType->getReturnType()),
55 S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), Loc)));
56 // FIXME: If the function is a non-static member function, add the type
57 // of the implicit object parameter before the formal parameters.
58 for (QualType T : FnType->getParamTypes())
59 Args.addArgument(TemplateArgumentLoc(
60 TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, Loc)));
62 // Build the template-id.
64 S.CheckTemplateIdType(TemplateName(CoroTraits), Loc, Args);
65 if (CoroTrait.isNull())
67 if (S.RequireCompleteType(Loc, CoroTrait,
68 diag::err_coroutine_traits_missing_specialization))
71 CXXRecordDecl *RD = CoroTrait->getAsCXXRecordDecl();
72 assert(RD && "specialization of class template is not a class?");
74 // Look up the ::promise_type member.
75 LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), Loc,
76 Sema::LookupOrdinaryName);
77 S.LookupQualifiedName(R, RD);
78 auto *Promise = R.getAsSingle<TypeDecl>();
80 S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_found)
85 // The promise type is required to be a class type.
86 QualType PromiseType = S.Context.getTypeDeclType(Promise);
87 if (!PromiseType->getAsCXXRecordDecl()) {
88 // Use the fully-qualified name of the type.
89 auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, Std);
90 NNS = NestedNameSpecifier::Create(S.Context, NNS, false,
91 CoroTrait.getTypePtr());
92 PromiseType = S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
94 S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_class)
102 /// Check that this is a context in which a coroutine suspension can appear.
103 static FunctionScopeInfo *
104 checkCoroutineContext(Sema &S, SourceLocation Loc, StringRef Keyword) {
105 // 'co_await' and 'co_yield' are not permitted in unevaluated operands.
106 if (S.isUnevaluatedContext()) {
107 S.Diag(Loc, diag::err_coroutine_unevaluated_context) << Keyword;
111 // Any other usage must be within a function.
112 // FIXME: Reject a coroutine with a deduced return type.
113 auto *FD = dyn_cast<FunctionDecl>(S.CurContext);
115 S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext)
116 ? diag::err_coroutine_objc_method
117 : diag::err_coroutine_outside_function) << Keyword;
118 } else if (isa<CXXConstructorDecl>(FD) || isa<CXXDestructorDecl>(FD)) {
119 // Coroutines TS [special]/6:
120 // A special member function shall not be a coroutine.
122 // FIXME: We assume that this really means that a coroutine cannot
123 // be a constructor or destructor.
124 S.Diag(Loc, diag::err_coroutine_ctor_dtor)
125 << isa<CXXDestructorDecl>(FD) << Keyword;
126 } else if (FD->isConstexpr()) {
127 S.Diag(Loc, diag::err_coroutine_constexpr) << Keyword;
128 } else if (FD->isVariadic()) {
129 S.Diag(Loc, diag::err_coroutine_varargs) << Keyword;
131 auto *ScopeInfo = S.getCurFunction();
132 assert(ScopeInfo && "missing function scope for function");
134 // If we don't have a promise variable, build one now.
135 if (!ScopeInfo->CoroutinePromise) {
137 FD->getType()->isDependentType()
138 ? S.Context.DependentTy
139 : lookupPromiseType(S, FD->getType()->castAs<FunctionProtoType>(),
144 // Create and default-initialize the promise.
145 ScopeInfo->CoroutinePromise =
146 VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(),
147 &S.PP.getIdentifierTable().get("__promise"), T,
148 S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
149 S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise);
150 if (!ScopeInfo->CoroutinePromise->isInvalidDecl())
151 S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise, false);
160 /// Build a call to 'operator co_await' if there is a suitable operator for
161 /// the given expression.
162 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
163 SourceLocation Loc, Expr *E) {
164 UnresolvedSet<16> Functions;
165 SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(),
167 return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
170 struct ReadySuspendResumeResult {
175 static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
177 MutableArrayRef<Expr *> Args) {
178 DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc);
180 // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&.
182 ExprResult Result = S.BuildMemberReferenceExpr(
183 Base, Base->getType(), Loc, /*IsPtr=*/false, SS,
184 SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr,
186 if (Result.isInvalid())
189 return S.ActOnCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr);
192 /// Build calls to await_ready, await_suspend, and await_resume for a co_await
194 static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, SourceLocation Loc,
196 // Assume invalid until we see otherwise.
197 ReadySuspendResumeResult Calls = {true, {}};
199 const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"};
200 for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) {
201 Expr *Operand = new (S.Context) OpaqueValueExpr(
202 Loc, E->getType(), VK_LValue, E->getObjectKind(), E);
204 // FIXME: Pass coroutine handle to await_suspend.
205 ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], None);
206 if (Result.isInvalid())
208 Calls.Results[I] = Result.get();
211 Calls.IsInvalid = false;
215 ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
216 if (E->getType()->isPlaceholderType()) {
217 ExprResult R = CheckPlaceholderExpr(E);
218 if (R.isInvalid()) return ExprError();
222 ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E);
223 if (Awaitable.isInvalid())
225 return BuildCoawaitExpr(Loc, Awaitable.get());
227 ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) {
228 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await");
232 if (E->getType()->isPlaceholderType()) {
233 ExprResult R = CheckPlaceholderExpr(E);
234 if (R.isInvalid()) return ExprError();
238 if (E->getType()->isDependentType()) {
239 Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E);
240 Coroutine->CoroutineStmts.push_back(Res);
244 // If the expression is a temporary, materialize it as an lvalue so that we
245 // can use it multiple times.
246 if (E->getValueKind() == VK_RValue)
247 E = new (Context) MaterializeTemporaryExpr(E->getType(), E, true);
249 // Build the await_ready, await_suspend, await_resume calls.
250 ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E);
254 Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
256 Coroutine->CoroutineStmts.push_back(Res);
260 static ExprResult buildPromiseCall(Sema &S, FunctionScopeInfo *Coroutine,
261 SourceLocation Loc, StringRef Name,
262 MutableArrayRef<Expr *> Args) {
263 assert(Coroutine->CoroutinePromise && "no promise for coroutine");
265 // Form a reference to the promise.
266 auto *Promise = Coroutine->CoroutinePromise;
267 ExprResult PromiseRef = S.BuildDeclRefExpr(
268 Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
269 if (PromiseRef.isInvalid())
272 // Call 'yield_value', passing in E.
273 return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
276 ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
277 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
281 // Build yield_value call.
282 ExprResult Awaitable =
283 buildPromiseCall(*this, Coroutine, Loc, "yield_value", E);
284 if (Awaitable.isInvalid())
287 // Build 'operator co_await' call.
288 Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get());
289 if (Awaitable.isInvalid())
292 return BuildCoyieldExpr(Loc, Awaitable.get());
294 ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) {
295 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
299 if (E->getType()->isPlaceholderType()) {
300 ExprResult R = CheckPlaceholderExpr(E);
301 if (R.isInvalid()) return ExprError();
305 if (E->getType()->isDependentType()) {
306 Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E);
307 Coroutine->CoroutineStmts.push_back(Res);
311 // If the expression is a temporary, materialize it as an lvalue so that we
312 // can use it multiple times.
313 if (E->getValueKind() == VK_RValue)
314 E = new (Context) MaterializeTemporaryExpr(E->getType(), E, true);
316 // Build the await_ready, await_suspend, await_resume calls.
317 ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E);
321 Expr *Res = new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1],
323 Coroutine->CoroutineStmts.push_back(Res);
327 StmtResult Sema::ActOnCoreturnStmt(SourceLocation Loc, Expr *E) {
328 return BuildCoreturnStmt(Loc, E);
330 StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) {
331 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return");
335 if (E && E->getType()->isPlaceholderType() &&
336 !E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)) {
337 ExprResult R = CheckPlaceholderExpr(E);
338 if (R.isInvalid()) return StmtError();
342 // FIXME: If the operand is a reference to a variable that's about to go out
343 // of scope, we should treat the operand as an xvalue for this overload
346 if (E && !E->getType()->isVoidType()) {
347 PC = buildPromiseCall(*this, Coroutine, Loc, "return_value", E);
349 E = MakeFullDiscardedValueExpr(E).get();
350 PC = buildPromiseCall(*this, Coroutine, Loc, "return_void", None);
355 Expr *PCE = ActOnFinishFullExpr(PC.get()).get();
357 Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE);
358 Coroutine->CoroutineStmts.push_back(Res);
362 void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
363 FunctionScopeInfo *Fn = getCurFunction();
364 assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine");
366 // Coroutines [stmt.return]p1:
367 // A return statement shall not appear in a coroutine.
368 if (Fn->FirstReturnLoc.isValid()) {
369 Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
370 auto *First = Fn->CoroutineStmts[0];
371 Diag(First->getLocStart(), diag::note_declared_coroutine_here)
372 << (isa<CoawaitExpr>(First) ? 0 :
373 isa<CoyieldExpr>(First) ? 1 : 2);
376 bool AnyCoawaits = false;
377 bool AnyCoyields = false;
378 for (auto *CoroutineStmt : Fn->CoroutineStmts) {
379 AnyCoawaits |= isa<CoawaitExpr>(CoroutineStmt);
380 AnyCoyields |= isa<CoyieldExpr>(CoroutineStmt);
383 if (!AnyCoawaits && !AnyCoyields)
384 Diag(Fn->CoroutineStmts.front()->getLocStart(),
385 diag::ext_coroutine_without_co_await_co_yield);
387 SourceLocation Loc = FD->getLocation();
389 // Form a declaration statement for the promise declaration, so that AST
390 // visitors can more easily find it.
391 StmtResult PromiseStmt =
392 ActOnDeclStmt(ConvertDeclToDeclGroup(Fn->CoroutinePromise), Loc, Loc);
393 if (PromiseStmt.isInvalid())
394 return FD->setInvalidDecl();
396 // Form and check implicit 'co_await p.initial_suspend();' statement.
397 ExprResult InitialSuspend =
398 buildPromiseCall(*this, Fn, Loc, "initial_suspend", None);
399 // FIXME: Support operator co_await here.
400 if (!InitialSuspend.isInvalid())
401 InitialSuspend = BuildCoawaitExpr(Loc, InitialSuspend.get());
402 InitialSuspend = ActOnFinishFullExpr(InitialSuspend.get());
403 if (InitialSuspend.isInvalid())
404 return FD->setInvalidDecl();
406 // Form and check implicit 'co_await p.final_suspend();' statement.
407 ExprResult FinalSuspend =
408 buildPromiseCall(*this, Fn, Loc, "final_suspend", None);
409 // FIXME: Support operator co_await here.
410 if (!FinalSuspend.isInvalid())
411 FinalSuspend = BuildCoawaitExpr(Loc, FinalSuspend.get());
412 FinalSuspend = ActOnFinishFullExpr(FinalSuspend.get());
413 if (FinalSuspend.isInvalid())
414 return FD->setInvalidDecl();
416 // FIXME: Perform analysis of set_exception call.
418 // FIXME: Try to form 'p.return_void();' expression statement to handle
419 // control flowing off the end of the coroutine.
421 // Build implicit 'p.get_return_object()' expression and form initialization
422 // of return type from it.
423 ExprResult ReturnObject =
424 buildPromiseCall(*this, Fn, Loc, "get_return_object", None);
425 if (ReturnObject.isInvalid())
426 return FD->setInvalidDecl();
427 QualType RetType = FD->getReturnType();
428 if (!RetType->isDependentType()) {
429 InitializedEntity Entity =
430 InitializedEntity::InitializeResult(Loc, RetType, false);
431 ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType,
433 if (ReturnObject.isInvalid())
434 return FD->setInvalidDecl();
436 ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc);
437 if (ReturnObject.isInvalid())
438 return FD->setInvalidDecl();
440 // FIXME: Perform move-initialization of parameters into frame-local copies.
441 SmallVector<Expr*, 16> ParamMoves;
443 // Build body for the coroutine wrapper statement.
444 Body = new (Context) CoroutineBodyStmt(
445 Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(),
446 /*SetException*/nullptr, /*Fallthrough*/nullptr,
447 ReturnObject.get(), ParamMoves);