# 算子的设计

算子是神经网络中用来处理加法 / 减法 / 激活函数等的抽象

# 算子基类

  • 层的名字 / 类型:

  • 线性 / Encode/Embedding/Softmax/Add 等

  • 算子类型: dataType: fp32/int32 等这种

  • 设备类型: CPU/GPU 即 Layer 即可以在 CPU 上实现,也可以在 GPU 上实现

  • 算子有输入,且输入可能并不唯一,所以设计上通常会提供一个接口: set_input (index,Tensor) 这样,第一个参数表示输入参数的下标.

  • 输出同理

  • 算子的计算过程:通常以 base_forward 来命名,每个算子重写 base_forward 来进行算子的计算,同时将数据存放到 output 中

  • 算子计算前的输入检测,因为传入的是 Tensor,所以需要检查输入的 Tensor 是否可以进行当前算子的计算、或者输入与层是否都在 CPU/GPU 上等

  • 含参数的算子:通常加入一个 weights<Tensor 列表> 代表算子权重 (参数)

如 AddKernel 的实现,在 get_kenel 阶段根据层的类型选择对应的 kernel 类型,如 CPU/GPU:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
void add_kernel_cpu(const tensor::Tensor& input1, const tensor::Tensor& input2,
const tensor::Tensor& output) {
arma::fvec input_vec1(const_cast<float*>(input1.ptr<float>()), input1.size(), false,
true);
arma::fvec input_vec2(const_cast<float*>(input2.ptr<float>()), input2.size(), false,
true);
arma::fvec output_vec(const_cast<float*>(output.ptr<float>()), output.size(), false,
true);
output_vec = input_vec1 + input_vec2;
}

AddKernel get_add_kernel(base::DeviceType device_type) {
if (device_type == base::DeviceType::kDeviceCPU) {
return add_kernel_cpu;
} else {
LOG(FATAL) << "Unknown device type for get a add kernel.";
return nullptr;
}
}

# 以 RMSNorm 算子为例,它的 CUDA 实现

RSMNorm 算子公式:

scalar=1di=1dxi2rsqrt=1scalar+epsy=xrsqrtwscalar = \frac{1}{d}\sum_{i=1}^{d}x_i^{2} \\ rsqrt = \frac{1}{\sqrt{scalar+eps}} \\ y = x \cdot rsqrt \cdot w

# CPU 上 RMSNorm 的实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
void rmsnorm_kernel_cpu(int32_t dim, const tensor::Tensor& input,
const tensor::Tensor& weight, const tensor::Tensor& output) {
const float* in_ptr = input.ptr<float>();
const float* wei_ptr = weight.ptr<float>();
const float* out_ptr = output.ptr<float>();

int size = static_cast<int32_t>(input.size());
float sum = 0.f;
for(int i=0;i<size;++i){
float input_value = input.index<float>(i);
sum += input_value * input_value;
}
const float eps = 1e-5f;
float mean = sum / float(size)+eps;

const float rsqrt = 1.f/std::sqrt(mean);
for(int i=0;i<size;++i){
*(out_ptr+i)=weight.index<float>(i)*(rsqrt*(*(in_ptr+i)));
}
}

# CUDA 上的算子实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

static __global__ void row_rmsnorm_f32(const float* in,const float* wei,const float* ourt,int size,const float eps){
//求和
const int tid = threadIdx.x;
const int lane_id = tid % warpSize;

float sum = 0.0f;
//这里只是子线程的和
for (int i = lane_id; i < size; i += warpSize) {
sum += in[i] * in[i];
}

//最终需要把所有局部和编程全局和,看下面的介绍
using WarpReduce = cub::WarpReduce<float, 32>;
__shared__ typename WarpReduce::TempStorage temp;
__shared__ float shared_val;
sum = WarpReduce(temp).Reduce(sum, cub::Sum());
//最终结果放在线程0上,然后对共享变量赋值成线程0的sum
if (threadIdx.x == 0) {
shared_val = sum;
}
__syncthreads();
//就是最终到这里的时候,所有线程的sum都是一个值了
sum = shared_val;

//求最终的式子 y = x*scale*w
const float scale = rsqrtf(sum / static_cast<float>(size) + eps);
for (int i = lane_id; i < size; i += warpSize) {
out[i] = scale * in[i] * wei[i];
}
}

WrapReduce 实际上用的是 ShuffleWarpReduce,即:

1 1 1 1 1 1 1 1
2 2 2 2
4 4
8

这样规约,这样规约可以避免线程阻塞,因为可以确保每一轮都是同时进行的,所以可以直接把后面的加到前面的,然后不断地缩减规模,达到类似归并的效果,结果在线程 0 上

# 启动 CUDA 上的 rmsnorm 核函数

和直接调用函数即为相似

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
void rmsnorm_kernel_cu(const tensor::Tensor& input, const tensor::Tensor& weight,
const tensor::Tensor& output, void* stream) {
CHECK(!input.is_empty());
CHECK(!weight.is_empty());
CHECK(!output.is_empty());

CHECK(input.device_type() == base::DeviceType::kDeviceCUDA &&
weight.device_type() == base::DeviceType::kDeviceCUDA &&
output.device_type() == base::DeviceType::kDeviceCUDA);

const float eps = 1e-5f;
const int32_t size = static_cast<int32_t>(input.size());
const float* in_ptr = input.ptr<float>();
const float* wei_ptr = weight.ptr<float>();
float* out_ptr = const_cast<float*>(output.ptr<float>());
if (size < 1024) {
constexpr int threads_num = 128;
if (stream) {
cudaStream_t stream_ = static_cast<cudaStream_t>(stream);
row_rmsnorm_f32<<<1, threads_num, 0, stream_>>>(in_ptr, wei_ptr, out_ptr, size, eps);
} else {
row_rmsnorm_f32<<<1, threads_num>>>(in_ptr, wei_ptr, out_ptr, size, eps);
}
} else {
constexpr int threads_num = 1024;
if (stream) {
cudaStream_t stream_ = static_cast<cudaStream_t>(stream);
row_rmsnorm_f32<<<1, threads_num, 0, stream_>>>(in_ptr, wei_ptr, out_ptr, size, eps);
} else {
row_rmsnorm_f32<<<1, threads_num>>>(in_ptr, wei_ptr, out_ptr, size, eps);
}
}
}
更新于

请我喝[茶]~( ̄▽ ̄)~*

Solvarg 微信支付

微信支付

Solvarg 支付宝

支付宝

Solvarg 贝宝

贝宝