-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTransformers.rs
More file actions
252 lines (196 loc) · 6.8 KB
/
Transformers.rs
File metadata and controls
252 lines (196 loc) · 6.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
const N_LAYERS: usize = 96; // Amount of transformer layers
const D_MODEL: usize = 12288; // Embedding size/hidden dimension
const D_MLP: usize = 4 * D_MODEL; // Dimension of MLP
const D_HEAD: usize = 128; // Dimension of each attention head
const N_HEADS: usize = D_MODEL / D_HEAD; // Number of heads
const N_VOCAB: usize = 50000; // Number of distinct tokens in vocab
type Token = u64;
type Logits = [f32; N_VOCAB];
trait ARModel {
fn apply(&self, tokens: &[Token]) -> Vec<Logits>;
}
// This defines a trait in Rust called ARModel. This defines behavior that types can implement
// Any type implementing "ARModel" must define the function apply
/*
&self: takes an immutable reference to the object (self)
tokens: &[Token]: takes a slice of Token (i.e. a borrowed view into a list of tokens)
-> Vec<Logits>: returns a Vec of Logits (likely the output scores from a model)
*/
// #[derive(Clone)]
// Defining the state/residual vector
struct State([f32; D_MODEL]);
// At every point inside the model, we have one State per token
// position
//This is a slice because you want dynamic length
type ResidualStream = [State];
// Query vector
type Query = State;
// Update Vector; Added to the state vector
type Update = State;
impl State{
fn zero() -> Self {
State([0.0; D_MODEL])
}
fn update(&self, right: &Update) -> State {
let mut out = self.clone();
for (i, r) in right.0.iter().enumerate(){
out.0[i] += r;
}
out
}
fn query(&self, right: &Query) -> f32 {
dot(&self.0, &right.0)
}
}
// <> specifies your parameters
// I think & means reference/pointer?
// out probably means the thing you return
// Querying something means dotting a vector with the query vector to create a single floating-point value.
fn dot<const N: usize>(l: &[f32; N], r: &[f32; N]) -> f32{
let mut out = 0.0;
for (i, r) in r.iter().enumerate() {
out += l[i] * r;
}
out
}
struct Transformer {
embedding: Embedding,
layers: [ResBlock; N_LAYERS],
unembedding: Unembedding,
}
struct Embedding([State; N_VOCAB]);
// Just a method you can call on "Embedding"
impl Embedding {
fn apply(&self, tok: Token) -> State {
self.0[tok as usize].clone()
}
}
struct LogitFn(Query);
impl LogitFn {
fn apply(&self, st: &State) -> f32 {
self.0.query(st)
}
}
struct Unembedding([LogitFn; N_VOCAB]);
impl Unembedding {
fn apply(&self, state: &State) -> Logits {
let mut out: Logits = [0.0; N_VOCAB];
for (i, f) in self.0.iter().enumerate() {
out[i] = f.apply(state);
}
out
}
}
// Each vocabulary element has a LogitFn which converts the state into
// a single floating-point value by querying the state according to some particular query.
struct ResBlock {
attn: AttnLayer,
mlps: MLPLayer,
}
// ATTENTION LAYER
struct AttnLayer {
heads: [AttnHead; N_HEADS],
}
type AttnVector = [f32; D_HEAD];
// This is the output of the attention layer
// Mult
struct AttnHead {
W_Q: Box<dyn Fn(&State) -> AttnVector>,
W_K: Box<dyn Fn(&State) -> AttnVector>,
W_V: Box<dyn Fn(&State) -> AttnVector>,
W_O: Box<dyn Fn(&AttnVector) -> Update>,
}
// Attention Head Implementation
impl AttnHead {
fn apply(&self, states: &[State]) -> Vec<Update> {
// Apply the Q, K, and V projections to produce Q, K, and V
let qs: Vec<AttnVector> = states.iter().map(&self.W_Q).collect();
let ks: Vec<AttnVector> = states.iter().map(&self.W_K).collect();
let vs: Vec<AttnVector> = states.iter().map(&self.W_V).collect();
// Iterate over each token position to compute the output at that position
let mut values: Vec<_> = states.iter().map(|_| [0.0; D_HEAD]).collect();
for(src, my_q) in qs.iter().enumerate() {
let mut scores = Vec::with_capacity(src);
// We can't get ahead! We can only look at the index we are at.
let visible_indices = 0..=src;
for i in visible_indices.clone() {
scores.push(dot(my_q, &ks[i]));
// Dotting that q vector with the keys, then pushing them into scores
}
// Turn scores into probability distribution
softmax(&mut scores);
// Loop over each visible position, weight their V vector by their attention weight and sum them together
for i in visible_indices {
let score = scores[i];
let v = vs[i];
for (j, vj) in v.iter().enumerate() {
values[src][j] += vj * score;
}
}
}
values.iter().map(&self.W_O).collect()
}
}
// Attention Layer: Applies each attention and sums outputs
impl AttnLayer {
fn apply(&self, states: &[State]) -> Vec<Update> {
let mut updates: Vec<Update> = states.iter().map(|_| State::zero()).collect();
for h in self.heads.iter() {
let head_out = h.apply(states);
updates = updates
.iter()
.zip(head_out.iter())
.map(|(l, r)| l.update(r))
.collect();
}
updates
}
}
struct Neuron {
read: Query,
write: Update,
}
struct MLPLayer {
mlps: [Neuron; D_MLP],
nonlinear: fn(f32) -> f32,
}
impl MLPLayer {
fn apply(&self, state: &State) -> Update {
let mut out: Update = State::zero();
for mlp in self.mlps.iter() {
let pre_act = mlp.read.query(state);
let post_act = (self.nonlinear)(pre_act);
let unit_out: Update = State(mlp.write.0.map(|f| f * post_act));
out = out.update(&unit_out)
}
out
}
}
// Now here is the full Transformer Model!
impl ARModel for Transformer {
fn apply(&self, tokens: &[Token]) -> Vec <Logits> {
// Embeddings: Convert tokens to initial states
let mut states = tokens
.iter()
.map(|t| self.embedding.apply(*t))
.collect::<Vec<_>>();
// At this point we have all the initial states after embedding
// Pass the initial hidden state through each layer
// This applies the operations of all the attention layers
for layer in self.layers.iter(){
let attn_out = layer.attn.apply(&states);
states = states
.iter()
.zip(attn_out.iter())
.map(|(l, r)| l.update(r))
.collect();
for i in 0..states.len() {
// Apply the mlps to each state
let mlp_out = layer.mlps.apply(&states[i]);
states[i] = states[i].update(&mlp_out);
}
}
// Apply the unembedding to get out the logits
states.iter().map(|s| self.unembedding.apply(s)).collect()
}
}