SimpleArray
首先实现两个数组相加,相乘的功能
#ifndef UNTITLED1__SIMPLEARRAY_HPP_
#define UNTITLED1__SIMPLEARRAY_HPP_
#include <vector>
#include <iostream>
class SimpleArray {
public:
explicit SimpleArray(size_t size);
double &operator[](size_t index);
double operator[](size_t index) const;
size_t size() const;
private:
std::vector<double> data;
friend SimpleArray operator+(const SimpleArray &lhs, const SimpleArray &rhs);
friend SimpleArray operator*(const SimpleArray &lhs, const SimpleArray &rhs);
};
#endif //UNTITLED1__SIMPLEARRAY_HPP_
#include <cassert>
#include "SimpleArray.hpp"
SimpleArray::SimpleArray(size_t size): data(size) {}
double &SimpleArray::operator[](size_t index) {
return data[index];
}
double SimpleArray::operator[](size_t index) const {
return data[index];
}
size_t SimpleArray::size() const {
return data.size();
}
SimpleArray operator+(const SimpleArray &lhs, const SimpleArray &rhs) {
assert(lhs.data.size() == rhs.data.size());
SimpleArray result(lhs.data.size());
for (size_t i = 0; i < lhs.data.size(); ++i) {
result[i] = lhs[i] + rhs[i];
}
return result;
}
SimpleArray operator*(const SimpleArray &lhs, const SimpleArray &rhs) {
assert(lhs.data.size() == rhs.data.size());
SimpleArray result(lhs.data.size());
for (size_t i = 0; i < lhs.data.size(); ++i) {
result[i] = lhs[i] * rhs[i];
}
return result;
}
int main() {
size_t size = 10;
SimpleArray simple_arrayA(size);
SimpleArray simple_arrayB(size);
for (size_t i = 0; i < size; ++i) {
simple_arrayA[i] = drand48();
simple_arrayB[i] = drand48();
}
SimpleArray simple_arrayC = simple_arrayA + simple_arrayB * simple_arrayA;
return 0;
}
实际上我们做了 2n 次计算,产生了 3 个临时对象,但是我们完全可以在一次循环中计算出值只需要 n 次计算。
惰性计算
上面的问题在于计算的太着急了,在不知道表达式全貌的情况下就开始计算,正确的应该是等表达式写完赋值的那个步骤再去计算,在此之前我们不会真正去计算,而是把表达式存储起来。
和 SimpleArray 一样,将所有的构造和所有的下标操作符都转发给 data
template<typename ArrayType=SimpleArray>
class ExprArray {
public:
explicit ExprArray(size_t size): data(size) {};
double &operator[](size_t index) {return data[index];}
double operator[](size_t index) const {return data[index];}
size_t size() const { return data.size(); }
private:
ArrayType data;
};
在使用上也是完全一致先初始化 ExprArray ,ExprArray 内部初始化 vector ,同理于下标操作。但是目前乘法和加法没有实现,编译报错。
int main() {
size_t size = 10;
ExprArray a(size);
ExprArray b(size);
for (size_t i = 0; i < size; ++i) {
a[i] = drand48();
b[i] = drand48();
}
ExprArray c = a + b * a;
return 0;
}
表达式模板
我们需要一个结构来存储表达式,考虑最简单的情况 LType 和 RType 都是 SimpleArray 使用 operator[] 来获取最后结果。
SumArray 存储了,符号,两边的类型,两边类型的对象。
template<typename LType, typename RType>
class SumArray {
public:
SumArray(const LType &lhs, const RType &rhs): lhs(lhs), rhs(rhs) {
assert(lhs.size() == rhs.size());
};
double operator[](size_t index) const {
return lhs[index] + rhs[index];
}
size_t size() const { return lhs.size(); }
private:
const LType &lhs;
const RType &rhs;
};
template<typename LType, typename RType>
ExprArray<SumArray<LType, RType>>
operator +(const ExprArray<LType> &lhs, const ExprArray<RType> &rhs) {
return ExprArray<SumArray<LType, RType>>(SumArray<LType, RType>(lhs.getData(), rhs.getData()));
}
首先先用 const ExprArray<LType> &lhs, const ExprArray<RType> &rhs
将两个 type 从 ExprArray 中剥离出来(默认情况都是SimpleArray)。
返回值在用 SumArray 对象去初始化 ExprArray
#ifndef UNTITLED1__EXPRARRAY_HPP_
#define UNTITLED1__EXPRARRAY_HPP_
#include "SimpleArray.hpp"
#include <cassert>
template<typename ArrayType>
class ExprArray {
public:
explicit ExprArray(size_t size): data(size) {};
explicit ExprArray(const ArrayType &array): data(array) {};
template<typename OtherType>
ExprArray(const ExprArray<OtherType> &array): data(array.size()) {
for (size_t i = 0; i < array.size(); ++i) {
data[i] = array[i];
}
};
double &operator[](size_t index) {return data[index];}
double operator[](size_t index) const {return data[index];}
size_t size() const { return data.size(); }
const ArrayType &getData() const { return data; }
private:
ArrayType data;
};
template<typename LType, typename RType>
class ProductArray {
public:
ProductArray(const LType &lhs, const RType &rhs): lhs(lhs), rhs(rhs) {
assert(lhs.size() == rhs.size());
};
double operator[](size_t index) const {
return lhs[index] * rhs[index];
}
size_t size() const { return lhs.size(); }
private:
const LType &lhs;
const RType &rhs;
};
template<typename LType, typename RType>
class SumArray {
public:
SumArray(const LType &lhs, const RType &rhs): lhs(lhs), rhs(rhs) {
assert(lhs.size() == rhs.size());
};
double operator[](size_t index) const {
return lhs[index] + rhs[index];
}
size_t size() const { return lhs.size(); }
private:
const LType &lhs;
const RType &rhs;
};
template<typename LType, typename RType>
ExprArray<ProductArray<LType, RType>>operator *(const ExprArray<LType> &lhs, const ExprArray<RType> &rhs) {
return ExprArray<ProductArray<LType, RType>>(ProductArray<LType, RType>(lhs.getData(), rhs.getData()));
}
template<typename LType, typename RType>
ExprArray<SumArray<LType, RType>>operator +(const ExprArray<LType> &lhs, const ExprArray<RType> &rhs) {
return ExprArray<SumArray<LType, RType>>(SumArray<LType, RType>(lhs.getData(), rhs.getData()));
}
#endif //UNTITLED1__EXPRARRAY_HPP_
最后利用一个循环去计算 ExprArray<SimpleArray> res = a + a * b;
。