反向求导算法
之前将了前向求导实现,现在就来讲一下反向求导实现
反向求导的数学基础是链式法则
即求导的过程是从函数输出层开始,逐层向前求导,直到输入层,这样就可以求出每一层的梯度
对于标量求导,即可以自动求出所有变量的梯度
比如对于函数,利用链式法则求导
$Y=F·G,F=a+b,G=b+1$
$\frac{\partial Y}{\partial a}$
$=\frac{\partial Y}{\partial F}·\frac{\partial F}{\partial a}+\frac{\partial Y}{\partial G}·\frac{\partial G}{\partial a}$
$=G + 0$
$=b + 1$
其语法树如下:

根据语法树求导,同路径下梯度相乘,不同路径梯度相加
$\mathrm{d}a = (b+1) \times 1,\mathrm{d}b=(b+1) \times 1 + (a+b) \times 1$
而不同的节点,比如加减乘除,其梯度的计算十分简单
通过把一个复杂的表达式转换为一颗只包含基本运算的语法二叉树
从上往下进行梯度迭代即可计算出每个变量的梯度
这与前向求导只能计算$\mathrm{d}Y$不同,反向求导可以算出$\frac{\partial Y}{\partial a},\frac{\partial Y}{\partial a}$
相加即可算出 $\mathrm{d}Y$
因此在工程应用时也是反向求导居多
以上就是反向求导的基本原理
现在我们来编写代码如何根据函数表达式构造语法二叉树
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
| #include <iostream> #include <math.h> using namespace std;
typedef enum { CONST, VAR, ADD, SUB, MUL, DIV, POW, }Operator;
class Qtype { public: Operator oper=CONST; float val; float grad = 0; Qtype* left = nullptr; Qtype* right = nullptr; Qtype() { oper = CONST; val = 0; } Qtype(Operator oper) { this->oper = oper; } Qtype(float a) { val = a; oper = CONST; } Qtype(float a, Operator oper) { val = a; this->oper = oper; } Qtype(float a, float grad, Operator oper) { val = a; this->grad = grad; this->oper = oper; } Qtype& operator+(Qtype& other) { Qtype* res = new Qtype(ADD); res->left = this; res->right = &other; return *res; } Qtype& operator-(Qtype& other) { Qtype* res = new Qtype(SUB); res->left = this; res->right = &other; return *res; } Qtype& operator*(Qtype& other) { Qtype* res = new Qtype(MUL); res->left = this; res->right = &other; return *res; } Qtype& operator/(Qtype& other) { Qtype* res = new Qtype(DIV); res->left = this; res->right = &other; return *res; } Qtype& operator=(Qtype& other) { val = other.val; grad = other.grad; oper = other.oper; left = other.left; right = other.right; return *this; } Qtype& operator^(Qtype& other) { Qtype* res = new Qtype(POW); res->left = this; res->right = &other; return *res; } Qtype& operator+=(Qtype& other) { Qtype* copy = new Qtype(*this); this->oper = ADD; this->left = copy; this->right = &other; return *this; } void Process(float grad) { switch (oper) { case CONST: grad = 0; break; case VAR: this->grad += grad; break; case ADD: left->Process(grad); right->Process(grad); break; case SUB: left->Process(grad); right->Process(-grad); break; case MUL: left->Process(right->val * grad); right->Process(left->val * grad); break; case DIV: left->Process(1 / right->val * grad); right->Process(-left->val / (right->val * right->val) * grad); break; case POW: left->Process(right->val * pow(left->val, right->val - 1) * grad); right->Process(val * log(left->val) * grad); break; default: break; } } void Forward() { switch (oper) { case CONST: break; case VAR: grad = 0; break; case ADD: left->Forward(); right->Forward(); val = left->val + right->val; break; case SUB: left->Forward(); right->Forward(); val = left->val - right->val; break; case MUL: left->Forward(); right->Forward(); val = left->val * right->val; break; case DIV: left->Forward(); right->Forward(); val = left->val / right->val; break; case POW: left->Forward(); right->Forward(); val = pow(left->val, right->val); break; default: break; } }
void Backward() { Process(1); } void PrintTree() { switch (oper) { case CONST: cout << val; break; case VAR: cout << "x"; break; case ADD: cout << "("; left->PrintTree(); cout << "+"; right->PrintTree(); cout << ")"; break; case SUB: cout << "("; left->PrintTree(); cout << "-"; right->PrintTree(); cout << ")"; break; case MUL: cout << "("; left->PrintTree(); cout << "*"; right->PrintTree(); cout << ")"; break; case DIV: cout << "("; left->PrintTree(); cout << "/"; right->PrintTree(); cout << ")"; break; case POW: cout << "("; left->PrintTree(); cout << "^"; right->PrintTree(); cout << ")"; break; default: break; } } };
|
代码的复杂度与前向求导实际上差不多
同样是通过重载运算符来实现自动构造语法二叉树
测试代码如下
1 2 3 4 5 6 7 8 9 10
| int main() { Qtype a(0, VAR), b(0, VAR),k(1); Qtype Y = (a + b) * (b + k); Y.Forward(); Y.Backward(); Y.PrintTree(); cout << endl; cout << a.grad << endl; cout << b.grad << endl; }
|
输出如下