SimpleArray
首先实现两个数组相加,相乘的功能
#ifndef UNTITLED1__SIMPLEARRAY_HPP_#define UNTITLED1__SIMPLEARRAY_HPP_#include <vector>#include <iostream>class SimpleArray {public:explicit SimpleArray(size_t size);double &operator[](size_t index);double operator[](size_t index) const;size_t size() const;private:std::vector<double> data;friend SimpleArray operator+(const SimpleArray &lhs, const SimpleArray &rhs);friend SimpleArray operator*(const SimpleArray &lhs, const SimpleArray &rhs);};#endif //UNTITLED1__SIMPLEARRAY_HPP_
#include <cassert>#include "SimpleArray.hpp"SimpleArray::SimpleArray(size_t size): data(size) {}double &SimpleArray::operator[](size_t index) {return data[index];}double SimpleArray::operator[](size_t index) const {return data[index];}size_t SimpleArray::size() const {return data.size();}SimpleArray operator+(const SimpleArray &lhs, const SimpleArray &rhs) {assert(lhs.data.size() == rhs.data.size());SimpleArray result(lhs.data.size());for (size_t i = 0; i < lhs.data.size(); ++i) {result[i] = lhs[i] + rhs[i];}return result;}SimpleArray operator*(const SimpleArray &lhs, const SimpleArray &rhs) {assert(lhs.data.size() == rhs.data.size());SimpleArray result(lhs.data.size());for (size_t i = 0; i < lhs.data.size(); ++i) {result[i] = lhs[i] * rhs[i];}return result;}
int main() {size_t size = 10;SimpleArray simple_arrayA(size);SimpleArray simple_arrayB(size);for (size_t i = 0; i < size; ++i) {simple_arrayA[i] = drand48();simple_arrayB[i] = drand48();}SimpleArray simple_arrayC = simple_arrayA + simple_arrayB * simple_arrayA;return 0;}
实际上我们做了 2n 次计算,产生了 3 个临时对象,但是我们完全可以在一次循环中计算出值只需要 n 次计算。
惰性计算
上面的问题在于计算的太着急了,在不知道表达式全貌的情况下就开始计算,正确的应该是等表达式写完赋值的那个步骤再去计算,在此之前我们不会真正去计算,而是把表达式存储起来。
和 SimpleArray 一样,将所有的构造和所有的下标操作符都转发给 data
template<typename ArrayType=SimpleArray>class ExprArray {public:explicit ExprArray(size_t size): data(size) {};double &operator[](size_t index) {return data[index];}double operator[](size_t index) const {return data[index];}size_t size() const { return data.size(); }private:ArrayType data;};
在使用上也是完全一致先初始化 ExprArray ,ExprArray 内部初始化 vector ,同理于下标操作。但是目前乘法和加法没有实现,编译报错。
int main() {size_t size = 10;ExprArray a(size);ExprArray b(size);for (size_t i = 0; i < size; ++i) {a[i] = drand48();b[i] = drand48();}ExprArray c = a + b * a;return 0;}
表达式模板
我们需要一个结构来存储表达式,考虑最简单的情况 LType 和 RType 都是 SimpleArray 使用 operator[] 来获取最后结果。
SumArray 存储了,符号,两边的类型,两边类型的对象。
template<typename LType, typename RType>class SumArray {public:SumArray(const LType &lhs, const RType &rhs): lhs(lhs), rhs(rhs) {assert(lhs.size() == rhs.size());};double operator[](size_t index) const {return lhs[index] + rhs[index];}size_t size() const { return lhs.size(); }private:const LType &lhs;const RType &rhs;};template<typename LType, typename RType>ExprArray<SumArray<LType, RType>>operator +(const ExprArray<LType> &lhs, const ExprArray<RType> &rhs) {return ExprArray<SumArray<LType, RType>>(SumArray<LType, RType>(lhs.getData(), rhs.getData()));}
首先先用 const ExprArray<LType> &lhs, const ExprArray<RType> &rhs 将两个 type 从 ExprArray 中剥离出来(默认情况都是SimpleArray)。
返回值在用 SumArray 对象去初始化 ExprArray
#ifndef UNTITLED1__EXPRARRAY_HPP_#define UNTITLED1__EXPRARRAY_HPP_#include "SimpleArray.hpp"#include <cassert>template<typename ArrayType>class ExprArray {public:explicit ExprArray(size_t size): data(size) {};explicit ExprArray(const ArrayType &array): data(array) {};template<typename OtherType>ExprArray(const ExprArray<OtherType> &array): data(array.size()) {for (size_t i = 0; i < array.size(); ++i) {data[i] = array[i];}};double &operator[](size_t index) {return data[index];}double operator[](size_t index) const {return data[index];}size_t size() const { return data.size(); }const ArrayType &getData() const { return data; }private:ArrayType data;};template<typename LType, typename RType>class ProductArray {public:ProductArray(const LType &lhs, const RType &rhs): lhs(lhs), rhs(rhs) {assert(lhs.size() == rhs.size());};double operator[](size_t index) const {return lhs[index] * rhs[index];}size_t size() const { return lhs.size(); }private:const LType &lhs;const RType &rhs;};template<typename LType, typename RType>class SumArray {public:SumArray(const LType &lhs, const RType &rhs): lhs(lhs), rhs(rhs) {assert(lhs.size() == rhs.size());};double operator[](size_t index) const {return lhs[index] + rhs[index];}size_t size() const { return lhs.size(); }private:const LType &lhs;const RType &rhs;};template<typename LType, typename RType>ExprArray<ProductArray<LType, RType>>operator *(const ExprArray<LType> &lhs, const ExprArray<RType> &rhs) {return ExprArray<ProductArray<LType, RType>>(ProductArray<LType, RType>(lhs.getData(), rhs.getData()));}template<typename LType, typename RType>ExprArray<SumArray<LType, RType>>operator +(const ExprArray<LType> &lhs, const ExprArray<RType> &rhs) {return ExprArray<SumArray<LType, RType>>(SumArray<LType, RType>(lhs.getData(), rhs.getData()));}#endif //UNTITLED1__EXPRARRAY_HPP_
最后利用一个循环去计算 ExprArray<SimpleArray> res = a + a * b; 。
