#61 Re: 完了,deepseek把GPU上的fp8的汇编代码开源了。
发表于 : 2025年 2月 26日 03:43

内行的人看来只是一笑,不想浪费时间。不过也不奇怪,估计老黄心里也在笑,不过不能公开说,还是得客气一下,免得伤了玻璃心

内行的人看来只是一笑,不想浪费时间。不过也不奇怪,估计老黄心里也在笑,不过不能公开说,还是得客气一下,免得伤了玻璃心
不过也不奇怪,一帮小年轻,怎么可能有底层硬件经验呢
看能不能振臂一呼形成新的业界标准 对抗NVIDIA 这样堡垒就从帝国内部分裂了supercnm 写了: 2025年 2月 26日 03:10 感觉主要是提供了一种思路,硬件公司Nvidia的软件部门太烂,有大量的空间可以优化
老黄现在应该做的是,赶紧对新的blackwell芯片做指令优化,或者把blackwell的指令做加密,不准别人优化
fp8不像是正道drifter 写了: 2025年 2月 26日 04:20 看能不能振臂一呼形成新的业界标准 对抗NVIDIA 这样堡垒就从帝国内部分裂了
当然美帝可以下令 不准intel AMD之类的厂商使用这种标准
那你说说为啥要求sm_90?xexz 写了: 2025年 2月 25日 23:21 上面是个JIT虚拟机,下层用的nvGPU的汇编,
意思是,其他家的GPU只要实现NV的汇编(这个对那些GPU硬件厂家跟没要求一样),用谁的GPU都一样一样的。
是指当下、现在的硬件平台支持情况,
nv当然有,按你帝的说法,有点DEI罢了
在Linux核里,改变privilege, 调用co processor 指令等,都要嵌入汇编,因为用C无法实现。
fangkuuaih 写了: 2025年 2月 26日 07:02 在Linux核里,改变privilege, 调用co processor 指令等,都要嵌入汇编,因为用C无法实现。
这是叔还是junior工程师时就做过的事。
女大里一堆OS 核工程师,硬件工程师,这些都是小儿科的玩意。
xexz 写了: 2025年 2月 25日 23:11 #pragma once
#include <cuda.h>
#include "utils.cuh"
namespace deep_gemm {
struct SM90_64x16x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %10, 0;\n"
"wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7},"
" %8,"
" %9,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 16;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x24x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %14, 0;\n"
"wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11},"
" %12,"
" %13,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 24;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x32x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %18, 0;\n"
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15},"
" %16,"
" %17,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 32;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x40x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %22, 0;\n"
"wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19},"
" %20,"
" %21,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 40;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x48x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %26, 0;\n"
"wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23},"
" %24,"
" %25,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 48;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x56x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %30, 0;\n"
"wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27}, "
" %28,"
" %29,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 56;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x64x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %34, 0;\n"
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31}, "
" %32,"
" %33,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 64;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x72x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %38, 0;\n"
"wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35}, "
" %36,"
" %37,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 72;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x80x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %42, 0;\n"
"wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39}, "
" %40,"
" %41,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 80;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
struct SM90_64x88x32_F32E4M3E4M3_SS {
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b,
float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07,
float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15,
float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23,
float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31,
float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39,
float& d40, float& d41, float& d42, float& d43,
bool scale_d) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %46, 0;\n"
"wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3"
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43}, "
" %44,"
" %45,"
" p , 1, 1;\n"
"}\n"
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43)
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d)));
}
__device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
wgmma(desc_a, desc_b,
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7],
d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15],
d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23],
d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31],
d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39],
d[40], d[41], d[42], d[43],
scale_d);
}
static constexpr int M = 64;
static constexpr int N = 88;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};