tastynoob
Articles58
Tags18
Categories7
反向求导算法

反向求导算法

之前将了前向求导实现,现在就来讲一下反向求导实现

反向求导的数学基础是链式法则

即求导的过程是从函数输出层开始,逐层向前求导,直到输入层,这样就可以求出每一层的梯度

对于标量求导,即可以自动求出所有变量的梯度

比如对于函数,利用链式法则求导

$Y=F·G,F=a+b,G=b+1$
$\frac{\partial Y}{\partial a}$
$=\frac{\partial Y}{\partial F}·\frac{\partial F}{\partial a}+\frac{\partial Y}{\partial G}·\frac{\partial G}{\partial a}$
$=G + 0$
$=b + 1$

其语法树如下:

根据语法树求导,同路径下梯度相乘,不同路径梯度相加

$\mathrm{d}a = (b+1) \times 1,\mathrm{d}b=(b+1) \times 1 + (a+b) \times 1$

而不同的节点,比如加减乘除,其梯度的计算十分简单

通过把一个复杂的表达式转换为一颗只包含基本运算的语法二叉树

从上往下进行梯度迭代即可计算出每个变量的梯度

这与前向求导只能计算$\mathrm{d}Y$不同,反向求导可以算出$\frac{\partial Y}{\partial a},\frac{\partial Y}{\partial a}$

相加即可算出 $\mathrm{d}Y$

因此在工程应用时也是反向求导居多

以上就是反向求导的基本原理

现在我们来编写代码如何根据函数表达式构造语法二叉树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
#include <iostream>
#include <math.h>
using namespace std;

typedef enum {
CONST,
VAR,
ADD,
SUB,
MUL,
DIV,
POW,
}Operator;


class Qtype {
public:
Operator oper=CONST;
float val;
float grad = 0;
Qtype* left = nullptr;
Qtype* right = nullptr;
Qtype() {
oper = CONST;
val = 0;
}
Qtype(Operator oper) {
this->oper = oper;
}
Qtype(float a) {
val = a;
oper = CONST;
}
Qtype(float a, Operator oper) {
val = a;
this->oper = oper;
}
Qtype(float a, float grad, Operator oper) {
val = a;
this->grad = grad;
this->oper = oper;
}
Qtype& operator+(Qtype& other) {
Qtype* res = new Qtype(ADD);
res->left = this;
res->right = &other;
return *res;
}
Qtype& operator-(Qtype& other) {
Qtype* res = new Qtype(SUB);
res->left = this;
res->right = &other;
return *res;
}
Qtype& operator*(Qtype& other) {
Qtype* res = new Qtype(MUL);
res->left = this;
res->right = &other;
return *res;
}
Qtype& operator/(Qtype& other) {
Qtype* res = new Qtype(DIV);
res->left = this;
res->right = &other;
return *res;
}
Qtype& operator=(Qtype& other) {
val = other.val;
grad = other.grad;
oper = other.oper;
left = other.left;
right = other.right;
return *this;
}
Qtype& operator^(Qtype& other) {
Qtype* res = new Qtype(POW);
res->left = this;
res->right = &other;
return *res;
}
Qtype& operator+=(Qtype& other) {
Qtype* copy = new Qtype(*this);//将自身拷贝一份
this->oper = ADD;
this->left = copy;
this->right = &other;
return *this;
}
//反向求导
void Process(float grad) {
switch (oper) {
case CONST:
grad = 0;
break;//不做任何事情
case VAR:
this->grad += grad;//梯度增加
break;
case ADD:
left->Process(grad);
right->Process(grad);
break;
case SUB:
left->Process(grad);
right->Process(-grad);
break;
case MUL:
left->Process(right->val * grad);
right->Process(left->val * grad);
break;
case DIV:
left->Process(1 / right->val * grad);
right->Process(-left->val / (right->val * right->val) * grad);
break;
case POW:
left->Process(right->val * pow(left->val, right->val - 1) * grad);
right->Process(val * log(left->val) * grad);
break;
default:
break;
}
}
//前向传播
void Forward() {
switch (oper) {
case CONST:
break;
case VAR:
grad = 0;
break;
case ADD:
left->Forward();
right->Forward();
val = left->val + right->val;
break;
case SUB:
left->Forward();
right->Forward();
val = left->val - right->val;
break;
case MUL:
left->Forward();
right->Forward();
val = left->val * right->val;
break;
case DIV:
left->Forward();
right->Forward();
val = left->val / right->val;
break;
case POW:
left->Forward();
right->Forward();
val = pow(left->val, right->val);
break;
default:
break;
}
}

void Backward() {
Process(1);
}
//打印语法树
void PrintTree() {
switch (oper) {
case CONST:
cout << val;
break;
case VAR:
cout << "x";
break;
case ADD:
cout << "(";
left->PrintTree();
cout << "+";
right->PrintTree();
cout << ")";
break;
case SUB:
cout << "(";
left->PrintTree();
cout << "-";
right->PrintTree();
cout << ")";
break;
case MUL:
cout << "(";
left->PrintTree();
cout << "*";
right->PrintTree();
cout << ")";
break;
case DIV:
cout << "(";
left->PrintTree();
cout << "/";
right->PrintTree();
cout << ")";
break;
case POW:
cout << "(";
left->PrintTree();
cout << "^";
right->PrintTree();
cout << ")";
break;
default:
break;
}
}
};

代码的复杂度与前向求导实际上差不多
同样是通过重载运算符来实现自动构造语法二叉树

测试代码如下

1
2
3
4
5
6
7
8
9
10
int main() {
Qtype a(0, VAR), b(0, VAR),k(1);
Qtype Y = (a + b) * (b + k);
Y.Forward();//先前向传播一次计算各个节点的值,再反向传播计算梯度
Y.Backward();
Y.PrintTree();//打印语法树
cout << endl;
cout << a.grad << endl;//打印a的梯度
cout << b.grad << endl;//打印b的梯度
}

输出如下

1
2
3
((x+x)*(x+1))
1
1
Author:tastynoob
Link:https://tastynoob.github.io/1970/01/01/%E7%AE%97%E6%B3%95/%E5%8F%8D%E5%90%91%E6%B1%82%E5%AF%BC%E7%AE%97%E6%B3%95/
版权声明:本文采用 CC BY-NC-SA 3.0 CN 协议进行许可
×