SimpleArray

首先实现两个数组相加,相乘的功能

  1. #ifndef UNTITLED1__SIMPLEARRAY_HPP_
  2. #define UNTITLED1__SIMPLEARRAY_HPP_
  3. #include <vector>
  4. #include <iostream>
  5. class SimpleArray {
  6. public:
  7. explicit SimpleArray(size_t size);
  8. double &operator[](size_t index);
  9. double operator[](size_t index) const;
  10. size_t size() const;
  11. private:
  12. std::vector<double> data;
  13. friend SimpleArray operator+(const SimpleArray &lhs, const SimpleArray &rhs);
  14. friend SimpleArray operator*(const SimpleArray &lhs, const SimpleArray &rhs);
  15. };
  16. #endif //UNTITLED1__SIMPLEARRAY_HPP_
  1. #include <cassert>
  2. #include "SimpleArray.hpp"
  3. SimpleArray::SimpleArray(size_t size): data(size) {}
  4. double &SimpleArray::operator[](size_t index) {
  5. return data[index];
  6. }
  7. double SimpleArray::operator[](size_t index) const {
  8. return data[index];
  9. }
  10. size_t SimpleArray::size() const {
  11. return data.size();
  12. }
  13. SimpleArray operator+(const SimpleArray &lhs, const SimpleArray &rhs) {
  14. assert(lhs.data.size() == rhs.data.size());
  15. SimpleArray result(lhs.data.size());
  16. for (size_t i = 0; i < lhs.data.size(); ++i) {
  17. result[i] = lhs[i] + rhs[i];
  18. }
  19. return result;
  20. }
  21. SimpleArray operator*(const SimpleArray &lhs, const SimpleArray &rhs) {
  22. assert(lhs.data.size() == rhs.data.size());
  23. SimpleArray result(lhs.data.size());
  24. for (size_t i = 0; i < lhs.data.size(); ++i) {
  25. result[i] = lhs[i] * rhs[i];
  26. }
  27. return result;
  28. }
  1. int main() {
  2. size_t size = 10;
  3. SimpleArray simple_arrayA(size);
  4. SimpleArray simple_arrayB(size);
  5. for (size_t i = 0; i < size; ++i) {
  6. simple_arrayA[i] = drand48();
  7. simple_arrayB[i] = drand48();
  8. }
  9. SimpleArray simple_arrayC = simple_arrayA + simple_arrayB * simple_arrayA;
  10. return 0;
  11. }

实际上我们做了 2n 次计算,产生了 3 个临时对象,但是我们完全可以在一次循环中计算出值只需要 n 次计算。

惰性计算

上面的问题在于计算的太着急了,在不知道表达式全貌的情况下就开始计算,正确的应该是等表达式写完赋值的那个步骤再去计算,在此之前我们不会真正去计算,而是把表达式存储起来。

和 SimpleArray 一样,将所有的构造和所有的下标操作符都转发给 data

  1. template<typename ArrayType=SimpleArray>
  2. class ExprArray {
  3. public:
  4. explicit ExprArray(size_t size): data(size) {};
  5. double &operator[](size_t index) {return data[index];}
  6. double operator[](size_t index) const {return data[index];}
  7. size_t size() const { return data.size(); }
  8. private:
  9. ArrayType data;
  10. };

在使用上也是完全一致先初始化 ExprArray ,ExprArray 内部初始化 vector ,同理于下标操作。但是目前乘法和加法没有实现,编译报错。

  1. int main() {
  2. size_t size = 10;
  3. ExprArray a(size);
  4. ExprArray b(size);
  5. for (size_t i = 0; i < size; ++i) {
  6. a[i] = drand48();
  7. b[i] = drand48();
  8. }
  9. ExprArray c = a + b * a;
  10. return 0;
  11. }

表达式模板

我们需要一个结构来存储表达式,考虑最简单的情况 LType 和 RType 都是 SimpleArray 使用 operator[] 来获取最后结果。

SumArray 存储了,符号,两边的类型,两边类型的对象。

  1. template<typename LType, typename RType>
  2. class SumArray {
  3. public:
  4. SumArray(const LType &lhs, const RType &rhs): lhs(lhs), rhs(rhs) {
  5. assert(lhs.size() == rhs.size());
  6. };
  7. double operator[](size_t index) const {
  8. return lhs[index] + rhs[index];
  9. }
  10. size_t size() const { return lhs.size(); }
  11. private:
  12. const LType &lhs;
  13. const RType &rhs;
  14. };
  15. template<typename LType, typename RType>
  16. ExprArray<SumArray<LType, RType>>
  17. operator +(const ExprArray<LType> &lhs, const ExprArray<RType> &rhs) {
  18. return ExprArray<SumArray<LType, RType>>(SumArray<LType, RType>(lhs.getData(), rhs.getData()));
  19. }

首先先用 const ExprArray<LType> &lhs, const ExprArray<RType> &rhs 将两个 type 从 ExprArray 中剥离出来(默认情况都是SimpleArray)。

返回值在用 SumArray 对象去初始化 ExprArray 对象,所以需要实现对应的构造函数

  1. #ifndef UNTITLED1__EXPRARRAY_HPP_
  2. #define UNTITLED1__EXPRARRAY_HPP_
  3. #include "SimpleArray.hpp"
  4. #include <cassert>
  5. template<typename ArrayType>
  6. class ExprArray {
  7. public:
  8. explicit ExprArray(size_t size): data(size) {};
  9. explicit ExprArray(const ArrayType &array): data(array) {};
  10. template<typename OtherType>
  11. ExprArray(const ExprArray<OtherType> &array): data(array.size()) {
  12. for (size_t i = 0; i < array.size(); ++i) {
  13. data[i] = array[i];
  14. }
  15. };
  16. double &operator[](size_t index) {return data[index];}
  17. double operator[](size_t index) const {return data[index];}
  18. size_t size() const { return data.size(); }
  19. const ArrayType &getData() const { return data; }
  20. private:
  21. ArrayType data;
  22. };
  23. template<typename LType, typename RType>
  24. class ProductArray {
  25. public:
  26. ProductArray(const LType &lhs, const RType &rhs): lhs(lhs), rhs(rhs) {
  27. assert(lhs.size() == rhs.size());
  28. };
  29. double operator[](size_t index) const {
  30. return lhs[index] * rhs[index];
  31. }
  32. size_t size() const { return lhs.size(); }
  33. private:
  34. const LType &lhs;
  35. const RType &rhs;
  36. };
  37. template<typename LType, typename RType>
  38. class SumArray {
  39. public:
  40. SumArray(const LType &lhs, const RType &rhs): lhs(lhs), rhs(rhs) {
  41. assert(lhs.size() == rhs.size());
  42. };
  43. double operator[](size_t index) const {
  44. return lhs[index] + rhs[index];
  45. }
  46. size_t size() const { return lhs.size(); }
  47. private:
  48. const LType &lhs;
  49. const RType &rhs;
  50. };
  51. template<typename LType, typename RType>
  52. ExprArray<ProductArray<LType, RType>>operator *(const ExprArray<LType> &lhs, const ExprArray<RType> &rhs) {
  53. return ExprArray<ProductArray<LType, RType>>(ProductArray<LType, RType>(lhs.getData(), rhs.getData()));
  54. }
  55. template<typename LType, typename RType>
  56. ExprArray<SumArray<LType, RType>>operator +(const ExprArray<LType> &lhs, const ExprArray<RType> &rhs) {
  57. return ExprArray<SumArray<LType, RType>>(SumArray<LType, RType>(lhs.getData(), rhs.getData()));
  58. }
  59. #endif //UNTITLED1__EXPRARRAY_HPP_

最后利用一个循环去计算 ExprArray<SimpleArray> res = a + a * b;