Chapter11 折叠表达式

自从C++17起,有一个新的特性可以计算对参数包中的 所有 参数应用一个二元运算符的结果。

例如,下面的函数将会返回所有参数的总和:

  1. template<typename... T>
  2. auto foldSum (T... args) {
  3. return (... + args); // ((arg1 + arg2) + arg3)...
  4. }

注意返回语句中的括号是折叠表达式的一部分,不能被省略。

如下调用:

  1. foldSum(47, 11, val, -1);

会把模板实例化为:

  1. return 47 + 11 + val + -1;

如下调用:

  1. foldsum(std::string("hello"), "world", "!");

会把模板实例化为:

  1. return std::string("hello") + "world" + "!";

注意折叠表达式里参数的位置很重要(可能看起来还有些反直觉)。如下写法:

  1. (... + args)

会展开为:

  1. ((arg1 + arg2) + arg3) ...

这意味着折叠表达式会以“后递增式”重复展开。你也可以写:

  1. (args + ...)

这样就会以“前递增式”展开,因此结果会变为:

  1. (arg1 + (arg2 + arg3)) ...

11.1 折叠表达式的动机

折叠表达式的出现让我们不必再用递归实例化模板的方式来处理参数包。 在C++17之前,你必须这么实现:

  1. template<typename T>
  2. auto foldSumRec (T arg) {
  3. return arg;
  4. }
  5. template<typename T1, typename... Ts>
  6. auto foldSumRec (T1 arg1, Ts... otherArgs) {
  7. return arg1 + foldSumRec(otherArgs...);
  8. }

这样的实现不仅写起来麻烦,对C++编译器来说也很难处理。使用如下写法:

  1. template<typename... T>
  2. auto foldSum (T... args) {
  3. return (... + args); // arg1 + arg2 + arg3...
  4. }

能显著的减少程序员和编译器的工作量。

11.2 使用折叠表达式

给定一个参数 args 和一个操作符 op ,C++17允许我们这么写:

  • 一元左折叠

( ... op args )

将会展开为:((arg1 op arg2) op arg3) op ...

  • 一元右折叠

( args op ... )

将会展开为: arg1 op (arg2 op ... (argN-1 op argN))

括号是必须的,然而,括号和省略号(…)之间并不需要用空格分隔。

左折叠和右折叠的不同比想象中更大。例如,当你使用 + 时可能会产生不同的效果。 使用左折叠时:

  1. template<typename... T>
  2. auto foldSumL(T... args) {
  3. return (... + args); // ((arg1 + arg2) + arg3)...
  4. }

如下调用

  1. foldSumL(1, 2, 3);

会求值为

  1. ((1 + 2) + 3)

这意味着下面的例子能够通过编译:

  1. std::cout << foldSumL(std::string("hello"), "world", "!") << '\n'; // OK

记住对字符串而言只有两侧至少有一个是std::string时才能使用+。 使用左折叠时会首先计算

  1. std::string("hello") + "world"

这将返回一个std::string,因此再加上字符串字面量"!"是有效的。

然而,如下调用

  1. std::cout << foldSumL("hello", "world", std::string("!")) << '\n'; // ERROR

将不能通过编译,因为它会求值为

  1. ("hello" + "world") + std::string("!");

然而把两个字符串字面量相加是错误的。

然而如果我们把实现修改为:

  1. template<typename... T>
  2. auto foldSumR(T... args) {
  3. return (args + ...); // (arg1 + (arg2 + arg3))...
  4. }

那么如下调用

  1. foldSumR(1, 2, 3)

将求值为

  1. (1 + (2 + 3))

这意味着下面的例子不能再通过编译:

  1. std::cout << foldSumR(std::string("hello"), "world", "!") << '\n'; // ERROR

然而如下调用现在反而可以编译了:

  1. std::cout << foldSumR("hello", "world", std::string("!")) << '\n'; // OK

在任何情况下,从左向右求值都是符合直觉的。 因此,更推荐使用左折叠的语法:

  1. (... + args); // 推荐的折叠表达式语法

11.2.1 处理空参数包

当使用折叠表达式处理空参数包时,将遵循如下规则:

  • 如果使用了&&运算符,值为true
  • 如果使用了||运算符,值为false
  • 如果使用了逗号运算符,值为void()
  • 使用所有其他的运算符,都会引发格式错误

对于所有其他的情况,你可以添加一个初始值: 给定一个参数包 args ,一个初始值 value ,一个操作符 op , C++17允许我们这么写:

  • 二元左折叠

( value op ... op args )

将会展开为:(((value op arg1) op arg2) op arg3) op ...

  • 二元右折叠

( args op ... op value )

将会展开为: arg1 op (arg2 op ... (argN op value))

省略号两侧的 op 必须相同。

例如,下面的定义在进行加法时允许传递一个空参数包:

  1. template<typename... T>
  2. auto foldSum (T... s) {
  3. return (0 + ... + s); // 即使sizeof...(s)==0也能工作
  4. }

从概念上讲,不管0是第一个还是最后一个操作数应该和结果无关:

  1. template<typename... T>
  2. auto foldSum (T... s) {
  3. return (s + ... + 0); // 即使sizeof...(s)==0也能工作
  4. }

然而,对于一元折叠表达式来说,不同的求值顺序比想象中的更重要。 对于二元表达式来说,也更推荐左折叠的方式:

  1. (val + ... + args); // 推荐的二元折叠表达式语法

有时候第一个操作数是特殊的,比如下面的例子:

  1. template<typename... T>
  2. void print (const T&... args)
  3. {
  4. (std::cout << ... << args) << '\n';
  5. }

这里,传递给print()的第一个参数输出之后将返回输出流,所以后面的参数可以继续输出。 其他的实现可能不能编译或者产生一些意料之外的结果。例如,

  1. std::cout << (args << ... << '\n');

类似print(1)的调用可以编译,但会打印出1左移'\n'位之后的值, '\n'的值通常是10,所以结果是1024。

注意在这个print()的例子中,两个参数之间没有输出空格字符。 因此,如果调用print("hello", 42, "world")将会打印出:

  1. hello42world

为了用空格分隔传入的参数, 你需要一个辅助函数来确保除了第一个参数之外的剩余参数输出前都先输出一个空格。 例如,使用如下的模板spaceBefore()可以做到这一点:

  1. template<typename T>
  2. const T& spaceBefore(const T& arg) {
  3. std::cout << ' ';
  4. return arg;
  5. }
  6. template <typename First, typename... Args>
  7. void print (const First& firstarg, const Args&... args) {
  8. std::cout << firstarg;
  9. (std::cout << ... << spaceBefore(args)) << '\n';
  10. }

这里,折叠表达式

  1. (std::cout << ... << spaceBefore(args))

将会展开为:

  1. std::cout << spaceBefore(arg1) << spaceBefore(arg2) << ...

因此,对于参数包中args的每一个参数,都会调用辅助函数, 在输出参数之前先输出一个空格到std::cout。 为了确保不会对第一个参数调用辅助函数,我们添加了额外的模板参数对第一个参数进行单独处理。

注意要想让参数包正确输出需要确保对每个参数调用spaceBefore()之前左侧的所有输出都已经完成。 得益于操作符<<的有定义的求值顺序,自从C++17起将保证行为正确。

我们也可以使用lambda来在print()内定义spaceBefore()

  1. template<typename First, typename... Args>
  2. void print (const First& firstarg, const Args&... args) {
  3. std::cout << firstarg;
  4. auto spaceBefore = [](const auto& arg) {
  5. std::cout << ' ';
  6. return arg;
  7. };
  8. (std::cout << ... << spaceBefore(args)) << '\n';
  9. }

然而,注意默认情况下lambda以值返回对象,这意味着会创建参数的不必要的拷贝。 解决方法是显式指明返回类型为const auto&或者decltype(auto)

  1. template<typename First, typename... Args>
  2. void print (const First& firstarg, const Args&... args) {
  3. std::cout << firstarg;
  4. auto spaceBefore = [](const auto& arg) -> const auto& {
  5. std::cout << ' ';
  6. return arg;
  7. };
  8. (std::cout << ... << spaceBefore(args)) << '\n';
  9. }

如果你不能把它们写在一个表达式里那么C++就不是C++了:

  1. template<typename First, typename... Args>
  2. void print (const First& firstarg, const Args&... args) {
  3. std::cout << firstarg;
  4. (std::cout << ... << [] (const auto& arg) -> decltype(auto) {
  5. std::cout << ' ';
  6. return arg;
  7. }(args)) << '\n';
  8. }

不过,一个更简单的实现print()的方法是使用一个lambda输出空格和参数, 然后在一元折叠表达式里使用它:

  1. template<typename First, typename... Args>
  2. void print(First first, const Args&... args) {
  3. std::cout << first;
  4. auto outWithSpace = [] (const auto& arg) {
  5. std::cout << ' ' << arg;
  6. };
  7. (... , outWithSpace(args));
  8. std::cout << '\n';
  9. }

通过使用新的auto模板参数,我们可以使print() 变得更加灵活:可以把间隔符定义为一个参数,这个参数可以是一个字符、一个字符串或者其它任何可打印的类型。

11.2.2 支持的运算符

你可以对除了.->[]之外的所有二元运算符使用折叠表达式。

折叠函数调用

折叠表达式可以用于逗号运算符,这样就可以在一条语句里进行多次函数调用。 也就是说,你现在可以简单写出如下实现:

  1. template<typename... Types>
  2. void callFoo(const Types&... args)
  3. {
  4. ...
  5. (... , foo(args)); // 调用foo(arg1),foo(arg2),foo(arg3),...
  6. }

来对所有参数调用函数foo()

另外,如果需要支持移动语义:

  1. template<typename... Types>
  2. void callFoo(Types&&... args)
  3. {
  4. ...
  5. (... , foo(std::forward<Types>(args))); // 调用foo(arg1),foo(arg2),...
  6. }

如果foo()函数返回的类型重载了逗号运算符,那么代码行为可能会改变。 为了保证这种情况下代码依然安全,你需要把返回值转换为void

  1. template<typename... Types>
  2. void callFoo(const Types&... args)
  3. {
  4. ...
  5. (... , (void)foo(std::forward<Types>(args))); // 调用foo(arg1),foo(arg2),...
  6. }

注意自然情况下,对于逗号运算符不管我们是左折叠还是右折叠都是一样的。 函数调用们总是会从左向右执行。如下写法:

  1. (foo(args) , ...);

中的括号只是把后边的调用括在一起,因此首先是第一个foo()调用, 之后是被括起来的两个foo()调用:

  1. foo(arg1) , (foo(arg2) , foo(arg3));

然而,因为逗号表达式的求值顺序通常是自左向右,所以第一个调用通常发生在括号里的两个调用之前, 并且括号里左侧的调用在右侧的调用之前。

不过,因为左折叠更符合自然的求值顺序,因此在使用折叠表达式进行多次函数调用时还是推荐使用左折叠。

组合哈希函数

另一个使用逗号折叠表达式的例子是组合哈希函数。可以用如下的方法完成:

  1. template<typename T>
  2. void hashCombine (std::size_t& seed, const T& val)
  3. {
  4. seed ^= std::hash<T>()(val) + 0x9e3779b9 + (seed<<6) + (seed>>2);
  5. }
  6. template<typename... Types>
  7. std::size_t combinedHashValue (const Types&... args)
  8. {
  9. std::size_t seed = 0; // 初始化seed
  10. (... , hashCombine(seed, args)); // 链式调用hashCombine()
  11. return seed;
  12. }

如下调用

  1. combinedHashValue ("Hi", "World", 42);

函数中的折叠表达式将被展开为:

  1. hashCombine(seed, "Hi"), (hashCombine(seed, "World"), hashCombine(seed, 42));

有了这些定义,我们现在可以轻易的定义出一个新的哈希函数, 并将这个函数用于某一个类型例如Customer的unordered set或unordered map:

  1. sturct CustomerHash
  2. {
  3. std::size_t operator() (const Customer& c) const {
  4. return combinedHashValue(c.getFirstname(), c.getLastname(), c.getValue());
  5. }
  6. };
  7. std::unordered_set<Customer, CustomerHash> coll;
  8. std::unordered_map<Customer, std::string, CustomerHash> map;

折叠基类的函数调用

折叠表达式可以在更复杂的场景中使用。例如,你可以折叠逗号表达式来调用可变数量基类的成员函数:

  1. #include <iostream>
  2. // 可变数量基类的模板
  3. template<typename... Bases>
  4. class MultiBase : private Bases...
  5. {
  6. public:
  7. void print() {
  8. // 调用所有基类的print()函数
  9. (... , Bases::print());
  10. }
  11. };
  12. struct A {
  13. void print() { std::cout << "A::print()\n"; }
  14. }
  15. struct B {
  16. void print() { std::cout << "B::print()\n"; }
  17. }
  18. struct C {
  19. void print() { std::cout << "C::print()\n"; }
  20. }
  21. int main()
  22. {
  23. MultiBase<A, B, C> mb;
  24. mb.print();
  25. }

这里

  1. template<typename... Bases>
  2. class MultiBase : private Bases...
  3. {
  4. ...
  5. };

允许我们用可变数量的基类初始化对象:

  1. MultiBase<A, B, C> mb;

之后,通过

  1. (... , Bases::print());

这个折叠表达式将展开为调用每一个基类中的print。 也就是说,这条语句会被展开为如下代码:

  1. (A::print(), B::print()), C::print();

折叠路径遍历

你也可以使用折叠表达式通过运算符->\*遍历一个二叉树中的路径。 考虑下面的递归数据结构:

  1. // 定义二叉树结构和遍历辅助函数
  2. struct Node {
  3. int value;
  4. Node *subLeft{nullptr};
  5. Node *subRight{nullptr};
  6. Node(int i = 0) : value{i} {
  7. }
  8. int getValue() const {
  9. return value;
  10. }
  11. ...
  12. // 遍历辅助函数:
  13. static constexpr auto left = &Node::subLeft;
  14. static constexpr auto right = &Node::subRight;
  15. // 使用折叠表达式遍历树:
  16. template<typename T, typename... TP>
  17. static Node* traverse(T np, TP... paths) {
  18. return (np ->* ... ->* paths); // np ->* paths1 ->* paths2
  19. }
  20. };

这里,

  1. (np ->* ... ->* paths)

使用了折叠表达式以np为起点遍历可变长度的路径,可以像下面这样使用这个函数:

  1. #include "foldtraverse.hpp"
  2. #include <iostream>
  3. int main()
  4. {
  5. // 初始化二叉树结构:
  6. Node* root = new Node{0};
  7. root->subLeft = new Node{1};
  8. root->subLeft->subRight = new Node{2};
  9. ...
  10. // 遍历二叉树:
  11. Node* node = Node::traverse(root, Node::left, Node::right);
  12. std::cout << node->getValue() << '\n';
  13. node = root ->* Node::left ->* Node::right;
  14. std::cout << node->getValue() << '\n';
  15. node = root -> subLeft -> subRight;
  16. std::cout << node->getValue() << '\n';
  17. }

当调用

  1. Node::traverse(root, Node::left, Node::right);

时折叠表达式将展开为:

  1. root ->* Node::left ->* Node::right

结果等价于

  1. root -> subLeft -> subRight

11.2.3 使用折叠表达式处理类型

通过使用类型特征,我们也可以使用折叠表达式来处理模板参数包(任意数量的模板类型参数)。 例如,你可以使用折叠表达式来判断一些类型是否相同:

  1. #include <type_traits>
  2. // 检查是否所有类型都相同:
  3. template<typename T1, typename... TN>
  4. struct IsHomogeneous {
  5. static constexpr bool value = (std::is_same_v<T1, TN> && ...);
  6. };
  7. // 检查是否所有传入的参数类型相同:
  8. template<typename T1, typename... TN>
  9. constexpr bool isHomogeneous(T1, TN...)
  10. {
  11. return (std::is_same_v<T1, TN> && ...);
  12. }

类型特征IsHomogeneous<>可以像下面这样使用:

  1. IsHomogeneous<int, MyType, decltype(42)>::value

这种情况下,折叠表达式将会展开为:

  1. std::is_same_v<int, MyType> && std::is_same_v<int, decltype(42)>

函数模板isHomogeneous<>()可以像下面这样使用:

  1. isHomogeneous(43, -1, "hello", nullptr)

在这种情况下,折叠表达式将会展开为:

  1. std::is_same_v<int, int> && std::is_same_v<int, const char*> && is_same_v<int, std::nullptr_t>

像通常一样,运算符&&会短路求值(出现第一个false时就会停止运算)。

标准库里std::array<>的推导指引就使用了这个特性。