-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathmodel.h
More file actions
234 lines (180 loc) · 9.71 KB
/
model.h
File metadata and controls
234 lines (180 loc) · 9.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
// Copyright (c) 2025, IST Austria, developed by Erik Schultheis
// SPDX-License-Identifier: Apache-2.0
//
#ifndef LLMQ_SRC_TRAINING_MODEL_H
#define LLMQ_SRC_TRAINING_MODEL_H
#include <cstddef>
#include <memory>
#include <string_view>
#include <vector>
#include "kernels/kernels.h"
#include "utilities/stack.h"
#include "utilities/tensor.h"
#include "training/transformer_config.h"
enum class EMatmulBackend;
class AdamWStateManager;
class ITensorContainer;
class NCCLCommunicator;
class TensorAllocator;
class GenericTensorContainer;
class DataLoader;
typedef struct cudnnContext* cudnnHandle_t;
typedef struct cublasLtContext* cublasLtHandle_t;
class IRunState;
//! \brief Abstract model base class.
//! \details Provides access to the different underlying tensor containers.
class IModel {
public:
//! \brief Runs the forward pass until just before the logit calculation
//! \details This function is asynchronous. You need to wait on `run_state.ForwardDone`
//! before accessing any of the results (or run subsequent work on `run_state.MainStream`).
//! However, it is guaranteed that `inputs` have been copied to the GPU-side buffer
//! before this function returns.
//! Note: We do not calculate the logits here, so that we have more freedom to optimize
//! this large matmul, e.g., calculating it in chunks, by including it in the backward pass.
virtual void forward(Tensor inputs, NCCLCommunicator& comm, int micro_step) = 0;
//! \brief Runs the forward pass and calculates the loss w.r.t. `targets`.
//! \returns a pair containing the full loss, and the loss over the first 1k tokens.
// TODO fix this function and interface; make async with accumulation.
virtual std::pair<float, float> validate(Tensor inputs, Tensor targets, NCCLCommunicator& comm, int micro_step) = 0;
//! \brief Runs the backward pass
//! \details This function is asynchronous. You need to wait on `run_state.BackwardDone`
//! before accessing any of the results (or run subsequent work on `run_state.MainStream`).
//! However, it is guaranteed that `inputs` and `targets` have been copied to the GPU-side buffer
//! before this function returns.
//! `z_loss` specifies the strength of z-loss regularization, 1/2 z_loss * log²(sum(z_i)
virtual void backward(Tensor inputs, Tensor targets, NCCLCommunicator& comm, float z_loss, int grad_accum_steps, int micro_step) = 0;
//! \brief Runs the AdamW update step.
//! \details Runs asynchronously, signalling completion through the OptimizerDone event.
virtual void update(NCCLCommunicator& comm, float learning_rate, float beta_1, float beta_2, int t, float epsilon, float weight_decay, float grad_clip) = 0;
//! Gets the loss of the preceding validate or backward call (forward does _not_ calculate the loss)
//! \param max_pos Maximum position inside the sequence that is considered for loss calculation.
//! -1 indicates the full sequence. Must be a multiple of 512.
float get_loss(int max_pos=-1) const;
//! Gets the gradient norm of the preceding update call.
float get_norm() const;
//! Gets the tensor into which model inputs are to be placed.
virtual Tensor& get_input_buffer();
//! Gets the tensor into which model targets are to be placed.
virtual Tensor& get_target_buffer();
//! Model (master) weights. Sharded.
virtual ITensorContainer& weights() = 0;
//! (First order) momentum. Sharded.
virtual AdamWStateManager& optimizer() = 0;
//! Get the current RNG state
virtual std::vector<std::byte> rng_state() const = 0;
//! Set the RNG state from checkpoint data
virtual void set_rng_state(const std::vector<std::byte>& state) = 0;
//! Randomly initialize the model weights.
virtual void init_weights(NCCLCommunicator& comm) = 0;
//! Import the model weights from a file. This may be different than just reading into `weights()`,
//! because it may involve dtype conversion (`allow_cast=true`), and even rearrange some data
//! (e.g., fused vs unfused QKV)
virtual void import_weights(const std::string& file_name, bool allow_cast, NCCLCommunicator& comm) = 0;
//! This function needs to be called after the model has been restored from a checkpoint.
virtual void on_restore_checkpoint(NCCLCommunicator& comm) = 0;
//! Export the model weights to a safetensors file.
virtual void export_weights(const std::string& file_name, NCCLCommunicator& comm) = 0;
//! Get the model type identifier
virtual std::string_view model_type() const = 0;
//! Get a const reference to the model's RunState.
virtual IRunState& get_run_state() const = 0;
// generic model param utilities
virtual std::size_t num_block_tensors() const = 0;
virtual void fill_block_shapes(GenericTensorContainer& target, const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const = 0;
virtual std::size_t num_non_block_tensors() const = 0;
virtual void fill_non_block_shapes(GenericTensorContainer& target, const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const = 0;
GenericTensorContainer create_block_container(const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const;
GenericTensorContainer create_non_block_container(const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const;
protected:
~IModel() = default;
};
/*!
* \brief Architecture-agnostic base class for model run states
* \details Contains model run data that is independent of the actual
* model architecture, e.g., cublas handles, generic cuda events, etc.
*/
class IRunState {
friend class IModel;
friend std::string save_checkpoint(std::string checkpoint_directory, int step, IModel& model,
const DataLoader* loader, NCCLCommunicator& comm);
friend void load_checkpoint(std::string checkpoint_directory, int step, IModel& model,
DataLoader* loader, NCCLCommunicator& comm);
public:
IRunState() = default;
IRunState(TransformerConfig config, long batch_size, long seq_len, std::shared_ptr<TensorAllocator> alloc);
IRunState(IRunState&&) = default;
IRunState& operator=(IRunState&&) = default;
//! gets the accumulated loss after the last backward operation.
//! only should be called after the last backward step (i.e., where `micro_step==grad_accum_steps`)
//! will block the caller until `backward` is done, i.e., the function is safe to call without
//! additional synchronization.
float get_loss(int max_pos=-1) const;
//! gets the global gradient norm.
//! will block the caller until `update` has finished the norm calculation,
//! i.e., the function is safe to call without additional synchronization.
float get_norm() const;
//! gets the maximum log-sum-exp of any token's logits
float get_lse_max() const;
//! gets the sum of all tokens' logits' log-sum-exp
float get_lse_sum() const;
// temporary buffers
Tensor temp_alloc(ETensorDType dtype, const std::vector<long>& shape);
void temp_acquire(Tensor& target);
void temp_free(Tensor& tensor);
TransformerConfig Config;
long B; //!< Batch size
long T; //!< Sequence length
std::shared_ptr<TensorAllocator> Allocator;
DeviceMemoryStack Stack;
Tensor Inputs; // (B, T) Int32
Tensor Targets; // (B, T) Int32
Tensor Losses; // (B, T) FP32
Tensor LSE; // (B, T) FP32; log-sum-exp of logits
Tensor GroupedLosses; // (T / 512) FP32
float* NormHost = nullptr; // single value
float* LossHost = nullptr; // single value
float* LSEHost = nullptr; // two values
std::pair<float, float> record_step(float loss, float norm);
cudaDeviceProp DeviceProp;
cudaStream_t MainStream = nullptr;
cudaEvent_t ForwardDone = nullptr; //!< recorded at the end of the forward pass
cudaEvent_t BackwardDone = nullptr; //!< recorded at the end of the backward pass
cudaEvent_t TransferDone = nullptr; //!< recorded once CPU-side buffers have been copied to GPU
cudaEvent_t NormDone = nullptr; //!< recorded after norm calculation completes
cudaEvent_t LSEDone = nullptr; //!< recorded after logit lse (z) has been computed
cudaEvent_t OptimizerDone = nullptr; //!< recorded after the optimizer completes
cudnnHandle_t CudnnHandle = nullptr;
cublasLtHandle_t CublasLtHandle = nullptr;
Tensor CuBlasWorkspace;
EMatmulBackend MatmulBackend = EMatmulBackend{0};
// events for debugging timings
void setup_timing_events(int micro_steps);
cudaEvent_t TimingOptimizerStart = nullptr;
cudaEvent_t TimingOptimizerEnd = nullptr;
std::vector<cudaEvent_t> TimingForwardStart;
std::vector<cudaEvent_t> TimingForwardEnd;
std::vector<cudaEvent_t> TimingHeadStart;
std::vector<cudaEvent_t> TimingHeadEnd;
std::vector<cudaEvent_t> TimingBackwardStart;
std::vector<cudaEvent_t> TimingBackwardEnd;
private:
Tensor Inputs_CPU; // (B, T) Int32
Tensor Targets_CPU; // (B, T) Int32
// ring-buffers that keep a history of past losses
struct OutlierDetector {
OutlierDetector(int window_size=100);
void record(float value);
float eval(float value) const;
void re_evaluate();
void reset(int window_size, int index, std::vector<float> values);
int mWindowSize;
int mIndex = 0;
std::vector<float> mValues;
double mSum = 0.0;
double mSumSq = 0.0;
};
OutlierDetector LossOutliers;
OutlierDetector NormOutliers;
};
#endif //LLMQ_SRC_TRAINING_MODEL_H