Skip to content

Commit 50398c6

Browse files
committed
294
1 parent ab94f12 commit 50398c6

File tree

3 files changed

+277
-40
lines changed

3 files changed

+277
-40
lines changed

readme.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
前端界的好文精读,每周更新!
88

9-
最新精读:<a href="./机器学习/293.%E5%AE%9E%E7%8E%B0%E4%B8%87%E8%83%BD%E8%BF%91%E4%BC%BC%E5%87%BD%E6%95%B0%3A%20%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E7%9A%84%E6%9E%B6%E6%9E%84%E8%AE%BE%E8%AE%A1.md">293.实现万能近似函数: 神经网络的架构设计</a>
9+
最新精读:<a href="./机器学习/294.%E5%8F%8D%E5%90%91%E4%BC%A0%E6%92%AD%3A%20%E6%8F%AD%E7%A7%98%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E7%9A%84%E5%AD%A6%E4%B9%A0%E6%9C%BA%E5%88%B6.md">294.反向传播: 揭秘神经网络的学习机制</a>
1010

1111
素材来源:[周刊参考池](https://github.com/ascoders/weekly/issues/2)
1212

@@ -338,6 +338,7 @@
338338
- <a href="./机器学习/291.%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%80%E4%BB%8B%3A%20%E5%AF%BB%E6%89%BE%E5%87%BD%E6%95%B0%E7%9A%84%E8%89%BA%E6%9C%AF.md">291.机器学习简介: 寻找函数的艺术</a>
339339
- <a href="./机器学习/292.%E4%B8%87%E8%83%BD%E8%BF%91%E4%BC%BC%E5%AE%9A%E7%90%86%3A%20%E9%80%BC%E8%BF%91%E4%BB%BB%E4%BD%95%E5%87%BD%E6%95%B0%E7%9A%84%E7%90%86%E8%AE%BA.md">292.万能近似定理: 逼近任何函数的理论</a>
340340
- <a href="./机器学习/293.%E5%AE%9E%E7%8E%B0%E4%B8%87%E8%83%BD%E8%BF%91%E4%BC%BC%E5%87%BD%E6%95%B0%3A%20%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E7%9A%84%E6%9E%B6%E6%9E%84%E8%AE%BE%E8%AE%A1.md">293.实现万能近似函数: 神经网络的架构设计</a>
341+
- <a href="./机器学习/294.%E5%8F%8D%E5%90%91%E4%BC%A0%E6%92%AD%3A%20%E6%8F%AD%E7%A7%98%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E7%9A%84%E5%AD%A6%E4%B9%A0%E6%9C%BA%E5%88%B6.md">294.反向传播: 揭秘神经网络的学习机制</a>
341342

342343
### 生活
343344

机器学习/293.实现万能近似函数: 神经网络的架构设计.md

Lines changed: 56 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -86,31 +86,32 @@ neuralNetwork.fit(); // { loss: .. }
8686

8787
- w: `number[]`,表示上一层每个节点连接到该节点乘以的系数 w。
8888
- b: `number`,表示该节点的常数系数 b。
89-
- c: `number`,表示该节点的常数系数 c。
89+
- c: `number`,表示该节点的常数系数 c(该参数也可以省略)
9090

9191
我们可以定义神经网络数据结构如下:
9292

9393
```ts
94+
/** 神经网络结构数据 */
9495
type NetworkStructor = Array<{
9596
// 启动函数类型
96-
activation: "sigmoid" | "relu";
97+
activation: ActivationType;
9798
// 节点
98-
neurals: Array<{
99-
/** 当前该节点的值 */
100-
value: number | undefined;
101-
/** 上一层每个节点连接到该节点乘以的系数 w */
102-
w: Array<number>;
103-
/** 该节点的常数系数 b */
104-
b: number;
105-
/** 该节点的常数系数 c */
106-
c: number;
107-
}>;
99+
neurals: Neural[];
108100
}>;
101+
102+
interface Neural {
103+
/** 当前该节点的值 */
104+
value: number;
105+
/** 上一层每个节点连接到该节点乘以的系数 w */
106+
w: Array<number>;
107+
/** 该节点的常数系数 b */
108+
b: number;
109+
}
109110
```
110111

111112
则我们根据用户传入的 `layers` 来初始化神经网络对象,并对每个参数赋予一个初始值:
112113

113-
```ts
114+
```js
114115
class NeuralNetwork {
115116
// 输入长度
116117
private inputCount = 0;
@@ -122,23 +123,28 @@ class NeuralNetwork {
122123
constructor({
123124
trainingData,
124125
layers,
126+
trainingCount,
125127
}: {
126128
trainingData: TraningData;
127129
layers: Layer[];
130+
trainingCount: number;
128131
}) {
129132
this.trainingData = trainingData;
130133
this.inputCount = layers[0].inputCount!;
131-
this.networkStructor = layers.map(({ activation, count }, index) => ({
132-
activation,
133-
neurals: Array.from({ length: count }).map(() => ({
134-
value: undefined,
135-
w: Array.from({
136-
length: index === 0 ? this.inputCount : layers[index - 1].count,
137-
}).map(() => getRandomNumber()),
138-
b: getRandomNumber(),
139-
c: getRandomNumber(),
140-
})),
141-
}));
134+
this.trainingCount = trainingCount;
135+
this.networkStructor = layers.map(({ activation, count }, index) => {
136+
const previousNeuralCount = index === 0 ? this.inputCount : layers[index - 1].count;
137+
return {
138+
activation,
139+
neurals: Array.from({ length: count }).map(() => ({
140+
value: 0,
141+
w: Array.from({
142+
length: previousNeuralCount,
143+
}).map(() => getRandomNumber()),
144+
b: getRandomNumber(),
145+
})),
146+
};
147+
});
142148
}
143149
}
144150
```
@@ -159,7 +165,7 @@ type TrainingItem = [number[], number[]]
159165
160166
所以拿 `traningItem` 的第 0 项就是输入,`modelFunction` 就是根据输入得到预测的输出。
161167
162-
```ts
168+
```js
163169
class NeuralNetwork {
164170
/** 获取上一层神经网络各节点的值 */
165171
private getPreviousLayerValues(layerIndex: number, trainingItem: TraningItem) {
@@ -173,20 +179,22 @@ class NeuralNetwork {
173179
this.networkStructor.forEach((layer, layerIndex) => {
174180
layer.neurals.forEach((neural) => {
175181
// 前置节点的值 * w 的总和
176-
let weightCount = 0;
182+
let previousValueCountWithWeight = 0;
177183
this.getPreviousLayerValues(layerIndex - 1, trainingItem).forEach(
178184
(value, index) => {
179-
weightCount += value * neural.w[index];
185+
previousValueCountWithWeight += value * neural.w[index];
180186
},
181187
);
182-
const activateResult = activate(layer.activation)(weightCount + neural.b);
183-
neural.value = neural.c * activateResult;
188+
const activateResult = activate(layer.activation)(
189+
previousValueCountWithWeight + neural.b,
190+
);
191+
neural.value = activateResult;
184192
});
185193
});
186194

187195
// 输出最后一层网络的值
188196
return this.networkStructor[this.networkStructor.length - 1].neurals.map(
189-
(neural) => neural.value
197+
(neural) => neural.value,
190198
);
191199
}
192200
}
@@ -220,22 +228,31 @@ loss function 的输入也是 Training item,输出也是一个数字,这个
220228

221229
计算 loss 有很多种选择,我们选择一种最简单的均方差:
222230

223-
```ts
231+
```js
224232
class NeuralNetwork {
225233
private lossFunction(trainingItem: TraningItem) {
226234
// 预测值
227-
const y = this.modelFunction(trainingItem);
235+
const xList = this.modelFunction(trainingItem);
228236
// 实际值
229-
const t = trainingItem[1];
230-
// loss 最终值
231-
let loss = 0;
237+
const tList = trainingItem[1];
238+
239+
const lastLayer = this.networkStructor[this.networkStructor.length - 1];
240+
const lastLayerNeuralCount = lastLayer.neurals.length;
241+
// 最后一层每一个神经元在此样本的 loss
242+
const lossList: number[] = Array.from({ length: lastLayerNeuralCount }).map(() => 0);
243+
// 最后一层每一个神经元在此样本 loss 的导数
244+
const dlossByDxList: number[] = Array.from({ length: lastLayerNeuralCount }).map(
245+
() => 0,
246+
);
232247

233-
for (let i = 0; i < y.length; i++) {
234-
// l(t,y) = (t-y)²
235-
loss += Math.pow(t[i] - y[i]!, 2);
248+
for (let i = 0; i < xList.length; i++) {
249+
// loss(x) = (x-t)²
250+
lossList[i] = Math.pow(tList[i] - xList[i]!, 2);
251+
// ∂loss/∂x = 2 * (x-t)
252+
dlossByDxList[i] += 2 * (xList[i]! - tList[i]);
236253
}
237254

238-
return loss / y.length;
255+
return { lossList, dlossByDxList };
239256
}
240257
}
241258
```

0 commit comments

Comments
 (0)