diff --git a/.github/workflows/test-enjoy-app.yml b/.github/workflows/test-enjoy-app.yml index 4d35d310..36e60a45 100644 --- a/.github/workflows/test-enjoy-app.yml +++ b/.github/workflows/test-enjoy-app.yml @@ -20,7 +20,7 @@ jobs: strategy: fail-fast: false matrix: - os: [macos-12, macos-14, windows-latest, ubuntu-latest] + os: [macos-12, macos-14, windows-latest, ubuntu-24.04] steps: - uses: actions/checkout@v4 diff --git a/1000-hours/package.json b/1000-hours/package.json index b4f4568d..4804d73e 100644 --- a/1000-hours/package.json +++ b/1000-hours/package.json @@ -7,9 +7,9 @@ "markdown-it-mathjax3": "^4.3.2", "markdown-it-sub": "^2.0.0", "markdown-it-sup": "^2.0.0", - "mermaid": "^10.9.0", - "sass": "^1.77.1", - "vitepress": "^1.1.4", + "mermaid": "^10.9.1", + "sass": "^1.77.2", + "vitepress": "^1.2.0", "vitepress-plugin-mermaid": "^2.0.16", "vue": "^3.4.27" }, diff --git a/enjoy/lib/whisper.cpp/arm64/darwin/bench b/enjoy/lib/whisper.cpp/arm64/darwin/bench deleted file mode 100755 index 5611654c..00000000 Binary files a/enjoy/lib/whisper.cpp/arm64/darwin/bench and /dev/null differ diff --git a/enjoy/lib/whisper.cpp/arm64/darwin/ggml-metal.metal b/enjoy/lib/whisper.cpp/arm64/darwin/ggml-metal.metal index a7d3f9ef..57fdf564 100644 --- a/enjoy/lib/whisper.cpp/arm64/darwin/ggml-metal.metal +++ b/enjoy/lib/whisper.cpp/arm64/darwin/ggml-metal.metal @@ -1,3 +1,7 @@ +#define GGML_COMMON_DECL_METAL +#define GGML_COMMON_IMPL_METAL +#include "ggml-common.h" + #include using namespace metal; @@ -6,46 +10,11 @@ using namespace metal; #define MIN(x, y) ((x) < (y) ? (x) : (y)) #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } -#define QK4_0 32 -#define QR4_0 2 -typedef struct { - half d; // delta - uint8_t qs[QK4_0 / 2]; // nibbles / quants -} block_q4_0; - -#define QK4_1 32 -typedef struct { - half d; // delta - half m; // min - uint8_t qs[QK4_1 / 2]; // nibbles / quants -} block_q4_1; - -#define QK5_0 32 -typedef struct { - half d; // delta - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_0 / 2]; // nibbles / quants -} block_q5_0; - -#define QK5_1 32 -typedef struct { - half d; // delta - half m; // min - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_1 / 2]; // nibbles / quants -} block_q5_1; - -#define QK8_0 32 -typedef struct { - half d; // delta - int8_t qs[QK8_0]; // quants -} block_q8_0; - #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 enum ggml_sort_order { - GGML_SORT_ASC, - GGML_SORT_DESC, + GGML_SORT_ORDER_ASC, + GGML_SORT_ORDER_DESC, }; // general-purpose kernel for addition, multiplication and division of two tensors @@ -244,6 +213,15 @@ kernel void kernel_scale_4( dst[tpig] = src0[tpig] * scale; } +kernel void kernel_clamp( + device const float * src0, + device float * dst, + constant float & min, + constant float & max, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]); +} + kernel void kernel_relu( device const float * src0, device float * dst, @@ -251,6 +229,13 @@ kernel void kernel_relu( dst[tpig] = max(0.0f, src0[tpig]); } +kernel void kernel_sigmoid( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); +} + kernel void kernel_tanh( device const float * src0, device float * dst, @@ -264,6 +249,15 @@ constant float GELU_QUICK_COEF = -1.702f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; kernel void kernel_gelu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -277,6 +271,15 @@ kernel void kernel_gelu( } kernel void kernel_gelu_quick( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_gelu_quick_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -286,6 +289,14 @@ kernel void kernel_gelu_quick( } kernel void kernel_silu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + +kernel void kernel_silu_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -348,15 +359,20 @@ kernel void kernel_sum_rows( dst_row[0] = row_sum; } +template kernel void kernel_soft_max( - device const float * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, constant float & scale, - threadgroup float * buf [[threadgroup(0)]], + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], @@ -366,15 +382,27 @@ kernel void kernel_soft_max( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; - device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; + device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const int64_t h = i02; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } // parallel max float lmax = -INFINITY; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)); + lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); } // find the max value in the block @@ -399,7 +427,7 @@ kernel void kernel_soft_max( // parallel sum float lsum = 0.0f; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); + const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); lsum += exp_psrc0; pdst[i00] = exp_psrc0; } @@ -434,15 +462,20 @@ kernel void kernel_soft_max( } } +template kernel void kernel_soft_max_4( - device const float * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, constant float & scale, - threadgroup float * buf [[threadgroup(0)]], + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], @@ -452,15 +485,26 @@ kernel void kernel_soft_max_4( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + + float slope = 1.0f; + + if (max_bias > 0.0f) { + const int64_t h = i02; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } // parallel max float4 lmax4 = -INFINITY; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)); + lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); @@ -486,7 +530,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } @@ -523,6 +567,14 @@ kernel void kernel_soft_max_4( } } +typedef decltype(kernel_soft_max) kernel_soft_max_t; +typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t; + +template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; +template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; + kernel void kernel_diag_mask_inf( device const float * src0, device float * dst, @@ -862,6 +914,7 @@ void mul_vec_q_n_f32_impl( int64_t ne1, uint r2, uint r3, + threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; @@ -938,7 +991,7 @@ kernel void kernel_mul_mv_q4_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -964,7 +1017,7 @@ kernel void kernel_mul_mv_q4_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -990,7 +1043,7 @@ kernel void kernel_mul_mv_q5_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -1016,7 +1069,7 @@ kernel void kernel_mul_mv_q5_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } @@ -1026,18 +1079,19 @@ void kernel_mul_mv_q8_0_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nr = N_DST; const int nsg = N_SIMDGROUP; const int nw = N_SIMDWIDTH; @@ -1115,7 +1169,7 @@ kernel void kernel_mul_mv_q8_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } #define N_F32_F32 4 @@ -1124,24 +1178,24 @@ void kernel_mul_mv_f32_f32_impl( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg) { const int64_t r0 = tgpig.x; const int64_t rb = tgpig.y*N_F32_F32; @@ -1394,24 +1448,24 @@ void kernel_mul_mv_f16_f32_impl( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg) { const int64_t r0 = tgpig.x; const int64_t rb = tgpig.y*N_F16_F32; @@ -1544,60 +1598,6 @@ kernel void kernel_mul_mv_f16_f32_l4( } } -kernel void kernel_alibi_f32( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant float & m0, - constant float & m1, - constant int & n_heads_log2_floor, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - const int64_t k = i3*ne3 + i2; - - float m_k; - if (k < n_heads_log2_floor) { - m_k = pow(m0, k + 1); - } else { - m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1; - device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01; - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - const float src_v = *(device float *)(src_row + i00*nb00); - device float * dst_v = (device float *)(dst_row + i00*nb0); - *dst_v = i00 * m_k + src_v; - } -} - static float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / max(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y)); @@ -1775,9 +1775,29 @@ kernel void kernel_rope( template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; -kernel void kernel_im2col_f16( +typedef void (im2col_t)( device const float * x, - device half * dst, + device char * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template +kernel void kernel_im2col( + device const float * x, + device char * dst, constant int32_t & ofs0, constant int32_t & ofs1, constant int32_t & IW, @@ -1800,14 +1820,19 @@ kernel void kernel_im2col_f16( (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]); + device T * pdst = (device T *) (dst); + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; + pdst[offset_dst] = 0.0f; } else { const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1; - dst[offset_dst] = x[offset_src + iih * IW + iiw]; + pdst[offset_dst] = x[offset_src + iih * IW + iiw]; } } +template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; +template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; + kernel void kernel_upscale_f32( device const char * src0, device char * dst, @@ -1899,11 +1924,56 @@ kernel void kernel_pad_f32( } } +kernel void kernel_arange_f32( + device char * dst, + constant int64_t & ne0, + constant float & start, + constant float & step, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + device float * dst_ptr = (device float *) dst; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + dst_ptr[i0] = start + step * i0; + } +} + +kernel void kernel_timestep_embedding_f32( + device const char * src0, + device char * dst, + constant uint64_t & nb1, + constant int & dim, + constant int & max_period, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + int i = tgpig.x; + device float * embed_data = (device float *)(dst + i*nb1); + + int half_ = dim / 2; + for (int j = tpitg.x; j < half_; j += ntg.x) { + float timestep = ((device float *)src0)[i]; + float freq = (float)exp(-log((float)max_period) * j / half_); + float arg = timestep * freq; + embed_data[j ] = cos(arg); + embed_data[j + half_] = sin(arg); + } + + if (dim % 2 != 0 && tpitg.x == 0) { + embed_data[dim] = 0.f; + } +} + // bitonic sort implementation following the CUDA kernels as reference typedef void (argsort_t)( - device const float * x, - device int32_t * dst, - constant int64_t & ncols, + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]); @@ -1912,33 +1982,42 @@ kernel void kernel_argsort_f32_i32( device const float * x, device int32_t * dst, constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]) { // bitonic sort int col = tpitg[0]; int row = tgpig[1]; - if (col >= ncols) return; + if (col >= ncols_pad) return; - device const float * x_row = x + row * ncols; - device int32_t * dst_row = dst + row * ncols; + device const float * x_row = x + row * ncols; + threadgroup int32_t * dst_row = shared_values; // initialize indices - if (col < ncols) { - dst_row[col] = col; - } + dst_row[col] = col; + threadgroup_barrier(mem_flags::mem_threadgroup); - for (int k = 2; k <= ncols; k *= 2) { + for (int k = 2; k <= ncols_pad; k *= 2) { for (int j = k / 2; j > 0; j /= 2) { int ixj = col ^ j; if (ixj > col) { if ((col & k) == 0) { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { SWAP(dst_row[col], dst_row[ixj]); } } else { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { SWAP(dst_row[col], dst_row[ixj]); } } @@ -1946,10 +2025,15 @@ kernel void kernel_argsort_f32_i32( threadgroup_barrier(mem_flags::mem_threadgroup); } } + + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } } -template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; -template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; kernel void kernel_leaky_relu_f32( device const float * src0, @@ -1959,6 +2043,654 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } +typedef void (flash_attn_ext_f16_t)( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, + constant uint64_t & nb31, + constant int64_t & ne1, + constant int64_t & ne2, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup half * shared, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]); + +// ref: https://arxiv.org/pdf/2307.08691.pdf +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, + constant uint64_t & nb31, + constant int64_t & ne1, + constant int64_t & ne2, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]*Q; + + const short D4 = D/4; + const short D8 = D/8; + //const short Q8 = Q/8; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + const short TF = T/2; // shared memory size per query in (float) + const short T4 = T/4; // shared memory size per query in (half4) + + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + simdgroup_half8x8 lo[D8]; + + // load heads from Q to shared memory + for (short j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 + j < ne01) { + sq4[j*T4 + i] = (half4) q4[i]; + } else { + sq4[j*T4 + i] = 0.0h; + } + } + } + + // zero out lo + for (short i = 0; i < D8; ++i) { + lo[i] = make_filled_simdgroup_matrix(0.0h); + } + + // zero out shared memory SH + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < SH; i += NW) { + ss[j*TF + i] = 0.0f; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + float S[Q] = { [0 ... Q-1] = 0.0h }; + float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2/rk2; + const short ik3 = iq3/rk3; + + // v indices + const short iv2 = iq2/rv2; + const short iv3 = iq3/rv3; + + // load the queries from shared memory into local memory + simdgroup_half8x8 mq[D8]; + + for (short i = 0; i < D8; ++i) { + simdgroup_load(mq[i], sq + i*8, T); + } + + // pointer to the mask + device const half * mp = (device const half *) (mask + iq1*nb31); + + // prepare diagonal scale matrix + simdgroup_float8x8 mscale(scale); + + // prepare diagonal slope matrix + simdgroup_float8x8 mslope(1.0f); + + // ALiBi + if (max_bias > 0.0f) { + const uint32_t h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + mslope = simdgroup_float8x8(pow(base, exph)); + } + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { + for (short cc = 0; cc < C/8; ++cc) { + simdgroup_float8x8 mqk = make_filled_simdgroup_matrix(0.h); + + device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); + + for (short i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose + + simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + } + + if (mask != q) { + // mqk = mqk*scale + mask*slope + simdgroup_half8x8 mm; + simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); + simdgroup_multiply(mm, mslope, mm); + simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); + } else { + // mqk = mqk*scale + simdgroup_multiply(mqk, mscale, mqk); + } + + simdgroup_store(mqk, ss + 8*cc, TF, 0, false); + } + } + + // used to detect blocks full of -INF + float smax = -INFINITY; + + // online softmax + { + float ms[Q]; + + for (short j = 0; j < Q; ++j) { + const short p = tiisg; + + const float m = M[j]; + const float s = ss[j*TF + p]; + + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + + ms[j] = exp(m - M[j]); + const float vs = exp(s - M[j]); + + S[j] = S[j]*ms[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*TF + p] = vs; + } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*TF + C + tiisg] = ms[tiisg]; + } + } + + // skip -INF blocks + if (smax == -INFINITY) { + continue; + } + + // O = diag(ms)*O + { + simdgroup_float8x8 mm; + simdgroup_load(mm, ss + C, TF, 0, false); + + for (short i = 0; i < D8; ++i) { + simdgroup_multiply(lo[i], mm, lo[i]); + } + } + + // O = O + (Q*K^T)*V + { + for (short cc = 0; cc < C/8; ++cc) { + device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + for (short i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); + + simdgroup_float8x8 mv; + simdgroup_load(mv, ss + 8*cc, TF, 0, false); + + simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]); + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + for (short j = 0; j < Q; ++j) { + if (tiisg == 0) { + ss[j*TF + 0] = S[j]; + ss[j*TF + 1] = M[j]; + } + } + } + + // reduce the warps sequentially + for (short sg = 1; sg < nsg; ++sg) { + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // each simdgroup stores its output to shared memory, reusing sq + if (sgitg == sg) { + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // the first simdgroup accumulates the results from the other simdgroups + if (sgitg == 0) { + for (short j = 0; j < Q; ++j) { + const float S0 = ss[j*TF + 0]; + const float S1 = ss[j*TF + sg*SH + 0]; + + const float M0 = ss[j*TF + 1]; + const float M1 = ss[j*TF + sg*SH + 1]; + + M = max(M0, M1); + + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[j*TF + 0] = S; + ss[j*TF + 1] = M; + + ss[j*TF + C + j ] = ms0; + ss[j*TF + C + j + sg*SH] = ms1; + } + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + { + simdgroup_half8x8 t; + simdgroup_float8x8 ms0; + simdgroup_float8x8 ms1; + + simdgroup_load(ms0, ss + C, TF, 0, false); + simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); + + for (short i = 0; i < D8; ++i) { + simdgroup_load (t, sq + i*8, T, 0, false); + simdgroup_multiply(t, ms1, t); + + simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + } + } + } + } + + // store result to shared memory (reuse sq) + if (sgitg == 0) { + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + for (short j = 0; j < Q && iq1 + j < ne01; ++j) { + const float S = ss[j*TF + 0]; + + for (short i = tiisg; i < D4; i += NW) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; + } + } + } +} + +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; + +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_vec_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, + constant uint64_t & nb31, + constant int64_t & ne1, + constant int64_t & ne2, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]; + + const short D4 = D/4; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const uint32_t h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 + threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + half4 lo[D4/NW]; + + // load heads from Q to shared memory + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 < ne01) { + sq4[i] = (half4) q4[i]; + } else { + sq4[i] = 0.0h; + } + } + + // zero out lo + for (short i = tiisg; i < D4; i += NW) { + lo[i/NW] = 0.0h; + } + + // zero out shared memory SH + for (short i = tiisg; i < SH/4; i += NW) { + ss4[i] = 0.0h; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2 / rk2; + const short ik3 = iq3 / rk3; + + // v indices + const short iv2 = iq2 / rv2; + const short iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + half4 mq[D4]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + mq[i] = sq4[i]; + } + + // pointer to the mask + device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + float4 mqk = { 0.0h }; + + device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + half4x4 mk; + mk[0] = pk4[i + 0*(nb11/8)]; + mk[1] = pk4[i + 1*(nb11/8)]; + mk[2] = pk4[i + 2*(nb11/8)]; + mk[3] = pk4[i + 3*(nb11/8)]; + + mqk += (float4) (mq[i] * mk); + } + + // reduce the results from the threads in the simdgroup + mqk += simd_shuffle_down(mqk, 16); + mqk += simd_shuffle_down(mqk, 8); + mqk += simd_shuffle_down(mqk, 4); + mqk += simd_shuffle_down(mqk, 2); + mqk += simd_shuffle_down(mqk, 1); + + // mqk = mqk*scale + mask*slope + if (tiisg == 0) { + mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f); + + ss4[cc] = mqk; + } + } + } + + // online softmax + { + const short p = tiisg; + + const float m = M; + const float s = ss[p]; + + M = simd_max(max(M, s)); + + const float ms = exp(m - M); + const float vs = exp(s - M); + + S = S*ms + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[p] = vs; + + // O = diag(ms)*O +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + lo[i/NW] *= ms; + } + } + + // O = O + (Q*K^T)*V + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; + lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; + lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; + lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; + } + } + } + + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + } + + // store results to shared memory + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = lo[ii/NW]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // parallel reduce + for (short r = nsg/2; r > 0; r >>= 1) { + if (sgitg < r) { + const float S0 = ss[ 0]; + const float S1 = ss[r*SH + 0]; + + const float M0 = ss[ 1]; + const float M1 = ss[r*SH + 1]; + + const float M = max(M0, M1); + + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); + + const float S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + const float S = ss[0]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S; + } + } +} + +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, @@ -2079,7 +2811,8 @@ kernel void kernel_cpy_f32_f16( for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; + // TODO: is there a better way to handle -INFINITY? + dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0]; } } @@ -2316,6 +3049,242 @@ kernel void kernel_cpy_f32_q4_1( } } +kernel void kernel_cpy_f32_q5_0( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0; + + device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK5_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -16; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK5_0].d = d; + + uint32_t qh = 0; + for (int j = 0; j < QK5_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK5_0/2 + j]*id; + + const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); + const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); + + dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + for (int j = 0; j < 4; ++j) { + dst_data[i00/QK5_0].qh[j] = qh8[j]; + } + } +} + +kernel void kernel_cpy_f32_q5_1( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1; + + device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float max = src[0]; + float min = src[0]; + + for (int j = 1; j < QK5_1; j++) { + const float v = src[j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK5_1].d = d; + dst_data[i00/QK5_1].m = min; + + uint32_t qh = 0; + for (int j = 0; j < QK5_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK5_1/2 + j] - min)*id; + + const uint8_t xi0 = (uint8_t)(x0 + 0.5f); + const uint8_t xi1 = (uint8_t)(x1 + 0.5f); + + dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); + } + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + for (int j = 0; j < 4; ++j) { + dst_data[i00/QK5_1].qh[j] = qh8[j]; + } + } +} + +static inline int best_index_int8(int n, constant float * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} + +constexpr constant static float kvalues_iq4nl_f[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +kernel void kernel_cpy_f32_iq4_nl( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL; + + device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK4_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / kvalues_iq4nl_f[0]; + const float id = d ? 1.0f/d : 0.0f; + + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK4_NL/2 + j]*id; + + const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0); + const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1); + + dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4); + + const float v0 = kvalues_iq4nl_f[xi0]; + const float v1 = kvalues_iq4nl_f[xi1]; + const float w0 = src[0 + j]*src[0 + j]; + const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j]; + sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + + } + + dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d; + + } +} + kernel void kernel_concat( device const char * src0, device const char * src1, @@ -2372,98 +3341,23 @@ kernel void kernel_concat( } } -//============================================ k-quants ====================================================== - -#ifndef QK_K -#define QK_K 256 -#else -static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64"); -#endif - -#if QK_K == 256 -#define K_SCALE_SIZE 12 -#else -#define K_SCALE_SIZE 4 -#endif - -typedef struct { - uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits - uint8_t qs[QK_K/4]; // quants - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins -} block_q2_K; -// 84 bytes / block - -typedef struct { - uint8_t hmask[QK_K/8]; // quants - high bit - uint8_t qs[QK_K/4]; // quants - low 2 bits -#if QK_K == 64 - uint8_t scales[2]; -#else - uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits -#endif - half d; // super-block scale -} block_q3_K; - -#if QK_K == 64 -typedef struct { - half d[2]; // super-block scales/mins - uint8_t scales[2]; - uint8_t qs[QK_K/2]; // 4-bit quants -} block_q4_K; -#else -typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_K; -#endif - -#if QK_K == 64 -typedef struct { - half d; // super-block scales/mins - int8_t scales[QK_K/16]; // 8-bit block scales - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -#else -typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -// 176 bytes / block -#endif - -typedef struct { - uint8_t ql[QK_K/2]; // quants, lower 4 bits - uint8_t qh[QK_K/4]; // quants, upper 2 bits - int8_t scales[QK_K/16]; // scales, quantized with 8 bits - half d; // super-block scale -} block_q6_K; -// 210 bytes / block - -//====================================== dot products ========================= - void kernel_mul_mv_q2_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -2623,7 +3517,7 @@ kernel void kernel_mul_mv_q2_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } #if QK_K == 256 @@ -2631,18 +3525,19 @@ void kernel_mul_mv_q3_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; @@ -2798,6 +3693,7 @@ void kernel_mul_mv_q3_K_f32_impl( constant int64_t & ne1, constant uint & r2, constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -2887,7 +3783,7 @@ kernel void kernel_mul_mv_q3_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } #if QK_K == 256 @@ -2895,18 +3791,19 @@ void kernel_mul_mv_q4_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; @@ -3017,6 +3914,7 @@ void kernel_mul_mv_q4_K_f32_impl( constant int64_t & ne1, constant uint & r2, constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -3125,25 +4023,26 @@ kernel void kernel_mul_mv_q4_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q5_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; @@ -3331,25 +4230,26 @@ kernel void kernel_mul_mv_q5_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q6_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const uint8_t kmask1 = 0x03; const uint8_t kmask2 = 0x0C; @@ -3465,7 +4365,1182 @@ kernel void kernel_mul_mv_q6_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +// ======================= "True" 2-bit + +void kernel_mul_mv_iq2_xxs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + { + int nval = 4; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_xxs * xr = x + ibl; + device const uint16_t * q2 = xr->qs + 4 * ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + device const uint8_t * aux8 = (device const uint8_t *)q2; + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = db * (0.5f + (aux32 >> 28)); + + float sum = 0; + for (int l = 0; l < 4; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]); + const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + sumf[row] += d * sum; + + dh += nb*sizeof(block_iq2_xxs)/2; + q2 += nb*sizeof(block_iq2_xxs)/2; + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_xxs_f32")]] +kernel void kernel_mul_mv_iq2_xxs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq2_xs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512); + { + int nval = 8; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_xs * xr = x + ibl; + device const uint16_t * q2 = xr->qs + 4 * ib; + device const uint8_t * sc = xr->scales + ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const uint8_t ls1 = sc[0] & 0xf; + const uint8_t ls2 = sc[0] >> 4; + const float d1 = db * (0.5f + ls1); + const float d2 = db * (0.5f + ls2); + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < 2; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); + const uint8_t signs = shared_signs[(q2[l] >> 9)]; + for (int j = 0; j < 8; ++j) { + sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + for (int l = 2; l < 4; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); + const uint8_t signs = shared_signs[(q2[l] >> 9)]; + for (int j = 0; j < 8; ++j) { + sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + sumf[row] += d1 * sum1 + d2 * sum2; + + dh += nb*sizeof(block_iq2_xs)/2; + q2 += nb*sizeof(block_iq2_xs)/2; + sc += nb*sizeof(block_iq2_xs); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_xs_f32")]] +kernel void kernel_mul_mv_iq2_xs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq3_xxs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + { + int nval = 4; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq3_xxs * xr = x + ibl; + device const uint8_t * q3 = xr->qs + 8 * ib; + device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float d = db * (0.5f + (aux32 >> 28)); + + float2 sum = {0}; + for (int l = 0; l < 4; ++l) { + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]); + const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 4; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + } + sumf[row] += d * (sum[0] + sum[1]); + + dh += nb*sizeof(block_iq3_xxs)/2; + q3 += nb*sizeof(block_iq3_xxs); + gas += nb*sizeof(block_iq3_xxs)/2; + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f; + } + } +} + +[[host_name("kernel_mul_mv_iq3_xxs_f32")]] +kernel void kernel_mul_mv_iq3_xxs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq3_s_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; + { + int nval = 8; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq3_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 8 * ib; + device const uint8_t * qh = xr->qh + ib; + device const uint8_t * sc = xr->scales + (ib/2); + device const uint8_t * signs = xr->signs + 4 * ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); + + float2 sum = {0}; + for (int l = 0; l < 4; ++l) { + const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values; + const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values; + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); + for (int j = 0; j < 4; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); + sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); + } + } + sumf[row] += d * (sum[0] + sum[1]); + + dh += nb*sizeof(block_iq3_s)/2; + qs += nb*sizeof(block_iq3_s); + qh += nb*sizeof(block_iq3_s); + sc += nb*sizeof(block_iq3_s); + signs += nb*sizeof(block_iq3_s); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_iq3_s_f32")]] +kernel void kernel_mul_mv_iq3_s_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq2_s_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + //{ + // int nval = 32; + // int pos = (32*sgitg + tiisg)*nval; + // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i]; + // threadgroup_barrier(mem_flags::mem_threadgroup); + //} + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint8_t * qh = xr->qh + ib; + device const uint8_t * sc = xr->scales + ib; + device const uint8_t * signs = qs + QK_K/8; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const float d1 = db * (0.5f + (sc[0] & 0xf)); + const float d2 = db * (0.5f + (sc[0] >> 4)); + + float2 sum = {0}; + for (int l = 0; l < 2; ++l) { + //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + for (int j = 0; j < 8; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]); + sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]); + } + } + sumf[row] += d1 * sum[0] + d2 * sum[1]; + + dh += nb*sizeof(block_iq2_s)/2; + qs += nb*sizeof(block_iq2_s); + qh += nb*sizeof(block_iq2_s); + sc += nb*sizeof(block_iq2_s); + signs += nb*sizeof(block_iq2_s); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_s_f32")]] +kernel void kernel_mul_mv_iq2_s_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq1_s_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_value, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + float sumy = 0; + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + sumy += yl[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq1_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint16_t * qh = xr->qh + ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700))); + constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700))); + constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700))); + + float sum = 0; + for (int j = 0; j < 4; ++j) { + sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) + + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) + + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); + } + sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1); + + dh += nb*sizeof(block_iq1_s)/2; + qs += nb*sizeof(block_iq1_s); + qh += nb*sizeof(block_iq1_s)/2; + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +void kernel_mul_mv_iq1_m_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_value, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + +#if QK_K != 64 + iq1m_scale_t scale; +#endif + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + float4 sumy = {0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+24]; sumy[3] += yl[i+24]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq1_m * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint8_t * qh = xr->qh + 2 * ib; + device const uint16_t * sc = (device const uint16_t *)xr->scales; + + for (int row = 0; row < N_DST; row++) { + +#if QK_K != 64 + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); +#endif + + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); + constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700))); + constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700))); + + float2 sum = {0.f}; + for (int j = 0; j < 4; ++j) { + sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4); + sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) + + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); + } + const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); +#if QK_K == 64 + const float d = (float) *((device const half *)(sc - 1)); + sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) + + (sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1)); +#else + sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + + (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); +#endif + + sc += nb*sizeof(block_iq1_m)/2; + qs += nb*sizeof(block_iq1_m); + qh += nb*sizeof(block_iq1_m); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +void kernel_mul_mv_iq4_nl_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + const int nb = ne00/QK4_NL; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/2; // 0...15 + const int it = tiisg%2; // 0 or 1 + + shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK4_NL + it * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ib = ix; ib < nb; ib += 16) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + + for (int row = 0; row < 2; ++row) { + + device const block_iq4_nl & xb = x[row*nb + ib]; + device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = q4[0] | (q4[1] << 16); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = q4[2] | (q4[3] << 16); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + acc1 += acc2; + + sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + + } + + yb += 16 * QK4_NL; + } + + for (int row = 0; row < 2; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +#if QK_K != 64 +void kernel_mul_mv_iq4_xs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/16; // 0 or 1 + const int it = tiisg%16; // 0...15 + const int ib = it/2; + const int il = it%2; + + shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK_K + ib * 32 + il * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ibl = ix; ibl < nb; ibl += 2) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + + for (int row = 0; row < 2; ++row) { + + device const block_iq4_xs & xb = x[row*nb + ibl]; + device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = q4[0] & 0x0f0f0f0f; + aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = q4[1] & 0x0f0f0f0f; + aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + acc1 += acc2; + + const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; + sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + + } + + yb += 2 * QK_K; + } + + for (int row = 0; row < 2; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} +#endif + +[[host_name("kernel_mul_mv_iq1_s_f32")]] +kernel void kernel_mul_mv_iq1_s_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq1_m_f32")]] +kernel void kernel_mul_mv_iq1_m_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq4_nl_f32")]] +kernel void kernel_mul_mv_iq4_nl_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq4_xs_f32")]] +kernel void kernel_mul_mv_iq4_xs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + +#if QK_K == 64 + kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +#else + kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +#endif } //============================= templates and their specializations ============================= @@ -3620,8 +5695,8 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : (scale_2&kmask2) | ((scale_1&kmask1) << 4); - half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h); - const half ml = 4.h * dl; + float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); + const float ml = 4.f * dl; il = (il/2) & 3; const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); @@ -3663,6 +5738,8 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg const float dl = d * sc[0]; const float ml = min * sc[1]; #else + (void) get_scale_min_k4_just2; + q = q + 16 * (il&1); device const uint8_t * s = xb->scales; device const half2 * dh = (device const half2 *)xb->d; @@ -3688,7 +5765,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg uint8_t ul = 1 << (il/2); il = il & 3; const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const float d = il < 2 ? xb->d : xb->d / 16.h; + const float d = il < 2 ? xb->d : xb->d / 16.f; const float min = xb->dmin; const float dl = d * sc[0]; const float ml = min * sc[1]; @@ -3721,17 +5798,17 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg #if QK_K == 256 ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); qh = qh + 32*(il/8) + 16*(il&1); - half sc = scales[(il%2) + 2 * ((il/2))]; + float sc = scales[(il%2) + 2 * ((il/2))]; il = (il/2) & 3; #else ql = ql + 16 * (il&1); - half sc = scales[il]; + float sc = scales[il]; #endif const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; - const half coef = il>1 ? 1.f/16.h : 1.h; - const half ml = d_all * sc * 32.h; - const half dl = d_all * sc * coef; + const float coef = il>1 ? 1.f/16.f : 1.f; + const float ml = d_all * sc * 32.f; + const float dl = d_all * sc * coef; for (int i = 0; i < 16; ++i) { const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); @@ -3739,6 +5816,215 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg } } +template +void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's. + device const uint16_t * q2 = xb->qs + 4*ib32; + const uint32_t aux32_g = q2[0] | (q2[1] << 16); + const uint32_t aux32_s = q2[2] | (q2[3] << 16); + thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g; + const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]); + uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]); + signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + +template +void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint16_t * q2 = xb->qs + 4*ib32; + const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511)); + uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511)); + signs = ksigns_iq2xs[q2[2*il+1] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + +template +void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * q3 = xb->qs + 8*ib32; + device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32; + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f; + constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]); + constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]); + uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127]; + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); + reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); + } + grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]); + grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]); + signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127]; + for (int i = 0; i < 4; ++i) { + reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); + reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); + } +} + +template +void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * qs = xb->qs + 8*ib32; + device const uint8_t * signs = xb->signs + 4*ib32 + 2*il; + const uint8_t qh = xb->qh[ib32] >> 4*il; + const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); + constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256))); + constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); + reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); + } + grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256))); + grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256))); + for (int i = 0; i < 4; ++i) { + reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); + reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); + } +} + +template +void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint8_t * signs = qs + QK_K/8; + const uint8_t qh = xb->qh[ib32] >> 4*il; + const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; + constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300))); + constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300))); + for (int i = 0; i < 8; ++i) { + reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]); + reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]); + } +} + +template +void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + const float d = xb->d; + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint16_t * qh = xb->qh; + const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1); + const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA); + const uint16_t h = qh[ib32] >> 6*il; + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) + ml; + reg[1][i] = dl * (grid1[i] >> 4) + ml; + reg[2][i] = dl * (grid2[i] & 0xf) + ml; + reg[3][i] = dl * (grid2[i] >> 4) + ml; + } +} + +template +void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + device const uint16_t * sc = (device const uint16_t *)xb->scales; +#if QK_K == 64 + const float d = xb->d; +#else + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const float d = scale.f16; +#endif + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint8_t * qh = xb->qh + 2*ib32 + il; +#if QK_K == 64 + const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1); +#else + const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1); +#endif + const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) + ml1; + reg[1][i] = dl * (grid1[i] >> 4) + ml1; + reg[2][i] = dl * (grid2[i] & 0xf) + ml2; + reg[3][i] = dl * (grid2[i] >> 4) + ml2; + } +} + +template +void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) { + device const uint16_t * q4 = (device const uint16_t *)xb->qs; + const float d = xb->d; + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f; + reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; + reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; + reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; + reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; + } +} + +template +void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) { +#if QK_K == 64 + dequantize_iq4_nl(xb, il, reg); +#else + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32; + const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4); + const float d = (float)xb->d * (ls - 32); + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f; + reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; + reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; + reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; + reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; + } +#endif +} + template kernel void kernel_get_rows( device const void * src0, @@ -4002,25 +6288,25 @@ void kernel_mul_mm_impl(device const uchar * src0, } } -// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids +// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids template void kernel_mul_mm_id_impl( device const uchar * src0, device const uchar * src1, - thread short * src1ids, + threadgroup ushort2 * rowids, device float * dst, constant int64_t & ne00, constant int64_t & ne02, constant uint64_t & nb01, constant uint64_t & nb02, + constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, constant int64_t & ne0, int64_t ne1, - constant uint & r2, - constant uint & r3, + int64_t ne0ne1, threadgroup uchar * shared_memory, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], @@ -4031,7 +6317,6 @@ void kernel_mul_mm_id_impl( const uint r0 = tgpig.y; const uint r1 = tgpig.x; - const uint im = tgpig.z; if (r1 * BLOCK_SIZE_N >= ne1) return; @@ -4049,19 +6334,16 @@ void kernel_mul_mm_id_impl( for (int i = 0; i < 8; i++){ c_res[i] = make_filled_simdgroup_matrix(0.f); } - short il = (tiitg % THREAD_PER_ROW); - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); ushort offset1 = il/nl; - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col]; + + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1; device const float * y = (device const float *)(src1 - + nb12 * im - + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col] + + nb12 * id[1] + + nb11 * (id[0] % ne11) + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { @@ -4090,11 +6372,11 @@ void kernel_mul_mm_id_impl( for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { for (int i = 0; i < 4; i++) { - simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); + simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); } simdgroup_barrier(mem_flags::mem_none); for (int i = 0; i < 2; i++) { - simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); + simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); } lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; @@ -4116,11 +6398,13 @@ void kernel_mul_mm_id_impl( threadgroup_barrier(mem_flags::mem_threadgroup); - device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0; + device float * C = dst + (BLOCK_SIZE_M * r0); if (sgitg == 0) { - for (int i = 0; i < n_rows; i++) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; + int joff = jid[0] * ne0 + jid[1] * ne0ne1; + for (int i = 0; i < n_rows; i++) { + *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M); } } } @@ -4171,14 +6455,18 @@ kernel void kernel_mul_mm(device const uchar * src0, template kernel void kernel_mul_mm_id( - device const uchar * ids, + device const uchar * src0s, device const uchar * src1, device float * dst, + device const uchar * ids, + constant int64_t & nei0, + constant int64_t & nei1, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne02, constant uint64_t & nb01, constant uint64_t & nb02, + constant int64_t & ne11, constant int64_t & ne12, constant int64_t & ne13, constant uint64_t & nb10, @@ -4187,55 +6475,52 @@ kernel void kernel_mul_mm_id( constant int64_t & ne0, constant int64_t & ne1, constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const uchar * src00, - device const uchar * src01, - device const uchar * src02, - device const uchar * src03, - device const uchar * src04, - device const uchar * src05, - device const uchar * src06, - device const uchar * src07, threadgroup uchar * shared_memory [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - // expert id - const int32_t id = tgpig.z/(ne12*ne13); + const int32_t i02 = tgpig.z; + tgpig.z = 0; - tgpig.z = tgpig.z%(ne12*ne13); + device const uchar * src0 = src0s + i02*nb02; - // row indices of src1 for expert id + // row indices + threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); + + // TODO: parallelize this loop int64_t _ne1 = 0; - short src1ids[512]; - - for (int64_t i1 = 0; i1 < ne1; i1++) { - if (((device int32_t *) (ids + i1*nbi1))[idx] == id) { - src1ids[_ne1++] = i1; + for (ushort ii1 = 0; ii1 < nei1; ii1++) { + for (ushort ii0 = 0; ii0 < nei0; ii0++) { + int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; + if (id == i02) { + //if (tiitg == 0) { + rowids[_ne1] = ushort2(ii0, ii1); + //} + _ne1++; + } } } + threadgroup_barrier(mem_flags::mem_threadgroup); + kernel_mul_mm_id_impl( - src0s[id], + src0, src1, - src1ids, + rowids, dst, ne00, ne02, nb01, nb02, + ne11, ne12, nb10, nb11, nb12, ne0, _ne1, - r2, - r3, + ne0*ne1, shared_memory, tgpig, tiitg, @@ -4278,238 +6563,201 @@ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows; +#if QK_K == 64 +template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows; +#else +template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows; +#endif // // matrix-matrix multiplication // -typedef void (mat_mm_t)( - device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar *, - uint3, uint, uint); +typedef decltype(kernel_mul_mm) mat_mm_t; -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; +#if QK_K == 64 +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +#else +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +#endif // // indirect matrix-matrix multiplication // -typedef void (mat_mm_id_t)( - device const uchar * ids, - device const uchar * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const uchar * src00, - device const uchar * src01, - device const uchar * src02, - device const uchar * src03, - device const uchar * src04, - device const uchar * src05, - device const uchar * src06, - device const uchar * src07, - threadgroup uchar *, - uint3, uint, uint); +typedef decltype(kernel_mul_mm_id) mat_mm_id_t; -template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +#if QK_K == 64 +template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +#else +template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +#endif // // matrix-vector multiplication // -[[host_name("kernel_mul_mv_id_f32_f32")]] -kernel void kernel_mul_mv_id_f32_f32( - device const char * ids, +typedef void (kernel_mul_mv_impl_t)( + device const char * src0, + device const char * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg); + +typedef void (kernel_mul_mv2_impl_t)( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg); + +template +void mmv_fn( + device const char * src0, device const char * src1, device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_f32_f32_impl( - src0[id], - src1 + bid*nb11, - dst + bid*ne0, - ne00, - ne01, - ne02, - nb00, - nb01, - nb02, - ne10, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg); + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + int64_t ne13, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint64_t nb1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { + impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg); } -[[host_name("kernel_mul_mv_id_f16_f32")]] -kernel void kernel_mul_mv_id_f16_f32( - device const char * ids, +template +void mmv_fn( + device const char * src0, device const char * src1, device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_f16_f32_impl( - src0[id], - src1 + bid*nb11, - dst + bid*ne0, - ne00, - ne01, - ne02, - nb00, - nb01, - nb02, - ne10, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg); + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + int64_t ne13, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint64_t nb1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { + impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); } -[[host_name("kernel_mul_mv_id_q8_0_f32")]] -kernel void kernel_mul_mv_id_q8_0_f32( - device const char * ids, +typedef decltype(mmv_fn) mul_mv_impl_fn_t; + +template +kernel void kernel_mul_mv_id( + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, + constant int64_t & nei0, + constant int64_t & nei1, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -4527,610 +6775,80 @@ kernel void kernel_mul_mv_id_q8_0_f32( constant int64_t & ne0, constant int64_t & ne1, constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + const int iid1 = tgpig.z/nei0; + const int idx = tgpig.z%nei0; - const int64_t bid = tgpig.z/(ne12*ne13); + tgpig.z = 0; - tgpig.z = tgpig.z%(ne12*ne13); + const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx]; - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + const int64_t i11 = idx % ne11; + const int64_t i12 = iid1; - kernel_mul_mv_q8_0_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, + const int64_t i1 = idx; + const int64_t i2 = i12; + + device const char * src0_cur = src0s + i02*nb02; + device const char * src1_cur = src1 + i11*nb11 + i12*nb12; + device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; + + impl_fn( + /* src0 */ src0_cur, + /* src1 */ src1_cur, + /* dst */ dst_cur, + /* ne00 */ ne00, + /* ne01 */ ne01, + /* ne02 */ 1,//ne02, + /* nb00 */ nb00, + /* nb01 */ nb01, + /* nb02 */ nb02, + /* ne10 */ ne10, + /* ne11 */ 1,//ne11, + /* ne12 */ 1,//ne12, + /* ne13 */ 1,//ne13, + /* nb10 */ nb10, + /* nb11 */ nb11, + /* nb12 */ nb12, + /* ne0 */ ne0, + /* ne1 */ 1,//ne1, + /* nb1 */ nb1, + /* r2 */ 1, + /* r3 */ 1, + shared_values, tgpig, + tiitg, tiisg, sgitg); } -[[host_name("kernel_mul_mv_id_q4_0_f32")]] -kernel void kernel_mul_mv_id_q4_0_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; +typedef decltype(kernel_mul_mv_id>) kernel_mul_mv_id_t; - const int64_t bid = tgpig.z/(ne12*ne13); +template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +#if QK_K != 64 +template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +#endif - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - mul_vec_q_n_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q4_1_f32")]] -kernel void kernel_mul_mv_id_q4_1_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - mul_vec_q_n_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q5_0_f32")]] -kernel void kernel_mul_mv_id_q5_0_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - mul_vec_q_n_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q5_1_f32")]] -kernel void kernel_mul_mv_id_q5_1_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - mul_vec_q_n_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q2_K_f32")]] -kernel void kernel_mul_mv_id_q2_K_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_q2_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q3_K_f32")]] -kernel void kernel_mul_mv_id_q3_K_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_q3_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q4_K_f32")]] -kernel void kernel_mul_mv_id_q4_K_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_q4_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q5_K_f32")]] -kernel void kernel_mul_mv_id_q5_K_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_q5_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q6_K_f32")]] -kernel void kernel_mul_mv_id_q6_K_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_q6_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} diff --git a/enjoy/lib/whisper.cpp/arm64/darwin/main b/enjoy/lib/whisper.cpp/arm64/darwin/main index aa2921bc..a64d3fe6 100755 Binary files a/enjoy/lib/whisper.cpp/arm64/darwin/main and b/enjoy/lib/whisper.cpp/arm64/darwin/main differ diff --git a/enjoy/lib/whisper.cpp/arm64/darwin/quantize b/enjoy/lib/whisper.cpp/arm64/darwin/quantize deleted file mode 100755 index da51a0a6..00000000 Binary files a/enjoy/lib/whisper.cpp/arm64/darwin/quantize and /dev/null differ diff --git a/enjoy/lib/whisper.cpp/x64/darwin/bench b/enjoy/lib/whisper.cpp/x64/darwin/bench deleted file mode 100755 index 689596c1..00000000 Binary files a/enjoy/lib/whisper.cpp/x64/darwin/bench and /dev/null differ diff --git a/enjoy/lib/whisper.cpp/x64/darwin/ggml-metal.metal b/enjoy/lib/whisper.cpp/x64/darwin/ggml-metal.metal index a7d3f9ef..57fdf564 100644 --- a/enjoy/lib/whisper.cpp/x64/darwin/ggml-metal.metal +++ b/enjoy/lib/whisper.cpp/x64/darwin/ggml-metal.metal @@ -1,3 +1,7 @@ +#define GGML_COMMON_DECL_METAL +#define GGML_COMMON_IMPL_METAL +#include "ggml-common.h" + #include using namespace metal; @@ -6,46 +10,11 @@ using namespace metal; #define MIN(x, y) ((x) < (y) ? (x) : (y)) #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } -#define QK4_0 32 -#define QR4_0 2 -typedef struct { - half d; // delta - uint8_t qs[QK4_0 / 2]; // nibbles / quants -} block_q4_0; - -#define QK4_1 32 -typedef struct { - half d; // delta - half m; // min - uint8_t qs[QK4_1 / 2]; // nibbles / quants -} block_q4_1; - -#define QK5_0 32 -typedef struct { - half d; // delta - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_0 / 2]; // nibbles / quants -} block_q5_0; - -#define QK5_1 32 -typedef struct { - half d; // delta - half m; // min - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_1 / 2]; // nibbles / quants -} block_q5_1; - -#define QK8_0 32 -typedef struct { - half d; // delta - int8_t qs[QK8_0]; // quants -} block_q8_0; - #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 enum ggml_sort_order { - GGML_SORT_ASC, - GGML_SORT_DESC, + GGML_SORT_ORDER_ASC, + GGML_SORT_ORDER_DESC, }; // general-purpose kernel for addition, multiplication and division of two tensors @@ -244,6 +213,15 @@ kernel void kernel_scale_4( dst[tpig] = src0[tpig] * scale; } +kernel void kernel_clamp( + device const float * src0, + device float * dst, + constant float & min, + constant float & max, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]); +} + kernel void kernel_relu( device const float * src0, device float * dst, @@ -251,6 +229,13 @@ kernel void kernel_relu( dst[tpig] = max(0.0f, src0[tpig]); } +kernel void kernel_sigmoid( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); +} + kernel void kernel_tanh( device const float * src0, device float * dst, @@ -264,6 +249,15 @@ constant float GELU_QUICK_COEF = -1.702f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; kernel void kernel_gelu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -277,6 +271,15 @@ kernel void kernel_gelu( } kernel void kernel_gelu_quick( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_gelu_quick_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -286,6 +289,14 @@ kernel void kernel_gelu_quick( } kernel void kernel_silu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + +kernel void kernel_silu_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -348,15 +359,20 @@ kernel void kernel_sum_rows( dst_row[0] = row_sum; } +template kernel void kernel_soft_max( - device const float * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, constant float & scale, - threadgroup float * buf [[threadgroup(0)]], + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], @@ -366,15 +382,27 @@ kernel void kernel_soft_max( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; - device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; + device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const int64_t h = i02; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } // parallel max float lmax = -INFINITY; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)); + lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); } // find the max value in the block @@ -399,7 +427,7 @@ kernel void kernel_soft_max( // parallel sum float lsum = 0.0f; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); + const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); lsum += exp_psrc0; pdst[i00] = exp_psrc0; } @@ -434,15 +462,20 @@ kernel void kernel_soft_max( } } +template kernel void kernel_soft_max_4( - device const float * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, constant float & scale, - threadgroup float * buf [[threadgroup(0)]], + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], @@ -452,15 +485,26 @@ kernel void kernel_soft_max_4( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + + float slope = 1.0f; + + if (max_bias > 0.0f) { + const int64_t h = i02; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } // parallel max float4 lmax4 = -INFINITY; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)); + lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); @@ -486,7 +530,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } @@ -523,6 +567,14 @@ kernel void kernel_soft_max_4( } } +typedef decltype(kernel_soft_max) kernel_soft_max_t; +typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t; + +template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; +template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; + kernel void kernel_diag_mask_inf( device const float * src0, device float * dst, @@ -862,6 +914,7 @@ void mul_vec_q_n_f32_impl( int64_t ne1, uint r2, uint r3, + threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; @@ -938,7 +991,7 @@ kernel void kernel_mul_mv_q4_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -964,7 +1017,7 @@ kernel void kernel_mul_mv_q4_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -990,7 +1043,7 @@ kernel void kernel_mul_mv_q5_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -1016,7 +1069,7 @@ kernel void kernel_mul_mv_q5_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } @@ -1026,18 +1079,19 @@ void kernel_mul_mv_q8_0_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nr = N_DST; const int nsg = N_SIMDGROUP; const int nw = N_SIMDWIDTH; @@ -1115,7 +1169,7 @@ kernel void kernel_mul_mv_q8_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } #define N_F32_F32 4 @@ -1124,24 +1178,24 @@ void kernel_mul_mv_f32_f32_impl( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg) { const int64_t r0 = tgpig.x; const int64_t rb = tgpig.y*N_F32_F32; @@ -1394,24 +1448,24 @@ void kernel_mul_mv_f16_f32_impl( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg) { const int64_t r0 = tgpig.x; const int64_t rb = tgpig.y*N_F16_F32; @@ -1544,60 +1598,6 @@ kernel void kernel_mul_mv_f16_f32_l4( } } -kernel void kernel_alibi_f32( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant float & m0, - constant float & m1, - constant int & n_heads_log2_floor, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - const int64_t k = i3*ne3 + i2; - - float m_k; - if (k < n_heads_log2_floor) { - m_k = pow(m0, k + 1); - } else { - m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1; - device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01; - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - const float src_v = *(device float *)(src_row + i00*nb00); - device float * dst_v = (device float *)(dst_row + i00*nb0); - *dst_v = i00 * m_k + src_v; - } -} - static float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / max(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y)); @@ -1775,9 +1775,29 @@ kernel void kernel_rope( template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; -kernel void kernel_im2col_f16( +typedef void (im2col_t)( device const float * x, - device half * dst, + device char * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template +kernel void kernel_im2col( + device const float * x, + device char * dst, constant int32_t & ofs0, constant int32_t & ofs1, constant int32_t & IW, @@ -1800,14 +1820,19 @@ kernel void kernel_im2col_f16( (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]); + device T * pdst = (device T *) (dst); + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; + pdst[offset_dst] = 0.0f; } else { const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1; - dst[offset_dst] = x[offset_src + iih * IW + iiw]; + pdst[offset_dst] = x[offset_src + iih * IW + iiw]; } } +template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; +template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; + kernel void kernel_upscale_f32( device const char * src0, device char * dst, @@ -1899,11 +1924,56 @@ kernel void kernel_pad_f32( } } +kernel void kernel_arange_f32( + device char * dst, + constant int64_t & ne0, + constant float & start, + constant float & step, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + device float * dst_ptr = (device float *) dst; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + dst_ptr[i0] = start + step * i0; + } +} + +kernel void kernel_timestep_embedding_f32( + device const char * src0, + device char * dst, + constant uint64_t & nb1, + constant int & dim, + constant int & max_period, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + int i = tgpig.x; + device float * embed_data = (device float *)(dst + i*nb1); + + int half_ = dim / 2; + for (int j = tpitg.x; j < half_; j += ntg.x) { + float timestep = ((device float *)src0)[i]; + float freq = (float)exp(-log((float)max_period) * j / half_); + float arg = timestep * freq; + embed_data[j ] = cos(arg); + embed_data[j + half_] = sin(arg); + } + + if (dim % 2 != 0 && tpitg.x == 0) { + embed_data[dim] = 0.f; + } +} + // bitonic sort implementation following the CUDA kernels as reference typedef void (argsort_t)( - device const float * x, - device int32_t * dst, - constant int64_t & ncols, + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]); @@ -1912,33 +1982,42 @@ kernel void kernel_argsort_f32_i32( device const float * x, device int32_t * dst, constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]) { // bitonic sort int col = tpitg[0]; int row = tgpig[1]; - if (col >= ncols) return; + if (col >= ncols_pad) return; - device const float * x_row = x + row * ncols; - device int32_t * dst_row = dst + row * ncols; + device const float * x_row = x + row * ncols; + threadgroup int32_t * dst_row = shared_values; // initialize indices - if (col < ncols) { - dst_row[col] = col; - } + dst_row[col] = col; + threadgroup_barrier(mem_flags::mem_threadgroup); - for (int k = 2; k <= ncols; k *= 2) { + for (int k = 2; k <= ncols_pad; k *= 2) { for (int j = k / 2; j > 0; j /= 2) { int ixj = col ^ j; if (ixj > col) { if ((col & k) == 0) { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { SWAP(dst_row[col], dst_row[ixj]); } } else { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { SWAP(dst_row[col], dst_row[ixj]); } } @@ -1946,10 +2025,15 @@ kernel void kernel_argsort_f32_i32( threadgroup_barrier(mem_flags::mem_threadgroup); } } + + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } } -template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; -template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; kernel void kernel_leaky_relu_f32( device const float * src0, @@ -1959,6 +2043,654 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } +typedef void (flash_attn_ext_f16_t)( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, + constant uint64_t & nb31, + constant int64_t & ne1, + constant int64_t & ne2, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup half * shared, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]); + +// ref: https://arxiv.org/pdf/2307.08691.pdf +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, + constant uint64_t & nb31, + constant int64_t & ne1, + constant int64_t & ne2, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]*Q; + + const short D4 = D/4; + const short D8 = D/8; + //const short Q8 = Q/8; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + const short TF = T/2; // shared memory size per query in (float) + const short T4 = T/4; // shared memory size per query in (half4) + + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + simdgroup_half8x8 lo[D8]; + + // load heads from Q to shared memory + for (short j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 + j < ne01) { + sq4[j*T4 + i] = (half4) q4[i]; + } else { + sq4[j*T4 + i] = 0.0h; + } + } + } + + // zero out lo + for (short i = 0; i < D8; ++i) { + lo[i] = make_filled_simdgroup_matrix(0.0h); + } + + // zero out shared memory SH + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < SH; i += NW) { + ss[j*TF + i] = 0.0f; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + float S[Q] = { [0 ... Q-1] = 0.0h }; + float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2/rk2; + const short ik3 = iq3/rk3; + + // v indices + const short iv2 = iq2/rv2; + const short iv3 = iq3/rv3; + + // load the queries from shared memory into local memory + simdgroup_half8x8 mq[D8]; + + for (short i = 0; i < D8; ++i) { + simdgroup_load(mq[i], sq + i*8, T); + } + + // pointer to the mask + device const half * mp = (device const half *) (mask + iq1*nb31); + + // prepare diagonal scale matrix + simdgroup_float8x8 mscale(scale); + + // prepare diagonal slope matrix + simdgroup_float8x8 mslope(1.0f); + + // ALiBi + if (max_bias > 0.0f) { + const uint32_t h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + mslope = simdgroup_float8x8(pow(base, exph)); + } + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { + for (short cc = 0; cc < C/8; ++cc) { + simdgroup_float8x8 mqk = make_filled_simdgroup_matrix(0.h); + + device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); + + for (short i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose + + simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + } + + if (mask != q) { + // mqk = mqk*scale + mask*slope + simdgroup_half8x8 mm; + simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); + simdgroup_multiply(mm, mslope, mm); + simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); + } else { + // mqk = mqk*scale + simdgroup_multiply(mqk, mscale, mqk); + } + + simdgroup_store(mqk, ss + 8*cc, TF, 0, false); + } + } + + // used to detect blocks full of -INF + float smax = -INFINITY; + + // online softmax + { + float ms[Q]; + + for (short j = 0; j < Q; ++j) { + const short p = tiisg; + + const float m = M[j]; + const float s = ss[j*TF + p]; + + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + + ms[j] = exp(m - M[j]); + const float vs = exp(s - M[j]); + + S[j] = S[j]*ms[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*TF + p] = vs; + } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*TF + C + tiisg] = ms[tiisg]; + } + } + + // skip -INF blocks + if (smax == -INFINITY) { + continue; + } + + // O = diag(ms)*O + { + simdgroup_float8x8 mm; + simdgroup_load(mm, ss + C, TF, 0, false); + + for (short i = 0; i < D8; ++i) { + simdgroup_multiply(lo[i], mm, lo[i]); + } + } + + // O = O + (Q*K^T)*V + { + for (short cc = 0; cc < C/8; ++cc) { + device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + for (short i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); + + simdgroup_float8x8 mv; + simdgroup_load(mv, ss + 8*cc, TF, 0, false); + + simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]); + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + for (short j = 0; j < Q; ++j) { + if (tiisg == 0) { + ss[j*TF + 0] = S[j]; + ss[j*TF + 1] = M[j]; + } + } + } + + // reduce the warps sequentially + for (short sg = 1; sg < nsg; ++sg) { + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // each simdgroup stores its output to shared memory, reusing sq + if (sgitg == sg) { + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // the first simdgroup accumulates the results from the other simdgroups + if (sgitg == 0) { + for (short j = 0; j < Q; ++j) { + const float S0 = ss[j*TF + 0]; + const float S1 = ss[j*TF + sg*SH + 0]; + + const float M0 = ss[j*TF + 1]; + const float M1 = ss[j*TF + sg*SH + 1]; + + M = max(M0, M1); + + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[j*TF + 0] = S; + ss[j*TF + 1] = M; + + ss[j*TF + C + j ] = ms0; + ss[j*TF + C + j + sg*SH] = ms1; + } + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + { + simdgroup_half8x8 t; + simdgroup_float8x8 ms0; + simdgroup_float8x8 ms1; + + simdgroup_load(ms0, ss + C, TF, 0, false); + simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); + + for (short i = 0; i < D8; ++i) { + simdgroup_load (t, sq + i*8, T, 0, false); + simdgroup_multiply(t, ms1, t); + + simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + } + } + } + } + + // store result to shared memory (reuse sq) + if (sgitg == 0) { + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + for (short j = 0; j < Q && iq1 + j < ne01; ++j) { + const float S = ss[j*TF + 0]; + + for (short i = tiisg; i < D4; i += NW) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; + } + } + } +} + +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; + +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_vec_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, + constant uint64_t & nb31, + constant int64_t & ne1, + constant int64_t & ne2, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]; + + const short D4 = D/4; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const uint32_t h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 + threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + half4 lo[D4/NW]; + + // load heads from Q to shared memory + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 < ne01) { + sq4[i] = (half4) q4[i]; + } else { + sq4[i] = 0.0h; + } + } + + // zero out lo + for (short i = tiisg; i < D4; i += NW) { + lo[i/NW] = 0.0h; + } + + // zero out shared memory SH + for (short i = tiisg; i < SH/4; i += NW) { + ss4[i] = 0.0h; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2 / rk2; + const short ik3 = iq3 / rk3; + + // v indices + const short iv2 = iq2 / rv2; + const short iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + half4 mq[D4]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + mq[i] = sq4[i]; + } + + // pointer to the mask + device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + float4 mqk = { 0.0h }; + + device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + half4x4 mk; + mk[0] = pk4[i + 0*(nb11/8)]; + mk[1] = pk4[i + 1*(nb11/8)]; + mk[2] = pk4[i + 2*(nb11/8)]; + mk[3] = pk4[i + 3*(nb11/8)]; + + mqk += (float4) (mq[i] * mk); + } + + // reduce the results from the threads in the simdgroup + mqk += simd_shuffle_down(mqk, 16); + mqk += simd_shuffle_down(mqk, 8); + mqk += simd_shuffle_down(mqk, 4); + mqk += simd_shuffle_down(mqk, 2); + mqk += simd_shuffle_down(mqk, 1); + + // mqk = mqk*scale + mask*slope + if (tiisg == 0) { + mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f); + + ss4[cc] = mqk; + } + } + } + + // online softmax + { + const short p = tiisg; + + const float m = M; + const float s = ss[p]; + + M = simd_max(max(M, s)); + + const float ms = exp(m - M); + const float vs = exp(s - M); + + S = S*ms + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[p] = vs; + + // O = diag(ms)*O +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + lo[i/NW] *= ms; + } + } + + // O = O + (Q*K^T)*V + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; + lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; + lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; + lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; + } + } + } + + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + } + + // store results to shared memory + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = lo[ii/NW]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // parallel reduce + for (short r = nsg/2; r > 0; r >>= 1) { + if (sgitg < r) { + const float S0 = ss[ 0]; + const float S1 = ss[r*SH + 0]; + + const float M0 = ss[ 1]; + const float M1 = ss[r*SH + 1]; + + const float M = max(M0, M1); + + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); + + const float S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + const float S = ss[0]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S; + } + } +} + +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, @@ -2079,7 +2811,8 @@ kernel void kernel_cpy_f32_f16( for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; + // TODO: is there a better way to handle -INFINITY? + dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0]; } } @@ -2316,6 +3049,242 @@ kernel void kernel_cpy_f32_q4_1( } } +kernel void kernel_cpy_f32_q5_0( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0; + + device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK5_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -16; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK5_0].d = d; + + uint32_t qh = 0; + for (int j = 0; j < QK5_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK5_0/2 + j]*id; + + const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); + const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); + + dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + for (int j = 0; j < 4; ++j) { + dst_data[i00/QK5_0].qh[j] = qh8[j]; + } + } +} + +kernel void kernel_cpy_f32_q5_1( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1; + + device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float max = src[0]; + float min = src[0]; + + for (int j = 1; j < QK5_1; j++) { + const float v = src[j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK5_1].d = d; + dst_data[i00/QK5_1].m = min; + + uint32_t qh = 0; + for (int j = 0; j < QK5_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK5_1/2 + j] - min)*id; + + const uint8_t xi0 = (uint8_t)(x0 + 0.5f); + const uint8_t xi1 = (uint8_t)(x1 + 0.5f); + + dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); + } + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + for (int j = 0; j < 4; ++j) { + dst_data[i00/QK5_1].qh[j] = qh8[j]; + } + } +} + +static inline int best_index_int8(int n, constant float * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} + +constexpr constant static float kvalues_iq4nl_f[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +kernel void kernel_cpy_f32_iq4_nl( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL; + + device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK4_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / kvalues_iq4nl_f[0]; + const float id = d ? 1.0f/d : 0.0f; + + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK4_NL/2 + j]*id; + + const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0); + const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1); + + dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4); + + const float v0 = kvalues_iq4nl_f[xi0]; + const float v1 = kvalues_iq4nl_f[xi1]; + const float w0 = src[0 + j]*src[0 + j]; + const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j]; + sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + + } + + dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d; + + } +} + kernel void kernel_concat( device const char * src0, device const char * src1, @@ -2372,98 +3341,23 @@ kernel void kernel_concat( } } -//============================================ k-quants ====================================================== - -#ifndef QK_K -#define QK_K 256 -#else -static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64"); -#endif - -#if QK_K == 256 -#define K_SCALE_SIZE 12 -#else -#define K_SCALE_SIZE 4 -#endif - -typedef struct { - uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits - uint8_t qs[QK_K/4]; // quants - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins -} block_q2_K; -// 84 bytes / block - -typedef struct { - uint8_t hmask[QK_K/8]; // quants - high bit - uint8_t qs[QK_K/4]; // quants - low 2 bits -#if QK_K == 64 - uint8_t scales[2]; -#else - uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits -#endif - half d; // super-block scale -} block_q3_K; - -#if QK_K == 64 -typedef struct { - half d[2]; // super-block scales/mins - uint8_t scales[2]; - uint8_t qs[QK_K/2]; // 4-bit quants -} block_q4_K; -#else -typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_K; -#endif - -#if QK_K == 64 -typedef struct { - half d; // super-block scales/mins - int8_t scales[QK_K/16]; // 8-bit block scales - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -#else -typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -// 176 bytes / block -#endif - -typedef struct { - uint8_t ql[QK_K/2]; // quants, lower 4 bits - uint8_t qh[QK_K/4]; // quants, upper 2 bits - int8_t scales[QK_K/16]; // scales, quantized with 8 bits - half d; // super-block scale -} block_q6_K; -// 210 bytes / block - -//====================================== dot products ========================= - void kernel_mul_mv_q2_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -2623,7 +3517,7 @@ kernel void kernel_mul_mv_q2_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } #if QK_K == 256 @@ -2631,18 +3525,19 @@ void kernel_mul_mv_q3_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; @@ -2798,6 +3693,7 @@ void kernel_mul_mv_q3_K_f32_impl( constant int64_t & ne1, constant uint & r2, constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -2887,7 +3783,7 @@ kernel void kernel_mul_mv_q3_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } #if QK_K == 256 @@ -2895,18 +3791,19 @@ void kernel_mul_mv_q4_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; @@ -3017,6 +3914,7 @@ void kernel_mul_mv_q4_K_f32_impl( constant int64_t & ne1, constant uint & r2, constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -3125,25 +4023,26 @@ kernel void kernel_mul_mv_q4_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q5_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; @@ -3331,25 +4230,26 @@ kernel void kernel_mul_mv_q5_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q6_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const uint8_t kmask1 = 0x03; const uint8_t kmask2 = 0x0C; @@ -3465,7 +4365,1182 @@ kernel void kernel_mul_mv_q6_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +// ======================= "True" 2-bit + +void kernel_mul_mv_iq2_xxs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + { + int nval = 4; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_xxs * xr = x + ibl; + device const uint16_t * q2 = xr->qs + 4 * ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + device const uint8_t * aux8 = (device const uint8_t *)q2; + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = db * (0.5f + (aux32 >> 28)); + + float sum = 0; + for (int l = 0; l < 4; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]); + const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + sumf[row] += d * sum; + + dh += nb*sizeof(block_iq2_xxs)/2; + q2 += nb*sizeof(block_iq2_xxs)/2; + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_xxs_f32")]] +kernel void kernel_mul_mv_iq2_xxs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq2_xs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512); + { + int nval = 8; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_xs * xr = x + ibl; + device const uint16_t * q2 = xr->qs + 4 * ib; + device const uint8_t * sc = xr->scales + ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const uint8_t ls1 = sc[0] & 0xf; + const uint8_t ls2 = sc[0] >> 4; + const float d1 = db * (0.5f + ls1); + const float d2 = db * (0.5f + ls2); + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < 2; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); + const uint8_t signs = shared_signs[(q2[l] >> 9)]; + for (int j = 0; j < 8; ++j) { + sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + for (int l = 2; l < 4; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); + const uint8_t signs = shared_signs[(q2[l] >> 9)]; + for (int j = 0; j < 8; ++j) { + sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + sumf[row] += d1 * sum1 + d2 * sum2; + + dh += nb*sizeof(block_iq2_xs)/2; + q2 += nb*sizeof(block_iq2_xs)/2; + sc += nb*sizeof(block_iq2_xs); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_xs_f32")]] +kernel void kernel_mul_mv_iq2_xs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq3_xxs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + { + int nval = 4; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq3_xxs * xr = x + ibl; + device const uint8_t * q3 = xr->qs + 8 * ib; + device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float d = db * (0.5f + (aux32 >> 28)); + + float2 sum = {0}; + for (int l = 0; l < 4; ++l) { + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]); + const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 4; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + } + sumf[row] += d * (sum[0] + sum[1]); + + dh += nb*sizeof(block_iq3_xxs)/2; + q3 += nb*sizeof(block_iq3_xxs); + gas += nb*sizeof(block_iq3_xxs)/2; + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f; + } + } +} + +[[host_name("kernel_mul_mv_iq3_xxs_f32")]] +kernel void kernel_mul_mv_iq3_xxs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq3_s_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; + { + int nval = 8; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq3_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 8 * ib; + device const uint8_t * qh = xr->qh + ib; + device const uint8_t * sc = xr->scales + (ib/2); + device const uint8_t * signs = xr->signs + 4 * ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); + + float2 sum = {0}; + for (int l = 0; l < 4; ++l) { + const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values; + const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values; + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); + for (int j = 0; j < 4; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); + sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); + } + } + sumf[row] += d * (sum[0] + sum[1]); + + dh += nb*sizeof(block_iq3_s)/2; + qs += nb*sizeof(block_iq3_s); + qh += nb*sizeof(block_iq3_s); + sc += nb*sizeof(block_iq3_s); + signs += nb*sizeof(block_iq3_s); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_iq3_s_f32")]] +kernel void kernel_mul_mv_iq3_s_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq2_s_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + //{ + // int nval = 32; + // int pos = (32*sgitg + tiisg)*nval; + // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i]; + // threadgroup_barrier(mem_flags::mem_threadgroup); + //} + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint8_t * qh = xr->qh + ib; + device const uint8_t * sc = xr->scales + ib; + device const uint8_t * signs = qs + QK_K/8; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const float d1 = db * (0.5f + (sc[0] & 0xf)); + const float d2 = db * (0.5f + (sc[0] >> 4)); + + float2 sum = {0}; + for (int l = 0; l < 2; ++l) { + //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + for (int j = 0; j < 8; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]); + sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]); + } + } + sumf[row] += d1 * sum[0] + d2 * sum[1]; + + dh += nb*sizeof(block_iq2_s)/2; + qs += nb*sizeof(block_iq2_s); + qh += nb*sizeof(block_iq2_s); + sc += nb*sizeof(block_iq2_s); + signs += nb*sizeof(block_iq2_s); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_s_f32")]] +kernel void kernel_mul_mv_iq2_s_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq1_s_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_value, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + float sumy = 0; + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + sumy += yl[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq1_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint16_t * qh = xr->qh + ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700))); + constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700))); + constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700))); + + float sum = 0; + for (int j = 0; j < 4; ++j) { + sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) + + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) + + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); + } + sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1); + + dh += nb*sizeof(block_iq1_s)/2; + qs += nb*sizeof(block_iq1_s); + qh += nb*sizeof(block_iq1_s)/2; + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +void kernel_mul_mv_iq1_m_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_value, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + +#if QK_K != 64 + iq1m_scale_t scale; +#endif + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + float4 sumy = {0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+24]; sumy[3] += yl[i+24]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq1_m * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint8_t * qh = xr->qh + 2 * ib; + device const uint16_t * sc = (device const uint16_t *)xr->scales; + + for (int row = 0; row < N_DST; row++) { + +#if QK_K != 64 + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); +#endif + + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); + constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700))); + constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700))); + + float2 sum = {0.f}; + for (int j = 0; j < 4; ++j) { + sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4); + sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) + + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); + } + const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); +#if QK_K == 64 + const float d = (float) *((device const half *)(sc - 1)); + sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) + + (sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1)); +#else + sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + + (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); +#endif + + sc += nb*sizeof(block_iq1_m)/2; + qs += nb*sizeof(block_iq1_m); + qh += nb*sizeof(block_iq1_m); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +void kernel_mul_mv_iq4_nl_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + const int nb = ne00/QK4_NL; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/2; // 0...15 + const int it = tiisg%2; // 0 or 1 + + shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK4_NL + it * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ib = ix; ib < nb; ib += 16) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + + for (int row = 0; row < 2; ++row) { + + device const block_iq4_nl & xb = x[row*nb + ib]; + device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = q4[0] | (q4[1] << 16); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = q4[2] | (q4[3] << 16); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + acc1 += acc2; + + sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + + } + + yb += 16 * QK4_NL; + } + + for (int row = 0; row < 2; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +#if QK_K != 64 +void kernel_mul_mv_iq4_xs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/16; // 0 or 1 + const int it = tiisg%16; // 0...15 + const int ib = it/2; + const int il = it%2; + + shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK_K + ib * 32 + il * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ibl = ix; ibl < nb; ibl += 2) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + + for (int row = 0; row < 2; ++row) { + + device const block_iq4_xs & xb = x[row*nb + ibl]; + device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = q4[0] & 0x0f0f0f0f; + aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = q4[1] & 0x0f0f0f0f; + aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + acc1 += acc2; + + const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; + sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + + } + + yb += 2 * QK_K; + } + + for (int row = 0; row < 2; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} +#endif + +[[host_name("kernel_mul_mv_iq1_s_f32")]] +kernel void kernel_mul_mv_iq1_s_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq1_m_f32")]] +kernel void kernel_mul_mv_iq1_m_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq4_nl_f32")]] +kernel void kernel_mul_mv_iq4_nl_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq4_xs_f32")]] +kernel void kernel_mul_mv_iq4_xs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + +#if QK_K == 64 + kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +#else + kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +#endif } //============================= templates and their specializations ============================= @@ -3620,8 +5695,8 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : (scale_2&kmask2) | ((scale_1&kmask1) << 4); - half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h); - const half ml = 4.h * dl; + float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); + const float ml = 4.f * dl; il = (il/2) & 3; const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); @@ -3663,6 +5738,8 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg const float dl = d * sc[0]; const float ml = min * sc[1]; #else + (void) get_scale_min_k4_just2; + q = q + 16 * (il&1); device const uint8_t * s = xb->scales; device const half2 * dh = (device const half2 *)xb->d; @@ -3688,7 +5765,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg uint8_t ul = 1 << (il/2); il = il & 3; const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const float d = il < 2 ? xb->d : xb->d / 16.h; + const float d = il < 2 ? xb->d : xb->d / 16.f; const float min = xb->dmin; const float dl = d * sc[0]; const float ml = min * sc[1]; @@ -3721,17 +5798,17 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg #if QK_K == 256 ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); qh = qh + 32*(il/8) + 16*(il&1); - half sc = scales[(il%2) + 2 * ((il/2))]; + float sc = scales[(il%2) + 2 * ((il/2))]; il = (il/2) & 3; #else ql = ql + 16 * (il&1); - half sc = scales[il]; + float sc = scales[il]; #endif const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; - const half coef = il>1 ? 1.f/16.h : 1.h; - const half ml = d_all * sc * 32.h; - const half dl = d_all * sc * coef; + const float coef = il>1 ? 1.f/16.f : 1.f; + const float ml = d_all * sc * 32.f; + const float dl = d_all * sc * coef; for (int i = 0; i < 16; ++i) { const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); @@ -3739,6 +5816,215 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg } } +template +void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's. + device const uint16_t * q2 = xb->qs + 4*ib32; + const uint32_t aux32_g = q2[0] | (q2[1] << 16); + const uint32_t aux32_s = q2[2] | (q2[3] << 16); + thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g; + const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]); + uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]); + signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + +template +void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint16_t * q2 = xb->qs + 4*ib32; + const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511)); + uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511)); + signs = ksigns_iq2xs[q2[2*il+1] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + +template +void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * q3 = xb->qs + 8*ib32; + device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32; + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f; + constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]); + constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]); + uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127]; + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); + reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); + } + grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]); + grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]); + signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127]; + for (int i = 0; i < 4; ++i) { + reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); + reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); + } +} + +template +void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * qs = xb->qs + 8*ib32; + device const uint8_t * signs = xb->signs + 4*ib32 + 2*il; + const uint8_t qh = xb->qh[ib32] >> 4*il; + const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); + constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256))); + constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); + reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); + } + grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256))); + grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256))); + for (int i = 0; i < 4; ++i) { + reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); + reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); + } +} + +template +void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint8_t * signs = qs + QK_K/8; + const uint8_t qh = xb->qh[ib32] >> 4*il; + const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; + constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300))); + constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300))); + for (int i = 0; i < 8; ++i) { + reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]); + reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]); + } +} + +template +void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + const float d = xb->d; + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint16_t * qh = xb->qh; + const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1); + const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA); + const uint16_t h = qh[ib32] >> 6*il; + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) + ml; + reg[1][i] = dl * (grid1[i] >> 4) + ml; + reg[2][i] = dl * (grid2[i] & 0xf) + ml; + reg[3][i] = dl * (grid2[i] >> 4) + ml; + } +} + +template +void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + device const uint16_t * sc = (device const uint16_t *)xb->scales; +#if QK_K == 64 + const float d = xb->d; +#else + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const float d = scale.f16; +#endif + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint8_t * qh = xb->qh + 2*ib32 + il; +#if QK_K == 64 + const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1); +#else + const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1); +#endif + const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) + ml1; + reg[1][i] = dl * (grid1[i] >> 4) + ml1; + reg[2][i] = dl * (grid2[i] & 0xf) + ml2; + reg[3][i] = dl * (grid2[i] >> 4) + ml2; + } +} + +template +void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) { + device const uint16_t * q4 = (device const uint16_t *)xb->qs; + const float d = xb->d; + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f; + reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; + reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; + reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; + reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; + } +} + +template +void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) { +#if QK_K == 64 + dequantize_iq4_nl(xb, il, reg); +#else + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32; + const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4); + const float d = (float)xb->d * (ls - 32); + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f; + reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; + reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; + reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; + reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; + } +#endif +} + template kernel void kernel_get_rows( device const void * src0, @@ -4002,25 +6288,25 @@ void kernel_mul_mm_impl(device const uchar * src0, } } -// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids +// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids template void kernel_mul_mm_id_impl( device const uchar * src0, device const uchar * src1, - thread short * src1ids, + threadgroup ushort2 * rowids, device float * dst, constant int64_t & ne00, constant int64_t & ne02, constant uint64_t & nb01, constant uint64_t & nb02, + constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, constant int64_t & ne0, int64_t ne1, - constant uint & r2, - constant uint & r3, + int64_t ne0ne1, threadgroup uchar * shared_memory, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], @@ -4031,7 +6317,6 @@ void kernel_mul_mm_id_impl( const uint r0 = tgpig.y; const uint r1 = tgpig.x; - const uint im = tgpig.z; if (r1 * BLOCK_SIZE_N >= ne1) return; @@ -4049,19 +6334,16 @@ void kernel_mul_mm_id_impl( for (int i = 0; i < 8; i++){ c_res[i] = make_filled_simdgroup_matrix(0.f); } - short il = (tiitg % THREAD_PER_ROW); - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); ushort offset1 = il/nl; - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col]; + + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1; device const float * y = (device const float *)(src1 - + nb12 * im - + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col] + + nb12 * id[1] + + nb11 * (id[0] % ne11) + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { @@ -4090,11 +6372,11 @@ void kernel_mul_mm_id_impl( for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { for (int i = 0; i < 4; i++) { - simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); + simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); } simdgroup_barrier(mem_flags::mem_none); for (int i = 0; i < 2; i++) { - simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); + simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); } lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; @@ -4116,11 +6398,13 @@ void kernel_mul_mm_id_impl( threadgroup_barrier(mem_flags::mem_threadgroup); - device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0; + device float * C = dst + (BLOCK_SIZE_M * r0); if (sgitg == 0) { - for (int i = 0; i < n_rows; i++) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; + int joff = jid[0] * ne0 + jid[1] * ne0ne1; + for (int i = 0; i < n_rows; i++) { + *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M); } } } @@ -4171,14 +6455,18 @@ kernel void kernel_mul_mm(device const uchar * src0, template kernel void kernel_mul_mm_id( - device const uchar * ids, + device const uchar * src0s, device const uchar * src1, device float * dst, + device const uchar * ids, + constant int64_t & nei0, + constant int64_t & nei1, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne02, constant uint64_t & nb01, constant uint64_t & nb02, + constant int64_t & ne11, constant int64_t & ne12, constant int64_t & ne13, constant uint64_t & nb10, @@ -4187,55 +6475,52 @@ kernel void kernel_mul_mm_id( constant int64_t & ne0, constant int64_t & ne1, constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const uchar * src00, - device const uchar * src01, - device const uchar * src02, - device const uchar * src03, - device const uchar * src04, - device const uchar * src05, - device const uchar * src06, - device const uchar * src07, threadgroup uchar * shared_memory [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - // expert id - const int32_t id = tgpig.z/(ne12*ne13); + const int32_t i02 = tgpig.z; + tgpig.z = 0; - tgpig.z = tgpig.z%(ne12*ne13); + device const uchar * src0 = src0s + i02*nb02; - // row indices of src1 for expert id + // row indices + threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); + + // TODO: parallelize this loop int64_t _ne1 = 0; - short src1ids[512]; - - for (int64_t i1 = 0; i1 < ne1; i1++) { - if (((device int32_t *) (ids + i1*nbi1))[idx] == id) { - src1ids[_ne1++] = i1; + for (ushort ii1 = 0; ii1 < nei1; ii1++) { + for (ushort ii0 = 0; ii0 < nei0; ii0++) { + int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; + if (id == i02) { + //if (tiitg == 0) { + rowids[_ne1] = ushort2(ii0, ii1); + //} + _ne1++; + } } } + threadgroup_barrier(mem_flags::mem_threadgroup); + kernel_mul_mm_id_impl( - src0s[id], + src0, src1, - src1ids, + rowids, dst, ne00, ne02, nb01, nb02, + ne11, ne12, nb10, nb11, nb12, ne0, _ne1, - r2, - r3, + ne0*ne1, shared_memory, tgpig, tiitg, @@ -4278,238 +6563,201 @@ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows; +#if QK_K == 64 +template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows; +#else +template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows; +#endif // // matrix-matrix multiplication // -typedef void (mat_mm_t)( - device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar *, - uint3, uint, uint); +typedef decltype(kernel_mul_mm) mat_mm_t; -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; +#if QK_K == 64 +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +#else +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +#endif // // indirect matrix-matrix multiplication // -typedef void (mat_mm_id_t)( - device const uchar * ids, - device const uchar * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const uchar * src00, - device const uchar * src01, - device const uchar * src02, - device const uchar * src03, - device const uchar * src04, - device const uchar * src05, - device const uchar * src06, - device const uchar * src07, - threadgroup uchar *, - uint3, uint, uint); +typedef decltype(kernel_mul_mm_id) mat_mm_id_t; -template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +#if QK_K == 64 +template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +#else +template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +#endif // // matrix-vector multiplication // -[[host_name("kernel_mul_mv_id_f32_f32")]] -kernel void kernel_mul_mv_id_f32_f32( - device const char * ids, +typedef void (kernel_mul_mv_impl_t)( + device const char * src0, + device const char * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg); + +typedef void (kernel_mul_mv2_impl_t)( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg); + +template +void mmv_fn( + device const char * src0, device const char * src1, device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_f32_f32_impl( - src0[id], - src1 + bid*nb11, - dst + bid*ne0, - ne00, - ne01, - ne02, - nb00, - nb01, - nb02, - ne10, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg); + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + int64_t ne13, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint64_t nb1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { + impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg); } -[[host_name("kernel_mul_mv_id_f16_f32")]] -kernel void kernel_mul_mv_id_f16_f32( - device const char * ids, +template +void mmv_fn( + device const char * src0, device const char * src1, device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_f16_f32_impl( - src0[id], - src1 + bid*nb11, - dst + bid*ne0, - ne00, - ne01, - ne02, - nb00, - nb01, - nb02, - ne10, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg); + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + int64_t ne13, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint64_t nb1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { + impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); } -[[host_name("kernel_mul_mv_id_q8_0_f32")]] -kernel void kernel_mul_mv_id_q8_0_f32( - device const char * ids, +typedef decltype(mmv_fn) mul_mv_impl_fn_t; + +template +kernel void kernel_mul_mv_id( + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, + constant int64_t & nei0, + constant int64_t & nei1, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -4527,610 +6775,80 @@ kernel void kernel_mul_mv_id_q8_0_f32( constant int64_t & ne0, constant int64_t & ne1, constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + const int iid1 = tgpig.z/nei0; + const int idx = tgpig.z%nei0; - const int64_t bid = tgpig.z/(ne12*ne13); + tgpig.z = 0; - tgpig.z = tgpig.z%(ne12*ne13); + const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx]; - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + const int64_t i11 = idx % ne11; + const int64_t i12 = iid1; - kernel_mul_mv_q8_0_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, + const int64_t i1 = idx; + const int64_t i2 = i12; + + device const char * src0_cur = src0s + i02*nb02; + device const char * src1_cur = src1 + i11*nb11 + i12*nb12; + device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; + + impl_fn( + /* src0 */ src0_cur, + /* src1 */ src1_cur, + /* dst */ dst_cur, + /* ne00 */ ne00, + /* ne01 */ ne01, + /* ne02 */ 1,//ne02, + /* nb00 */ nb00, + /* nb01 */ nb01, + /* nb02 */ nb02, + /* ne10 */ ne10, + /* ne11 */ 1,//ne11, + /* ne12 */ 1,//ne12, + /* ne13 */ 1,//ne13, + /* nb10 */ nb10, + /* nb11 */ nb11, + /* nb12 */ nb12, + /* ne0 */ ne0, + /* ne1 */ 1,//ne1, + /* nb1 */ nb1, + /* r2 */ 1, + /* r3 */ 1, + shared_values, tgpig, + tiitg, tiisg, sgitg); } -[[host_name("kernel_mul_mv_id_q4_0_f32")]] -kernel void kernel_mul_mv_id_q4_0_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; +typedef decltype(kernel_mul_mv_id>) kernel_mul_mv_id_t; - const int64_t bid = tgpig.z/(ne12*ne13); +template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +#if QK_K != 64 +template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +#endif - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - mul_vec_q_n_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q4_1_f32")]] -kernel void kernel_mul_mv_id_q4_1_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - mul_vec_q_n_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q5_0_f32")]] -kernel void kernel_mul_mv_id_q5_0_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - mul_vec_q_n_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q5_1_f32")]] -kernel void kernel_mul_mv_id_q5_1_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - mul_vec_q_n_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q2_K_f32")]] -kernel void kernel_mul_mv_id_q2_K_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_q2_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q3_K_f32")]] -kernel void kernel_mul_mv_id_q3_K_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_q3_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q4_K_f32")]] -kernel void kernel_mul_mv_id_q4_K_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_q4_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q5_K_f32")]] -kernel void kernel_mul_mv_id_q5_K_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_q5_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q6_K_f32")]] -kernel void kernel_mul_mv_id_q6_K_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_q6_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} diff --git a/enjoy/lib/whisper.cpp/x64/darwin/main b/enjoy/lib/whisper.cpp/x64/darwin/main index 42426045..f937f800 100755 Binary files a/enjoy/lib/whisper.cpp/x64/darwin/main and b/enjoy/lib/whisper.cpp/x64/darwin/main differ diff --git a/enjoy/lib/whisper.cpp/x64/darwin/quantize b/enjoy/lib/whisper.cpp/x64/darwin/quantize deleted file mode 100755 index 0531a529..00000000 Binary files a/enjoy/lib/whisper.cpp/x64/darwin/quantize and /dev/null differ diff --git a/enjoy/lib/whisper.cpp/x64/linux/bench b/enjoy/lib/whisper.cpp/x64/linux/bench deleted file mode 100755 index 09dcafa8..00000000 Binary files a/enjoy/lib/whisper.cpp/x64/linux/bench and /dev/null differ diff --git a/enjoy/lib/whisper.cpp/x64/linux/main b/enjoy/lib/whisper.cpp/x64/linux/main index 7c4b3ba8..a3c25087 100755 Binary files a/enjoy/lib/whisper.cpp/x64/linux/main and b/enjoy/lib/whisper.cpp/x64/linux/main differ diff --git a/enjoy/lib/whisper.cpp/x64/linux/quantize b/enjoy/lib/whisper.cpp/x64/linux/quantize deleted file mode 100755 index 2880bb5f..00000000 Binary files a/enjoy/lib/whisper.cpp/x64/linux/quantize and /dev/null differ diff --git a/enjoy/lib/whisper.cpp/x64/win32/bench.exe b/enjoy/lib/whisper.cpp/x64/win32/bench.exe deleted file mode 100644 index dca5fb6f..00000000 Binary files a/enjoy/lib/whisper.cpp/x64/win32/bench.exe and /dev/null differ diff --git a/enjoy/lib/whisper.cpp/x64/win32/main.exe b/enjoy/lib/whisper.cpp/x64/win32/main.exe index 53c47058..65c45db5 100644 Binary files a/enjoy/lib/whisper.cpp/x64/win32/main.exe and b/enjoy/lib/whisper.cpp/x64/win32/main.exe differ diff --git a/enjoy/lib/whisper.cpp/x64/win32/quantize.exe b/enjoy/lib/whisper.cpp/x64/win32/quantize.exe deleted file mode 100644 index 8105373b..00000000 Binary files a/enjoy/lib/whisper.cpp/x64/win32/quantize.exe and /dev/null differ diff --git a/enjoy/lib/whisper.cpp/x64/win32/whisper.dll b/enjoy/lib/whisper.cpp/x64/win32/whisper.dll index addfc979..6f63a22e 100644 Binary files a/enjoy/lib/whisper.cpp/x64/win32/whisper.dll and b/enjoy/lib/whisper.cpp/x64/win32/whisper.dll differ diff --git a/enjoy/package.json b/enjoy/package.json index ba70f41d..d862a648 100644 --- a/enjoy/package.json +++ b/enjoy/package.json @@ -49,20 +49,20 @@ "@types/fluent-ffmpeg": "^2.1.24", "@types/html-to-text": "^9.0.4", "@types/intl-tel-input": "^18.1.4", - "@types/lodash": "^4.17.1", + "@types/lodash": "^4.17.4", "@types/mark.js": "^8.11.12", - "@types/node": "^20.12.11", + "@types/node": "^20.12.12", "@types/react": "^18.3.2", "@types/react-dom": "^18.3.0", "@types/validator": "^13.11.10", "@types/wavesurfer.js": "^6.0.12", - "@typescript-eslint/eslint-plugin": "^7.8.0", - "@typescript-eslint/parser": "^7.8.0", + "@typescript-eslint/eslint-plugin": "^7.9.0", + "@typescript-eslint/parser": "^7.9.0", "@vitejs/plugin-react": "^4.2.1", "autoprefixer": "^10.4.19", - "electron": "^30.0.3", + "electron": "^30.0.6", "electron-playwright-helpers": "^1.7.1", - "eslint": "^9.2.0", + "eslint": "^9.3.0", "eslint-import-resolver-typescript": "^3.6.1", "eslint-plugin-import": "^2.29.1", "flora-colossus": "^2.0.0", @@ -77,14 +77,14 @@ "typescript": "^5.4.5", "vite": "^5.2.11", "vite-plugin-static-copy": "^1.0.4", - "zx": "^8.0.2" + "zx": "^8.1.0" }, "dependencies": { "@andrkrn/ffprobe-static": "^5.2.0", "@electron-forge/publisher-s3": "^7.4.0", - "@hookform/resolvers": "^3.3.4", - "@langchain/community": "^0.0.56", - "@langchain/google-genai": "^0.0.12", + "@hookform/resolvers": "^3.4.0", + "@langchain/community": "^0.2.0", + "@langchain/google-genai": "^0.0.14", "@mozilla/readability": "^0.5.0", "@radix-ui/react-accordion": "^1.1.2", "@radix-ui/react-alert-dialog": "^1.0.5", @@ -117,7 +117,7 @@ "axios": "^1.6.8", "camelcase": "^8.0.0", "camelcase-keys": "^9.1.3", - "chart.js": "^4.4.2", + "chart.js": "^4.4.3", "cheerio": "^1.0.0-rc.12", "class-variance-authority": "^0.7.0", "clsx": "^2.1.1", @@ -133,22 +133,22 @@ "electron-context-menu": "^4.0.0", "electron-log": "^5.1.4", "electron-settings": "^4.0.4", - "electron-squirrel-startup": "^1.0.0", + "electron-squirrel-startup": "^1.0.1", "ffmpeg-static": "^5.2.0", "fluent-ffmpeg": "^2.1.2", "fs-extra": "^11.2.0", "html-to-text": "^9.0.5", "https-proxy-agent": "^7.0.4", "i18next": "^23.11.4", - "intl-tel-input": "^23.0.4", + "intl-tel-input": "^23.0.7", "js-md5": "^0.8.3", - "langchain": "^0.1.37", + "langchain": "^0.2.0", "lodash": "^4.17.21", "lucide-react": "^0.378.0", "mark.js": "^8.11.1", "microsoft-cognitiveservices-speech-sdk": "^1.36.0", "next-themes": "^0.3.0", - "openai": "^4.45.0", + "openai": "^4.47.1", "pitchfinder": "^2.3.2", "postcss": "^8.4.38", "proxy-agent": "^6.4.0", @@ -170,7 +170,7 @@ "tailwind-scrollbar-hide": "^1.1.7", "umzug": "^3.8.0", "update-electron-app": "^3.0.0", - "wavesurfer.js": "^7.7.14", + "wavesurfer.js": "^7.7.15", "zod": "^3.23.8", "zod-to-json-schema": "^3.23.0" } diff --git a/enjoy/src/renderer/hooks/use-conversation.tsx b/enjoy/src/renderer/hooks/use-conversation.tsx index c72352ae..9ee887d6 100644 --- a/enjoy/src/renderer/hooks/use-conversation.tsx +++ b/enjoy/src/renderer/hooks/use-conversation.tsx @@ -3,7 +3,7 @@ import { AISettingsProviderContext, } from "@renderer/context"; import { useContext } from "react"; -import { ChatMessageHistory, BufferMemory } from "langchain/memory"; +import { ChatMessageHistory, BufferMemory } from "langchain/memory/index"; import { ConversationChain } from "langchain/chains"; import { ChatOpenAI } from "@langchain/openai"; import { ChatOllama } from "@langchain/community/chat_models/ollama"; @@ -13,7 +13,7 @@ import { MessagesPlaceholder, } from "@langchain/core/prompts"; import OpenAI from "openai"; -import { type Generation } from "langchain/dist/schema"; +import { type LLMResult } from "@langchain/core/outputs"; import { v4 } from "uuid"; export const useConversation = () => { @@ -163,7 +163,7 @@ export const useConversation = () => { prompt, verbose: true, }); - let response: Generation[] = []; + let response: LLMResult["generations"][0] = []; await chain.call({ input: message.content }, [ { handleLLMEnd: async (output) => { diff --git a/yarn.lock b/yarn.lock index 9c6b8ec9..89f5ba4b 100644 --- a/yarn.lock +++ b/yarn.lock @@ -15,9 +15,9 @@ __metadata: markdown-it-mathjax3: "npm:^4.3.2" markdown-it-sub: "npm:^2.0.0" markdown-it-sup: "npm:^2.0.0" - mermaid: "npm:^10.9.0" - sass: "npm:^1.77.1" - vitepress: "npm:^1.1.4" + mermaid: "npm:^10.9.1" + sass: "npm:^1.77.2" + vitepress: "npm:^1.2.0" vitepress-plugin-mermaid: "npm:^2.0.16" vue: "npm:^3.4.27" languageName: unknown @@ -234,23 +234,6 @@ __metadata: languageName: node linkType: hard -"@anthropic-ai/sdk@npm:^0.9.1": - version: 0.9.1 - resolution: "@anthropic-ai/sdk@npm:0.9.1" - dependencies: - "@types/node": "npm:^18.11.18" - "@types/node-fetch": "npm:^2.6.4" - abort-controller: "npm:^3.0.0" - agentkeepalive: "npm:^4.2.1" - digest-fetch: "npm:^1.3.0" - form-data-encoder: "npm:1.7.2" - formdata-node: "npm:^4.3.2" - node-fetch: "npm:^2.6.7" - web-streams-polyfill: "npm:^3.2.1" - checksum: 10c0/de00ec551a5731a30254e9a3a3ed76f05c0865da6cddf32d4010659f72ca2d25985dff7173af850f92ede1e3695e00105f9689aeb875f0e82dc93f1597be0d05 - languageName: node - linkType: hard - "@aws-crypto/crc32@npm:3.0.0": version: 3.0.0 resolution: "@aws-crypto/crc32@npm:3.0.0" @@ -2729,9 +2712,9 @@ __metadata: languageName: node linkType: hard -"@eslint/eslintrc@npm:^3.0.2": - version: 3.0.2 - resolution: "@eslint/eslintrc@npm:3.0.2" +"@eslint/eslintrc@npm:^3.1.0": + version: 3.1.0 + resolution: "@eslint/eslintrc@npm:3.1.0" dependencies: ajv: "npm:^6.12.4" debug: "npm:^4.3.2" @@ -2742,14 +2725,14 @@ __metadata: js-yaml: "npm:^4.1.0" minimatch: "npm:^3.1.2" strip-json-comments: "npm:^3.1.1" - checksum: 10c0/d8c92f06bdf8e2be9fcc0eeac4a9351745174adfcc72571ef3d179101cb55e19f15f6385c2a4f4945a3ba9245802d3371208e2e1e4f00f6bcf6b8711656af85a + checksum: 10c0/5b7332ed781edcfc98caa8dedbbb843abfb9bda2e86538529c843473f580e40c69eb894410eddc6702f487e9ee8f8cfa8df83213d43a8fdb549f23ce06699167 languageName: node linkType: hard -"@eslint/js@npm:9.2.0": - version: 9.2.0 - resolution: "@eslint/js@npm:9.2.0" - checksum: 10c0/89632466d329d9dd68c6ec24290e407f0950ca8c4b7f3750b82457daa7f6233799ccbc956cd84231f9544efbefddd69833ee82658883ca673cfca9e4b8e0713a +"@eslint/js@npm:9.3.0": + version: 9.3.0 + resolution: "@eslint/js@npm:9.3.0" + checksum: 10c0/1550c68922eb17c3ce541d9ffbb687e656bc0d4e2fff2d9cb0a75101a9cdb6184b7af7378ccf46c6ca750b280d25d45cc5e7374a83a4ee89d404d3717502fd9d languageName: node linkType: hard @@ -2805,12 +2788,12 @@ __metadata: languageName: node linkType: hard -"@hookform/resolvers@npm:^3.3.4": - version: 3.3.4 - resolution: "@hookform/resolvers@npm:3.3.4" +"@hookform/resolvers@npm:^3.4.0": + version: 3.4.0 + resolution: "@hookform/resolvers@npm:3.4.0" peerDependencies: react-hook-form: ^7.0.0 - checksum: 10c0/f7a5b8ba59cbb0859e7a212bd5cbcbc70bbdddd21d4fbe9f4a96d149b5756470cb29857a50334d8c1c64392e21007ccf5288d26aa407431784d4006a2570cb36 + checksum: 10c0/017b2620b2496c0a2aa4f30f16a58e10b05a12fd3a603338c9aae92b20404c30dfaf86fb1f4f873db95f4756707cf5c4b88cd4e22327f2ca9a43daf495946f86 languageName: node linkType: hard @@ -2846,10 +2829,10 @@ __metadata: languageName: node linkType: hard -"@humanwhocodes/retry@npm:^0.2.3": - version: 0.2.3 - resolution: "@humanwhocodes/retry@npm:0.2.3" - checksum: 10c0/0913d89bb5cb1f0a049a6c068dee312d30920d5cce5a07588cd91fcb5453af52f2a9826d07d465066b92ad7bc0545e9f59384c414abe27745c79141c78a25099 +"@humanwhocodes/retry@npm:^0.3.0": + version: 0.3.0 + resolution: "@humanwhocodes/retry@npm:0.3.0" + checksum: 10c0/7111ec4e098b1a428459b4e3be5a5d2a13b02905f805a2468f4fa628d072f0de2da26a27d04f65ea2846f73ba51f4204661709f05bfccff645e3cedef8781bb6 languageName: node linkType: hard @@ -2935,14 +2918,17 @@ __metadata: languageName: node linkType: hard -"@langchain/community@npm:^0.0.56": - version: 0.0.56 - resolution: "@langchain/community@npm:0.0.56" +"@langchain/community@npm:^0.2.0": + version: 0.2.0 + resolution: "@langchain/community@npm:0.2.0" dependencies: - "@langchain/core": "npm:~0.1.60" + "@langchain/core": "npm:~0.2.0" "@langchain/openai": "npm:~0.0.28" + binary-extensions: "npm:^2.2.0" expr-eval: "npm:^2.0.2" flat: "npm:^5.0.2" + js-yaml: "npm:^4.1.0" + langchain: "npm:~0.2.0" langsmith: "npm:~0.1.1" uuid: "npm:^9.0.0" zod: "npm:^3.22.3" @@ -2954,10 +2940,12 @@ __metadata: "@aws-sdk/client-dynamodb": ^3.310.0 "@aws-sdk/client-kendra": ^3.352.0 "@aws-sdk/client-lambda": ^3.310.0 + "@aws-sdk/client-s3": ^3.310.0 "@aws-sdk/client-sagemaker-runtime": ^3.310.0 "@aws-sdk/client-sfn": ^3.310.0 "@aws-sdk/credential-provider-node": ^3.388.0 "@azure/search-documents": ^12.0.0 + "@azure/storage-blob": ^12.15.0 "@clickhouse/client": ^0.2.5 "@cloudflare/ai": "*" "@datastax/astra-db-ts": ^1.0.0 @@ -2967,10 +2955,14 @@ __metadata: "@gomomento/sdk": ^1.51.1 "@gomomento/sdk-core": ^1.51.1 "@google-ai/generativelanguage": ^0.2.1 + "@google-cloud/storage": ^6.10.1 || ^7.7.0 "@gradientai/nodejs-sdk": ^1.2.0 "@huggingface/inference": ^2.6.4 + "@mendable/firecrawl-js": ^0.0.13 + "@mlc-ai/web-llm": ^0.2.35 "@mozilla/readability": "*" "@neondatabase/serverless": "*" + "@notionhq/client": ^2.2.10 "@opensearch-project/opensearch": "*" "@pinecone-database/pinecone": "*" "@planetscale/database": ^1.8.0 @@ -2995,9 +2987,12 @@ __metadata: "@xata.io/client": ^0.28.0 "@xenova/transformers": ^2.5.4 "@zilliz/milvus2-sdk-node": ">=2.2.7" + apify-client: ^2.7.1 + assemblyai: ^4.0.0 better-sqlite3: ^9.4.0 cassandra-driver: ^4.7.2 cborg: ^4.1.1 + cheerio: ^1.0.0-rc.12 chromadb: "*" closevector-common: 0.1.3 closevector-node: 0.1.6 @@ -3005,15 +3000,18 @@ __metadata: cohere-ai: "*" convex: ^1.3.1 couchbase: ^4.3.0 + d3-dsv: ^2.0.0 discord.js: ^14.14.1 dria: ^0.0.3 duck-duck-scrape: ^2.2.5 + epub2: ^3.0.1 faiss-node: ^0.5.1 firebase-admin: ^11.9.0 || ^12.0.0 google-auth-library: ^8.9.0 googleapis: ^126.0.1 hnswlib-node: ^3.0.0 html-to-text: ^9.0.5 + ignore: ^5.2.0 interface-datastore: ^8.2.11 ioredis: ^5.3.2 it-all: ^3.0.4 @@ -3022,16 +3020,24 @@ __metadata: llmonitor: ^0.5.9 lodash: ^4.17.21 lunary: ^0.6.11 + mammoth: ^1.6.0 mongodb: ">=5.2.0" mysql2: ^3.3.3 neo4j-driver: "*" node-llama-cpp: "*" + notion-to-md: ^3.1.0 + officeparser: ^4.0.4 + pdf-parse: 1.1.1 pg: ^8.11.0 pg-copy-streams: ^6.0.5 pickleparser: ^0.2.1 + playwright: ^1.32.1 portkey-ai: ^0.1.11 + puppeteer: ^19.7.2 redis: "*" replicate: ^0.18.0 + sonix-speech-recognition: ^2.1.1 + srt-parser-2: ^1.2.3 typeorm: ^0.3.12 typesense: ^1.5.3 usearch: ^1.1.1 @@ -3040,6 +3046,8 @@ __metadata: weaviate-ts-client: "*" web-auth-library: ^1.0.3 ws: ^8.14.2 + youtube-transcript: ^1.0.6 + youtubei.js: ^9.1.0 peerDependenciesMeta: "@aws-crypto/sha256-js": optional: true @@ -3053,6 +3061,8 @@ __metadata: optional: true "@aws-sdk/client-lambda": optional: true + "@aws-sdk/client-s3": + optional: true "@aws-sdk/client-sagemaker-runtime": optional: true "@aws-sdk/client-sfn": @@ -3061,6 +3071,8 @@ __metadata: optional: true "@azure/search-documents": optional: true + "@azure/storage-blob": + optional: true "@clickhouse/client": optional: true "@cloudflare/ai": @@ -3079,14 +3091,22 @@ __metadata: optional: true "@google-ai/generativelanguage": optional: true + "@google-cloud/storage": + optional: true "@gradientai/nodejs-sdk": optional: true "@huggingface/inference": optional: true + "@mendable/firecrawl-js": + optional: true + "@mlc-ai/web-llm": + optional: true "@mozilla/readability": optional: true "@neondatabase/serverless": optional: true + "@notionhq/client": + optional: true "@opensearch-project/opensearch": optional: true "@pinecone-database/pinecone": @@ -3135,12 +3155,18 @@ __metadata: optional: true "@zilliz/milvus2-sdk-node": optional: true + apify-client: + optional: true + assemblyai: + optional: true better-sqlite3: optional: true cassandra-driver: optional: true cborg: optional: true + cheerio: + optional: true chromadb: optional: true closevector-common: @@ -3155,12 +3181,16 @@ __metadata: optional: true couchbase: optional: true + d3-dsv: + optional: true discord.js: optional: true dria: optional: true duck-duck-scrape: optional: true + epub2: + optional: true faiss-node: optional: true firebase-admin: @@ -3173,6 +3203,8 @@ __metadata: optional: true html-to-text: optional: true + ignore: + optional: true interface-datastore: optional: true ioredis: @@ -3189,6 +3221,8 @@ __metadata: optional: true lunary: optional: true + mammoth: + optional: true mongodb: optional: true mysql2: @@ -3197,18 +3231,32 @@ __metadata: optional: true node-llama-cpp: optional: true + notion-to-md: + optional: true + officeparser: + optional: true + pdf-parse: + optional: true pg: optional: true pg-copy-streams: optional: true pickleparser: optional: true + playwright: + optional: true portkey-ai: optional: true + puppeteer: + optional: true redis: optional: true replicate: optional: true + sonix-speech-recognition: + optional: true + srt-parser-2: + optional: true typeorm: optional: true typesense: @@ -3225,302 +3273,35 @@ __metadata: optional: true ws: optional: true - checksum: 10c0/2206682c3bed853c38807b0cf8916b72bc013cab4f7635c7c45340ced1e58dabf697339e78cc304117067249e2f634234323f94238e0f751f623cc1e6742719a + youtube-transcript: + optional: true + youtubei.js: + optional: true + checksum: 10c0/b69c85e396307408c29f1515ffba560a0806dc1f0f60df35d8c314f98e30d8d2311f25d0dd88745eea1452e320f66f739919fa620e85815a8acd741d5af40963 languageName: node linkType: hard -"@langchain/community@npm:~0.0.47": - version: 0.0.47 - resolution: "@langchain/community@npm:0.0.47" +"@langchain/core@npm:>0.1.5 <0.3.0, @langchain/core@npm:~0.2.0": + version: 0.2.0 + resolution: "@langchain/core@npm:0.2.0" dependencies: - "@langchain/core": "npm:~0.1.56" - "@langchain/openai": "npm:~0.0.28" - expr-eval: "npm:^2.0.2" - flat: "npm:^5.0.2" - langsmith: "npm:~0.1.1" + ansi-styles: "npm:^5.0.0" + camelcase: "npm:6" + decamelize: "npm:1.2.0" + js-tiktoken: "npm:^1.0.12" + langsmith: "npm:~0.1.7" + ml-distance: "npm:^4.0.0" + mustache: "npm:^4.2.0" + p-queue: "npm:^6.6.2" + p-retry: "npm:4" uuid: "npm:^9.0.0" - zod: "npm:^3.22.3" - zod-to-json-schema: "npm:^3.22.5" - peerDependencies: - "@aws-crypto/sha256-js": ^5.0.0 - "@aws-sdk/client-bedrock-agent-runtime": ^3.485.0 - "@aws-sdk/client-bedrock-runtime": ^3.422.0 - "@aws-sdk/client-dynamodb": ^3.310.0 - "@aws-sdk/client-kendra": ^3.352.0 - "@aws-sdk/client-lambda": ^3.310.0 - "@aws-sdk/client-sagemaker-runtime": ^3.310.0 - "@aws-sdk/client-sfn": ^3.310.0 - "@aws-sdk/credential-provider-node": ^3.388.0 - "@azure/search-documents": ^12.0.0 - "@clickhouse/client": ^0.2.5 - "@cloudflare/ai": "*" - "@datastax/astra-db-ts": ^0.1.4 - "@elastic/elasticsearch": ^8.4.0 - "@getmetal/metal-sdk": "*" - "@getzep/zep-js": ^0.9.0 - "@gomomento/sdk": ^1.51.1 - "@gomomento/sdk-core": ^1.51.1 - "@google-ai/generativelanguage": ^0.2.1 - "@gradientai/nodejs-sdk": ^1.2.0 - "@huggingface/inference": ^2.6.4 - "@mozilla/readability": "*" - "@opensearch-project/opensearch": "*" - "@pinecone-database/pinecone": "*" - "@planetscale/database": ^1.8.0 - "@premai/prem-sdk": ^0.3.25 - "@qdrant/js-client-rest": ^1.2.0 - "@raycast/api": ^1.55.2 - "@rockset/client": ^0.9.1 - "@smithy/eventstream-codec": ^2.0.5 - "@smithy/protocol-http": ^3.0.6 - "@smithy/signature-v4": ^2.0.10 - "@smithy/util-utf8": ^2.0.0 - "@supabase/postgrest-js": ^1.1.1 - "@supabase/supabase-js": ^2.10.0 - "@tensorflow-models/universal-sentence-encoder": "*" - "@tensorflow/tfjs-converter": "*" - "@tensorflow/tfjs-core": "*" - "@upstash/redis": ^1.20.6 - "@upstash/vector": ^1.0.2 - "@vercel/kv": ^0.2.3 - "@vercel/postgres": ^0.5.0 - "@writerai/writer-sdk": ^0.40.2 - "@xata.io/client": ^0.28.0 - "@xenova/transformers": ^2.5.4 - "@zilliz/milvus2-sdk-node": ">=2.2.7" - better-sqlite3: ^9.4.0 - cassandra-driver: ^4.7.2 - cborg: ^4.1.1 - chromadb: "*" - closevector-common: 0.1.3 - closevector-node: 0.1.6 - closevector-web: 0.1.6 - cohere-ai: "*" - convex: ^1.3.1 - couchbase: ^4.3.0 - discord.js: ^14.14.1 - dria: ^0.0.3 - duck-duck-scrape: ^2.2.5 - faiss-node: ^0.5.1 - firebase-admin: ^11.9.0 || ^12.0.0 - google-auth-library: ^8.9.0 - googleapis: ^126.0.1 - hnswlib-node: ^1.4.2 - html-to-text: ^9.0.5 - interface-datastore: ^8.2.11 - ioredis: ^5.3.2 - it-all: ^3.0.4 - jsdom: "*" - jsonwebtoken: ^9.0.2 - llmonitor: ^0.5.9 - lodash: ^4.17.21 - lunary: ^0.6.11 - mongodb: ">=5.2.0" - mysql2: ^3.3.3 - neo4j-driver: "*" - node-llama-cpp: "*" - pg: ^8.11.0 - pg-copy-streams: ^6.0.5 - pickleparser: ^0.2.1 - portkey-ai: ^0.1.11 - redis: "*" - replicate: ^0.18.0 - typeorm: ^0.3.12 - typesense: ^1.5.3 - usearch: ^1.1.1 - vectordb: ^0.1.4 - voy-search: 0.6.2 - weaviate-ts-client: "*" - web-auth-library: ^1.0.3 - ws: ^8.14.2 - peerDependenciesMeta: - "@aws-crypto/sha256-js": - optional: true - "@aws-sdk/client-bedrock-agent-runtime": - optional: true - "@aws-sdk/client-bedrock-runtime": - optional: true - "@aws-sdk/client-dynamodb": - optional: true - "@aws-sdk/client-kendra": - optional: true - "@aws-sdk/client-lambda": - optional: true - "@aws-sdk/client-sagemaker-runtime": - optional: true - "@aws-sdk/client-sfn": - optional: true - "@aws-sdk/credential-provider-node": - optional: true - "@azure/search-documents": - optional: true - "@clickhouse/client": - optional: true - "@cloudflare/ai": - optional: true - "@datastax/astra-db-ts": - optional: true - "@elastic/elasticsearch": - optional: true - "@getmetal/metal-sdk": - optional: true - "@getzep/zep-js": - optional: true - "@gomomento/sdk": - optional: true - "@gomomento/sdk-core": - optional: true - "@google-ai/generativelanguage": - optional: true - "@gradientai/nodejs-sdk": - optional: true - "@huggingface/inference": - optional: true - "@mozilla/readability": - optional: true - "@opensearch-project/opensearch": - optional: true - "@pinecone-database/pinecone": - optional: true - "@planetscale/database": - optional: true - "@premai/prem-sdk": - optional: true - "@qdrant/js-client-rest": - optional: true - "@raycast/api": - optional: true - "@rockset/client": - optional: true - "@smithy/eventstream-codec": - optional: true - "@smithy/protocol-http": - optional: true - "@smithy/signature-v4": - optional: true - "@smithy/util-utf8": - optional: true - "@supabase/postgrest-js": - optional: true - "@supabase/supabase-js": - optional: true - "@tensorflow-models/universal-sentence-encoder": - optional: true - "@tensorflow/tfjs-converter": - optional: true - "@tensorflow/tfjs-core": - optional: true - "@upstash/redis": - optional: true - "@upstash/vector": - optional: true - "@vercel/kv": - optional: true - "@vercel/postgres": - optional: true - "@writerai/writer-sdk": - optional: true - "@xata.io/client": - optional: true - "@xenova/transformers": - optional: true - "@zilliz/milvus2-sdk-node": - optional: true - better-sqlite3: - optional: true - cassandra-driver: - optional: true - cborg: - optional: true - chromadb: - optional: true - closevector-common: - optional: true - closevector-node: - optional: true - closevector-web: - optional: true - cohere-ai: - optional: true - convex: - optional: true - couchbase: - optional: true - discord.js: - optional: true - dria: - optional: true - duck-duck-scrape: - optional: true - faiss-node: - optional: true - firebase-admin: - optional: true - google-auth-library: - optional: true - googleapis: - optional: true - hnswlib-node: - optional: true - html-to-text: - optional: true - interface-datastore: - optional: true - ioredis: - optional: true - it-all: - optional: true - jsdom: - optional: true - jsonwebtoken: - optional: true - llmonitor: - optional: true - lodash: - optional: true - lunary: - optional: true - mongodb: - optional: true - mysql2: - optional: true - neo4j-driver: - optional: true - node-llama-cpp: - optional: true - pg: - optional: true - pg-copy-streams: - optional: true - pickleparser: - optional: true - portkey-ai: - optional: true - redis: - optional: true - replicate: - optional: true - typeorm: - optional: true - typesense: - optional: true - usearch: - optional: true - vectordb: - optional: true - voy-search: - optional: true - weaviate-ts-client: - optional: true - web-auth-library: - optional: true - ws: - optional: true - checksum: 10c0/564842266ef71d5c138d01539cb301edbb12b77eaa2c809705b1fe7c28134c260d200e4bfe501b09a3b613c0c7c5574146d211013f8acf91506ac1fce634ab26 + zod: "npm:^3.22.4" + zod-to-json-schema: "npm:^3.22.3" + checksum: 10c0/408ecc311635de1ce51fef8123e1a465a8b6a4fda8f997ac306087d9f29855cb3cd96a86f295a7708dc416bf1c2583b68bf9cb7a6bf95b4c0453f9f3d528dbe8 languageName: node linkType: hard -"@langchain/core@npm:~0.1, @langchain/core@npm:~0.1.60": +"@langchain/core@npm:~0.1": version: 0.1.61 resolution: "@langchain/core@npm:0.1.61" dependencies: @@ -3540,25 +3321,6 @@ __metadata: languageName: node linkType: hard -"@langchain/core@npm:~0.1.5": - version: 0.1.40 - resolution: "@langchain/core@npm:0.1.40" - dependencies: - ansi-styles: "npm:^5.0.0" - camelcase: "npm:6" - decamelize: "npm:1.2.0" - js-tiktoken: "npm:^1.0.8" - langsmith: "npm:~0.1.7" - ml-distance: "npm:^4.0.0" - p-queue: "npm:^6.6.2" - p-retry: "npm:4" - uuid: "npm:^9.0.0" - zod: "npm:^3.22.4" - zod-to-json-schema: "npm:^3.22.3" - checksum: 10c0/dc2880bcf185f6fbe7838179bd5ae04144628213c646fcb62b092426e2a590665ae266482ee5efe5abf45fa21cd7127f9a2c398211a08d29290809f1f36d8176 - languageName: node - linkType: hard - "@langchain/core@npm:~0.1.56": version: 0.1.57 resolution: "@langchain/core@npm:0.1.57" @@ -3578,13 +3340,13 @@ __metadata: languageName: node linkType: hard -"@langchain/google-genai@npm:^0.0.12": - version: 0.0.12 - resolution: "@langchain/google-genai@npm:0.0.12" +"@langchain/google-genai@npm:^0.0.14": + version: 0.0.14 + resolution: "@langchain/google-genai@npm:0.0.14" dependencies: "@google/generative-ai": "npm:^0.7.0" - "@langchain/core": "npm:~0.1.5" - checksum: 10c0/d2849b9441590199374eb74e4c0c7a9cfdf365ad2745907c928476f48e85f7207a148cf0311ced777423634c844b43d454e0dc81283936eba4c7d486766c4159 + "@langchain/core": "npm:>0.1.5 <0.3.0" + checksum: 10c0/2fbf4fa1e435fb77ab78d3c662e1531548704b5ce26530a344fbacbd9435a52b107bc47004bd1e2357a8accf8053b51615edc45751f3b12156a7863b2945eef8 languageName: node linkType: hard @@ -5556,19 +5318,19 @@ __metadata: languageName: node linkType: hard -"@shikijs/core@npm:1.3.0, @shikijs/core@npm:^1.3.0": - version: 1.3.0 - resolution: "@shikijs/core@npm:1.3.0" - checksum: 10c0/0b11f11961a563ef91122f2a17b66af7502b0f8665f53c435aaadf39006722466379fb1acc5ed64d7d381921f6900a13647b27d0bf3fb0d619471a8bbfa34a9b +"@shikijs/core@npm:1.6.0, @shikijs/core@npm:^1.5.2": + version: 1.6.0 + resolution: "@shikijs/core@npm:1.6.0" + checksum: 10c0/017c8ff1264edf4d9cb9d556b8e6614eab1ccc11d74b89eca7afb39b6c72072e34d32a97b4fb7828acc40e4f38162ae049e287fd66830891eb55e983d6940bd6 languageName: node linkType: hard -"@shikijs/transformers@npm:^1.3.0": - version: 1.3.0 - resolution: "@shikijs/transformers@npm:1.3.0" +"@shikijs/transformers@npm:^1.5.2": + version: 1.6.0 + resolution: "@shikijs/transformers@npm:1.6.0" dependencies: - shiki: "npm:1.3.0" - checksum: 10c0/c2ff13d48ea7e5c83dbde7cfd8c1147445442b21ddf506a8463c0f53f5a81beacdc7cd0ea88ae4123477d3d2e9ae88864e4172da1b807cc851e3c356b5919393 + shiki: "npm:1.6.0" + checksum: 10c0/c204a6d8eeadf0c7b8dda09617d5eae9cbe08c13763a8efe759cf1d5681a249ed6dd104ddaa770dff05f06641329a7b6c94e5dd7ebe9ed726652213f0e717904 languageName: node linkType: hard @@ -6969,13 +6731,6 @@ __metadata: languageName: node linkType: hard -"@types/json-schema@npm:^7.0.15": - version: 7.0.15 - resolution: "@types/json-schema@npm:7.0.15" - checksum: 10c0/a996a745e6c5d60292f36731dd41341339d4eeed8180bb09226e5c8d23759067692b1d88e5d91d72ee83dfc00d3aca8e7bd43ea120516c17922cbcb7c3e252db - languageName: node - linkType: hard - "@types/json5@npm:^0.0.29": version: 0.0.29 resolution: "@types/json5@npm:0.0.29" @@ -7001,17 +6756,17 @@ __metadata: languageName: node linkType: hard -"@types/linkify-it@npm:*": - version: 3.0.5 - resolution: "@types/linkify-it@npm:3.0.5" - checksum: 10c0/696e09975991c649ba37c5585714929fdebf5c64a8bfb99910613ef838337dbbba6c608fccdfa03d6347432586ef12e139bc0e947ae6fec569096fef5cc1c550 +"@types/linkify-it@npm:^5": + version: 5.0.0 + resolution: "@types/linkify-it@npm:5.0.0" + checksum: 10c0/7bbbf45b9dde17bf3f184fee585aef0e7342f6954f0377a24e4ff42ab5a85d5b806aaa5c8d16e2faf2a6b87b2d94467a196b7d2b85c9c7de2f0eaac5487aaab8 languageName: node linkType: hard -"@types/lodash@npm:^4.17.1": - version: 4.17.1 - resolution: "@types/lodash@npm:4.17.1" - checksum: 10c0/af2ad8a3c8d7deb170a7ec6e18afc5ae8980576e5f7fe798d8a95a1df7222c15bdf967a25a35879f575a3b64743de00145710ee461a0051e055e94e4fe253f45 +"@types/lodash@npm:^4.17.4": + version: 4.17.4 + resolution: "@types/lodash@npm:4.17.4" + checksum: 10c0/0124c64cb9fe7a0f78b6777955abd05ef0d97844d49118652eae45f8fa57bfb7f5a7a9bccc0b5a84c0a6dc09631042e4590cb665acb9d58dfd5e6543c75341ec languageName: node linkType: hard @@ -7024,13 +6779,13 @@ __metadata: languageName: node linkType: hard -"@types/markdown-it@npm:^14.0.1": - version: 14.0.1 - resolution: "@types/markdown-it@npm:14.0.1" +"@types/markdown-it@npm:^14.1.1": + version: 14.1.1 + resolution: "@types/markdown-it@npm:14.1.1" dependencies: - "@types/linkify-it": "npm:*" - "@types/mdurl": "npm:*" - checksum: 10c0/d1f9fde06f6511d4518040fac296d14f2b3c42e00bb955545ef0db4e7063208f60ff7fc766305f66b539de4393f027dcdc83db53f5c0672ffd97c04409437b2c + "@types/linkify-it": "npm:^5" + "@types/mdurl": "npm:^2" + checksum: 10c0/f029e471d6927c78879de8de085302544cd43643e2435775fbe0d49be998d4a23890555b0c090c154925e3a9df01ba48514582671e4f8488d8dcf992e2d155ef languageName: node linkType: hard @@ -7052,10 +6807,10 @@ __metadata: languageName: node linkType: hard -"@types/mdurl@npm:*": - version: 1.0.5 - resolution: "@types/mdurl@npm:1.0.5" - checksum: 10c0/8991c781eb94fb3621e48e191251a94057908fc14be60f52bdd7c48684af923ffa77559ea979450a0475f85c08f8a472f99ff9c2ca4308961b9b9d35fd7584f7 +"@types/mdurl@npm:^2": + version: 2.0.0 + resolution: "@types/mdurl@npm:2.0.0" + checksum: 10c0/cde7bb571630ed1ceb3b92a28f7b59890bb38b8f34cd35326e2df43eebfc74985e6aa6fd4184e307393bad8a9e0783a519a3f9d13c8e03788c0f98e5ec869c5e languageName: node linkType: hard @@ -7099,12 +6854,12 @@ __metadata: languageName: node linkType: hard -"@types/node@npm:>=20.12.5": - version: 20.12.5 - resolution: "@types/node@npm:20.12.5" +"@types/node@npm:>=20.12.11, @types/node@npm:^20.12.12": + version: 20.12.12 + resolution: "@types/node@npm:20.12.12" dependencies: undici-types: "npm:~5.26.4" - checksum: 10c0/2da65516fba98f0417620e42bddbe53e144d4782d69cd37f99df2537c6850b9cfbdb8a017f02c61e9a074bcac84f9f3f221b250474ac8c6b95d507a47e8d53f9 + checksum: 10c0/f374b763c744e8f16e4f38cf6e2c0eef31781ec9228c9e43a6f267880fea420fab0a238b59f10a7cb3444e49547c5e3785787e371fc242307310995b21988812 languageName: node linkType: hard @@ -7124,15 +6879,6 @@ __metadata: languageName: node linkType: hard -"@types/node@npm:^20.12.11": - version: 20.12.11 - resolution: "@types/node@npm:20.12.11" - dependencies: - undici-types: "npm:~5.26.4" - checksum: 10c0/efaa7b08c50ba6e6982ac7d531ba08d5935539ba76e954807df1ff9382a319ead7e003ccaca5ad7cf33ecf53f212507f7c1f2794bd2ff52df6365fef21f6e1fb - languageName: node - linkType: hard - "@types/normalize-package-data@npm:^2.4.0": version: 2.4.4 resolution: "@types/normalize-package-data@npm:2.4.4" @@ -7200,13 +6946,6 @@ __metadata: languageName: node linkType: hard -"@types/semver@npm:^7.5.8": - version: 7.5.8 - resolution: "@types/semver@npm:7.5.8" - checksum: 10c0/8663ff927234d1c5fcc04b33062cb2b9fcfbe0f5f351ed26c4d1e1581657deebd506b41ff7fdf89e787e3d33ce05854bc01686379b89e9c49b564c4cfa988efa - languageName: node - linkType: hard - "@types/sizzle@npm:*": version: 2.3.8 resolution: "@types/sizzle@npm:2.3.8" @@ -7281,20 +7020,18 @@ __metadata: languageName: node linkType: hard -"@typescript-eslint/eslint-plugin@npm:^7.8.0": - version: 7.8.0 - resolution: "@typescript-eslint/eslint-plugin@npm:7.8.0" +"@typescript-eslint/eslint-plugin@npm:^7.9.0": + version: 7.9.0 + resolution: "@typescript-eslint/eslint-plugin@npm:7.9.0" dependencies: "@eslint-community/regexpp": "npm:^4.10.0" - "@typescript-eslint/scope-manager": "npm:7.8.0" - "@typescript-eslint/type-utils": "npm:7.8.0" - "@typescript-eslint/utils": "npm:7.8.0" - "@typescript-eslint/visitor-keys": "npm:7.8.0" - debug: "npm:^4.3.4" + "@typescript-eslint/scope-manager": "npm:7.9.0" + "@typescript-eslint/type-utils": "npm:7.9.0" + "@typescript-eslint/utils": "npm:7.9.0" + "@typescript-eslint/visitor-keys": "npm:7.9.0" graphemer: "npm:^1.4.0" ignore: "npm:^5.3.1" natural-compare: "npm:^1.4.0" - semver: "npm:^7.6.0" ts-api-utils: "npm:^1.3.0" peerDependencies: "@typescript-eslint/parser": ^7.0.0 @@ -7302,44 +7039,44 @@ __metadata: peerDependenciesMeta: typescript: optional: true - checksum: 10c0/37ca22620d1834ff0baa28fa4b8fd92039a3903cb95748353de32d56bae2a81ce50d1bbaed27487eebc884e0a0f9387fcb0f1647593e4e6df5111ef674afa9f0 + checksum: 10c0/5c0ded9cb2210c141d236075f01a86447bf497a5061773c3c64a90756264776b4c4df100f7588e36d34f727eca55afd52fe6696a3cbe2d1f131250934254603a languageName: node linkType: hard -"@typescript-eslint/parser@npm:^7.8.0": - version: 7.8.0 - resolution: "@typescript-eslint/parser@npm:7.8.0" +"@typescript-eslint/parser@npm:^7.9.0": + version: 7.9.0 + resolution: "@typescript-eslint/parser@npm:7.9.0" dependencies: - "@typescript-eslint/scope-manager": "npm:7.8.0" - "@typescript-eslint/types": "npm:7.8.0" - "@typescript-eslint/typescript-estree": "npm:7.8.0" - "@typescript-eslint/visitor-keys": "npm:7.8.0" + "@typescript-eslint/scope-manager": "npm:7.9.0" + "@typescript-eslint/types": "npm:7.9.0" + "@typescript-eslint/typescript-estree": "npm:7.9.0" + "@typescript-eslint/visitor-keys": "npm:7.9.0" debug: "npm:^4.3.4" peerDependencies: eslint: ^8.56.0 peerDependenciesMeta: typescript: optional: true - checksum: 10c0/0dd994c1b31b810c25e1b755b8d352debb7bf21a31f9a91acaec34acf4e471320bcceaa67cf64c110c0b8f5fac10a037dbabac6ec423e17adf037e59a7bce9c1 + checksum: 10c0/16ca04645429436d9b7986cddda979ef4d088f4223f4a69e04a369e0fd4852dd5ff3d4b99da2e43cddaa2b421b24ff42f275d87bd110ae2356bdd0e81c2534e7 languageName: node linkType: hard -"@typescript-eslint/scope-manager@npm:7.8.0": - version: 7.8.0 - resolution: "@typescript-eslint/scope-manager@npm:7.8.0" +"@typescript-eslint/scope-manager@npm:7.9.0": + version: 7.9.0 + resolution: "@typescript-eslint/scope-manager@npm:7.9.0" dependencies: - "@typescript-eslint/types": "npm:7.8.0" - "@typescript-eslint/visitor-keys": "npm:7.8.0" - checksum: 10c0/c253b98e96d4bf0375f473ca2c4d081726f1fd926cdfa65ee14c9ee99cca8eddb763b2d238ac365daa7246bef21b0af38180d04e56e9df7443c0e6f8474d097c + "@typescript-eslint/types": "npm:7.9.0" + "@typescript-eslint/visitor-keys": "npm:7.9.0" + checksum: 10c0/1ba6fc559a42a9b54e38c3ac2b6669efcff1a30292fb4e5fc8739c890a6c0f37d1a6aee1d115198f57c88e4f1776e95c1d7143de5cb5b970d5eb3023e97789dd languageName: node linkType: hard -"@typescript-eslint/type-utils@npm:7.8.0": - version: 7.8.0 - resolution: "@typescript-eslint/type-utils@npm:7.8.0" +"@typescript-eslint/type-utils@npm:7.9.0": + version: 7.9.0 + resolution: "@typescript-eslint/type-utils@npm:7.9.0" dependencies: - "@typescript-eslint/typescript-estree": "npm:7.8.0" - "@typescript-eslint/utils": "npm:7.8.0" + "@typescript-eslint/typescript-estree": "npm:7.9.0" + "@typescript-eslint/utils": "npm:7.9.0" debug: "npm:^4.3.4" ts-api-utils: "npm:^1.3.0" peerDependencies: @@ -7347,23 +7084,23 @@ __metadata: peerDependenciesMeta: typescript: optional: true - checksum: 10c0/00f6315626b64f7dbc1f7fba6f365321bb8d34141ed77545b2a07970e59a81dbdf768c1e024225ea00953750d74409ddd8a16782fc4a39261e507c04192dacab + checksum: 10c0/775280fb179268f8bacd60e684d9d5a1c6a379646b082c7244bf2dfb7dd693053bd9efa473b71e10a86db69322b0a2cecf5598d019684930df50000bf3d70af0 languageName: node linkType: hard -"@typescript-eslint/types@npm:7.8.0": - version: 7.8.0 - resolution: "@typescript-eslint/types@npm:7.8.0" - checksum: 10c0/b2fdbfc21957bfa46f7d8809b607ad8c8b67c51821d899064d09392edc12f28b2318a044f0cd5d523d782e84e8f0558778877944964cf38e139f88790cf9d466 +"@typescript-eslint/types@npm:7.9.0": + version: 7.9.0 + resolution: "@typescript-eslint/types@npm:7.9.0" + checksum: 10c0/d5f4a547dba4865ee2391bf06f2b3f8e8592a561976d2be35bb61ce340c7d1b7b4b25ac6ab5b9941813b465b9420bebb7b2179b1d71f6a83069feeb000b3558d languageName: node linkType: hard -"@typescript-eslint/typescript-estree@npm:7.8.0": - version: 7.8.0 - resolution: "@typescript-eslint/typescript-estree@npm:7.8.0" +"@typescript-eslint/typescript-estree@npm:7.9.0": + version: 7.9.0 + resolution: "@typescript-eslint/typescript-estree@npm:7.9.0" dependencies: - "@typescript-eslint/types": "npm:7.8.0" - "@typescript-eslint/visitor-keys": "npm:7.8.0" + "@typescript-eslint/types": "npm:7.9.0" + "@typescript-eslint/visitor-keys": "npm:7.9.0" debug: "npm:^4.3.4" globby: "npm:^11.1.0" is-glob: "npm:^4.0.3" @@ -7373,34 +7110,31 @@ __metadata: peerDependenciesMeta: typescript: optional: true - checksum: 10c0/1690b62679685073dcb0f62499f0b52b445b37ae6e12d02aa4acbafe3fb023cf999b01f714b6282e88f84fd934fe3e2eefb21a64455d19c348d22bbc68ca8e47 + checksum: 10c0/cfc3d2b7a5433c9a2989c7289bc72b49786993782801ad8ca5a07c651df457a67fbce13b120c86c34c03d56570a90e5cf4f3b8806349f103a3658f2366ec28ea languageName: node linkType: hard -"@typescript-eslint/utils@npm:7.8.0": - version: 7.8.0 - resolution: "@typescript-eslint/utils@npm:7.8.0" +"@typescript-eslint/utils@npm:7.9.0": + version: 7.9.0 + resolution: "@typescript-eslint/utils@npm:7.9.0" dependencies: "@eslint-community/eslint-utils": "npm:^4.4.0" - "@types/json-schema": "npm:^7.0.15" - "@types/semver": "npm:^7.5.8" - "@typescript-eslint/scope-manager": "npm:7.8.0" - "@typescript-eslint/types": "npm:7.8.0" - "@typescript-eslint/typescript-estree": "npm:7.8.0" - semver: "npm:^7.6.0" + "@typescript-eslint/scope-manager": "npm:7.9.0" + "@typescript-eslint/types": "npm:7.9.0" + "@typescript-eslint/typescript-estree": "npm:7.9.0" peerDependencies: eslint: ^8.56.0 - checksum: 10c0/31fb58388d15b082eb7bd5bce889cc11617aa1131dfc6950471541b3df64c82d1c052e2cccc230ca4ae80456d4f63a3e5dccb79899a8f3211ce36c089b7d7640 + checksum: 10c0/cb99d6a950e7da0319bc7b923a82c52c0798a14e837afee51b2295cfbde02e0a2ac8e0b5904cd7bd01d1b376c7a6ad3739101b486feaf2517c8640024deb88c7 languageName: node linkType: hard -"@typescript-eslint/visitor-keys@npm:7.8.0": - version: 7.8.0 - resolution: "@typescript-eslint/visitor-keys@npm:7.8.0" +"@typescript-eslint/visitor-keys@npm:7.9.0": + version: 7.9.0 + resolution: "@typescript-eslint/visitor-keys@npm:7.9.0" dependencies: - "@typescript-eslint/types": "npm:7.8.0" + "@typescript-eslint/types": "npm:7.9.0" eslint-visitor-keys: "npm:^3.4.3" - checksum: 10c0/5892fb5d9c58efaf89adb225f7dbbb77f9363961f2ff420b6b130bdd102dddd7aa8a16c46a5a71c19889d27b781e966119a89270555ea2cb5653a04d8994123d + checksum: 10c0/19181d8b9d2d7bc43d5c8884661cd9a86ac316392b8e590187cc507442093a1ba2bef0cc22181b8298d5dc9f455abb73cffa4663451bdf32b1b7fe12160c5c99 languageName: node linkType: hard @@ -7458,19 +7192,6 @@ __metadata: languageName: node linkType: hard -"@vue/compiler-core@npm:3.4.26": - version: 3.4.26 - resolution: "@vue/compiler-core@npm:3.4.26" - dependencies: - "@babel/parser": "npm:^7.24.4" - "@vue/shared": "npm:3.4.26" - entities: "npm:^4.5.0" - estree-walker: "npm:^2.0.2" - source-map-js: "npm:^1.2.0" - checksum: 10c0/3591fda1b07e929b2b83eaaee67c21075d0c7676ecb8d65c9d800d0b2db5afcf1c44a98fd7ac3d00c737de5462a74fe6788672f739a998769244dfb319bacaae - languageName: node - linkType: hard - "@vue/compiler-core@npm:3.4.27": version: 3.4.27 resolution: "@vue/compiler-core@npm:3.4.27" @@ -7484,16 +7205,6 @@ __metadata: languageName: node linkType: hard -"@vue/compiler-dom@npm:3.4.26": - version: 3.4.26 - resolution: "@vue/compiler-dom@npm:3.4.26" - dependencies: - "@vue/compiler-core": "npm:3.4.26" - "@vue/shared": "npm:3.4.26" - checksum: 10c0/3b39c5586ba0029ccef8cd6d7f7b98e1522ca8f52edba056bc46a8ec27764a6fdec2619bbe089484586cf440c07249631c9c18efe91633306cf0047f4d76d72f - languageName: node - linkType: hard - "@vue/compiler-dom@npm:3.4.27": version: 3.4.27 resolution: "@vue/compiler-dom@npm:3.4.27" @@ -7504,23 +7215,6 @@ __metadata: languageName: node linkType: hard -"@vue/compiler-sfc@npm:3.4.26": - version: 3.4.26 - resolution: "@vue/compiler-sfc@npm:3.4.26" - dependencies: - "@babel/parser": "npm:^7.24.4" - "@vue/compiler-core": "npm:3.4.26" - "@vue/compiler-dom": "npm:3.4.26" - "@vue/compiler-ssr": "npm:3.4.26" - "@vue/shared": "npm:3.4.26" - estree-walker: "npm:^2.0.2" - magic-string: "npm:^0.30.10" - postcss: "npm:^8.4.38" - source-map-js: "npm:^1.2.0" - checksum: 10c0/1a00f293c5465e23d7915c3a8a98c4d97479e167b00638c5985c2fb7539762affc90d258230f95ee705989ad9a04b001af146b905143a42b52bc5b84a36e9360 - languageName: node - linkType: hard - "@vue/compiler-sfc@npm:3.4.27": version: 3.4.27 resolution: "@vue/compiler-sfc@npm:3.4.27" @@ -7538,16 +7232,6 @@ __metadata: languageName: node linkType: hard -"@vue/compiler-ssr@npm:3.4.26": - version: 3.4.26 - resolution: "@vue/compiler-ssr@npm:3.4.26" - dependencies: - "@vue/compiler-dom": "npm:3.4.26" - "@vue/shared": "npm:3.4.26" - checksum: 10c0/ba2844604060743b431ef77652c1d2624689835c08a6dd5e348172efc537ee278b16660b4100a3b88e990c6e10f290d8c878965b7701f3cf5796d849f7a411a5 - languageName: node - linkType: hard - "@vue/compiler-ssr@npm:3.4.27": version: 3.4.27 resolution: "@vue/compiler-ssr@npm:3.4.27" @@ -7558,45 +7242,36 @@ __metadata: languageName: node linkType: hard -"@vue/devtools-api@npm:^7.0.27": - version: 7.0.27 - resolution: "@vue/devtools-api@npm:7.0.27" +"@vue/devtools-api@npm:^7.2.0": + version: 7.2.0 + resolution: "@vue/devtools-api@npm:7.2.0" dependencies: - "@vue/devtools-kit": "npm:^7.0.27" - checksum: 10c0/25db119b41e03018e9ad570cc884396280732e9104e282dd9b78c055aaa1135b993f2e12d738670f6ac340356a06de9b850b5783e7f2e8712b7f8101d42dbe6b + "@vue/devtools-kit": "npm:^7.2.0" + checksum: 10c0/c00f9afc66a9910d5694e920d947aa27824771f0a58673efc4c82b239cbe30e56176e448544c89692cffe6cd588066231dae06777993325bf3b44961928fa294 languageName: node linkType: hard -"@vue/devtools-kit@npm:^7.0.27": - version: 7.0.27 - resolution: "@vue/devtools-kit@npm:7.0.27" +"@vue/devtools-kit@npm:^7.2.0": + version: 7.2.0 + resolution: "@vue/devtools-kit@npm:7.2.0" dependencies: - "@vue/devtools-shared": "npm:^7.0.27" + "@vue/devtools-shared": "npm:^7.2.0" hookable: "npm:^5.5.3" mitt: "npm:^3.0.1" perfect-debounce: "npm:^1.0.0" speakingurl: "npm:^14.0.1" peerDependencies: vue: ^3.0.0 - checksum: 10c0/e2520bf15836a23c56ee6d6564d1e79a67a0efc50a962d1cdb9989c91cc3c6417cd65b4962a541e03cc9333e7c6795e5c680b24d375088e53e200ee090e8624b + checksum: 10c0/1434eb3329c8c60ed5c9446ce2064ad671df4fe966280ec6a24c9635262a89622f8bd0c9763d0c33e12bcd08dbe0c2663037a69256f0385b2700c3dca972c17b languageName: node linkType: hard -"@vue/devtools-shared@npm:^7.0.27": - version: 7.0.27 - resolution: "@vue/devtools-shared@npm:7.0.27" +"@vue/devtools-shared@npm:^7.2.0": + version: 7.2.0 + resolution: "@vue/devtools-shared@npm:7.2.0" dependencies: rfdc: "npm:^1.3.1" - checksum: 10c0/0d9e17066bafe2cb4039b0f007892a0773a88549c9cf546f0f99ab0f55435802d6381cf02a598733f9d4e138fb085f6a89c23d905978eab6957abf6f6d4254ca - languageName: node - linkType: hard - -"@vue/reactivity@npm:3.4.26": - version: 3.4.26 - resolution: "@vue/reactivity@npm:3.4.26" - dependencies: - "@vue/shared": "npm:3.4.26" - checksum: 10c0/6aeeced8059f3f508c584263235cd9b56e0bbdc9b72ad8c445476782958734b0e5afa3047be5d99ec1fa3eb20a3bbc9ef3df369e82c0d58ec5777ad47c100339 + checksum: 10c0/d5450ed0d1543d181eeb4453f0fd561cea959838f31dd9449a686accf9d659cec3a227c0a5b6728cedcff819aedf7d85a232b06ba3fcaa86a5d36713a084931b languageName: node linkType: hard @@ -7609,16 +7284,6 @@ __metadata: languageName: node linkType: hard -"@vue/runtime-core@npm:3.4.26": - version: 3.4.26 - resolution: "@vue/runtime-core@npm:3.4.26" - dependencies: - "@vue/reactivity": "npm:3.4.26" - "@vue/shared": "npm:3.4.26" - checksum: 10c0/01a25f14c9bbd3272367b90419851f2e037d4b4ccbe732da4a764084bad8b98dc2dbd1f690d93c0610d1bfe0d037c348003c793f58368747cba2e8eddcd09915 - languageName: node - linkType: hard - "@vue/runtime-core@npm:3.4.27": version: 3.4.27 resolution: "@vue/runtime-core@npm:3.4.27" @@ -7629,17 +7294,6 @@ __metadata: languageName: node linkType: hard -"@vue/runtime-dom@npm:3.4.26": - version: 3.4.26 - resolution: "@vue/runtime-dom@npm:3.4.26" - dependencies: - "@vue/runtime-core": "npm:3.4.26" - "@vue/shared": "npm:3.4.26" - csstype: "npm:^3.1.3" - checksum: 10c0/79e87825e7b06f114c8682599d1ad220a843eff0529e6ea5fab68f90aa7486eb96aa3f1b0ff2bc0950a3ff4d55503ad481ec163d095aef8067cd5acf02e01fb9 - languageName: node - linkType: hard - "@vue/runtime-dom@npm:3.4.27": version: 3.4.27 resolution: "@vue/runtime-dom@npm:3.4.27" @@ -7651,18 +7305,6 @@ __metadata: languageName: node linkType: hard -"@vue/server-renderer@npm:3.4.26": - version: 3.4.26 - resolution: "@vue/server-renderer@npm:3.4.26" - dependencies: - "@vue/compiler-ssr": "npm:3.4.26" - "@vue/shared": "npm:3.4.26" - peerDependencies: - vue: 3.4.26 - checksum: 10c0/d4d2bc9af7e073338208e2fd37199816e7ee816e216cb603ddaad363fd541c27e60c78c63e65c95657df3082aa0b871ee437e034c03a03134653b983c8fe2d30 - languageName: node - linkType: hard - "@vue/server-renderer@npm:3.4.27": version: 3.4.27 resolution: "@vue/server-renderer@npm:3.4.27" @@ -7675,14 +7317,7 @@ __metadata: languageName: node linkType: hard -"@vue/shared@npm:3.4.26": - version: 3.4.26 - resolution: "@vue/shared@npm:3.4.26" - checksum: 10c0/8b002badc842d94b47cf78630ca51403a8a8981bc20c35f5b669c299413f81ed5fd97dbd763d0c00fe4614c615991214bc06d709d0b57d1e5836d0262cfbebdb - languageName: node - linkType: hard - -"@vue/shared@npm:3.4.27": +"@vue/shared@npm:3.4.27, @vue/shared@npm:^3.4.27": version: 3.4.27 resolution: "@vue/shared@npm:3.4.27" checksum: 10c0/4a21918858270bcc654bb94b3429d9acbe95af097ea3063e192b36bd502dc896ca47778fa74a863b01f677ec271b189eb90f8b372943c10e52725a6bdc7f6cd5 @@ -8329,13 +7964,6 @@ __metadata: languageName: node linkType: hard -"base-64@npm:^0.1.0": - version: 0.1.0 - resolution: "base-64@npm:0.1.0" - checksum: 10c0/fe0dcf076e823f04db7ee9b02495be08a91c445fbc6db03cb9913be9680e2fcc0af8b74459041fe08ad16800b1f65a549501d8f08696a8a6d32880789b7de69d - languageName: node - linkType: hard - "base32-encode@npm:^0.1.0 || ^1.0.0": version: 1.2.0 resolution: "base32-encode@npm:1.2.0" @@ -8831,19 +8459,12 @@ __metadata: languageName: node linkType: hard -"charenc@npm:0.0.2": - version: 0.0.2 - resolution: "charenc@npm:0.0.2" - checksum: 10c0/a45ec39363a16799d0f9365c8dd0c78e711415113c6f14787a22462ef451f5013efae8a28f1c058f81fc01f2a6a16955f7a5fd0cd56247ce94a45349c89877d8 - languageName: node - linkType: hard - -"chart.js@npm:^4.4.2": - version: 4.4.2 - resolution: "chart.js@npm:4.4.2" +"chart.js@npm:^4.4.3": + version: 4.4.3 + resolution: "chart.js@npm:4.4.3" dependencies: "@kurkle/color": "npm:^0.3.0" - checksum: 10c0/4233bdc3cce9e4d5aca575f8ef9313f56867491b9eac027f3514ca807aca9a2ba2087f85ca3c68f022974f948ce688a178b0ce21d95706116abca5acf19e072f + checksum: 10c0/910113b07983d1b2f52689afcd88d0c5fecbd30fbbe89f4e122d78397fd4521e1563fd6d80f7742a391abe7cfbfa8deaf12c58b83c9636431fa77e7285d85c2e languageName: node linkType: hard @@ -9414,13 +9035,6 @@ __metadata: languageName: node linkType: hard -"crypt@npm:0.0.2": - version: 0.0.2 - resolution: "crypt@npm:0.0.2" - checksum: 10c0/adbf263441dd801665d5425f044647533f39f4612544071b1471962209d235042fb703c27eea2795c7c53e1dfc242405173003f83cf4f4761a633d11f9653f18 - languageName: node - linkType: hard - "css-select@npm:^4.3.0": version: 4.3.0 resolution: "css-select@npm:4.3.0" @@ -10189,16 +9803,6 @@ __metadata: languageName: node linkType: hard -"digest-fetch@npm:^1.3.0": - version: 1.3.0 - resolution: "digest-fetch@npm:1.3.0" - dependencies: - base-64: "npm:^0.1.0" - md5: "npm:^2.3.0" - checksum: 10c0/0fb389e33b9c6baf5e6a9ed287aa9d0d8b373d59b49d49c62c261e1ab24eaaf1d5aea3a105c1b31ba4a23e29e129365d839ce4c5974fa004a85d1a4568bc3585 - languageName: node - linkType: hard - "dir-compare@npm:^4.2.0": version: 4.2.0 resolution: "dir-compare@npm:4.2.0" @@ -10558,12 +10162,12 @@ __metadata: languageName: node linkType: hard -"electron-squirrel-startup@npm:^1.0.0": - version: 1.0.0 - resolution: "electron-squirrel-startup@npm:1.0.0" +"electron-squirrel-startup@npm:^1.0.1": + version: 1.0.1 + resolution: "electron-squirrel-startup@npm:1.0.1" dependencies: debug: "npm:^2.2.0" - checksum: 10c0/5513a61f98be8d57f3b7f54fe38e2b3c42e0d02f62ee8f172d9d040863ee790b997f030328e79e454f5aaf51aa0b8908da8fb2ab9b18a8fdd29e88b3a0688c8d + checksum: 10c0/eaff42844702321237b23f68ae2bff2eae17cb1251189e9a5397d94d1598a607e3cbedc60e130a0641798739019db44c37b2a8801e7a60f9261810bb8d07e9bf languageName: node linkType: hard @@ -10591,16 +10195,16 @@ __metadata: languageName: node linkType: hard -"electron@npm:^30.0.3": - version: 30.0.3 - resolution: "electron@npm:30.0.3" +"electron@npm:^30.0.6": + version: 30.0.6 + resolution: "electron@npm:30.0.6" dependencies: "@electron/get": "npm:^2.0.0" "@types/node": "npm:^20.9.0" extract-zip: "npm:^2.0.1" bin: electron: cli.js - checksum: 10c0/8fc437e491fd7baa132fd4cbf74c0a6b2f0ff7db373f4e198764b2385266118051e1db6599377340c6a743422f63fc04f5d85c8289498f9fc770c9ea96421405 + checksum: 10c0/92cc9a5013774f3b69c44fc964096e5de4217e297cd5d16cfd1654690d725c6d5acb1111d5637216ee3e12ee814336808208acbc24460d6bd2ac160be41443d1 languageName: node linkType: hard @@ -10698,9 +10302,9 @@ __metadata: "@electron-forge/publisher-github": "npm:^7.4.0" "@electron-forge/publisher-s3": "npm:^7.4.0" "@electron/fuses": "npm:^1.8.0" - "@hookform/resolvers": "npm:^3.3.4" - "@langchain/community": "npm:^0.0.56" - "@langchain/google-genai": "npm:^0.0.12" + "@hookform/resolvers": "npm:^3.4.0" + "@langchain/community": "npm:^0.2.0" + "@langchain/google-genai": "npm:^0.0.14" "@mozilla/readability": "npm:^0.5.0" "@playwright/test": "npm:^1.44.0" "@radix-ui/react-accordion": "npm:^1.1.2" @@ -10734,15 +10338,15 @@ __metadata: "@types/fluent-ffmpeg": "npm:^2.1.24" "@types/html-to-text": "npm:^9.0.4" "@types/intl-tel-input": "npm:^18.1.4" - "@types/lodash": "npm:^4.17.1" + "@types/lodash": "npm:^4.17.4" "@types/mark.js": "npm:^8.11.12" - "@types/node": "npm:^20.12.11" + "@types/node": "npm:^20.12.12" "@types/react": "npm:^18.3.2" "@types/react-dom": "npm:^18.3.0" "@types/validator": "npm:^13.11.10" "@types/wavesurfer.js": "npm:^6.0.12" - "@typescript-eslint/eslint-plugin": "npm:^7.8.0" - "@typescript-eslint/parser": "npm:^7.8.0" + "@typescript-eslint/eslint-plugin": "npm:^7.9.0" + "@typescript-eslint/parser": "npm:^7.9.0" "@uidotdev/usehooks": "npm:^2.4.1" "@vidstack/react": "npm:^1.11.21" "@vitejs/plugin-react": "npm:^4.2.1" @@ -10752,7 +10356,7 @@ __metadata: axios: "npm:^1.6.8" camelcase: "npm:^8.0.0" camelcase-keys: "npm:^9.1.3" - chart.js: "npm:^4.4.2" + chart.js: "npm:^4.4.3" cheerio: "npm:^1.0.0-rc.12" class-variance-authority: "npm:^0.7.0" clsx: "npm:^2.1.1" @@ -10765,13 +10369,13 @@ __metadata: decamelize: "npm:^6.0.0" decamelize-keys: "npm:^2.0.1" echogarden: "npm:^1.4.4" - electron: "npm:^30.0.3" + electron: "npm:^30.0.6" electron-context-menu: "npm:^4.0.0" electron-log: "npm:^5.1.4" electron-playwright-helpers: "npm:^1.7.1" electron-settings: "npm:^4.0.4" - electron-squirrel-startup: "npm:^1.0.0" - eslint: "npm:^9.2.0" + electron-squirrel-startup: "npm:^1.0.1" + eslint: "npm:^9.3.0" eslint-import-resolver-typescript: "npm:^3.6.1" eslint-plugin-import: "npm:^2.29.1" ffmpeg-static: "npm:^5.2.0" @@ -10781,16 +10385,16 @@ __metadata: html-to-text: "npm:^9.0.5" https-proxy-agent: "npm:^7.0.4" i18next: "npm:^23.11.4" - intl-tel-input: "npm:^23.0.4" + intl-tel-input: "npm:^23.0.7" js-md5: "npm:^0.8.3" - langchain: "npm:^0.1.37" + langchain: "npm:^0.2.0" lodash: "npm:^4.17.21" lucide-react: "npm:^0.378.0" mark.js: "npm:^8.11.1" microsoft-cognitiveservices-speech-sdk: "npm:^1.36.0" next-themes: "npm:^0.3.0" octokit: "npm:^4.0.2" - openai: "npm:^4.45.0" + openai: "npm:^4.47.1" pitchfinder: "npm:^2.3.2" postcss: "npm:^8.4.38" progress: "npm:^2.0.3" @@ -10822,10 +10426,10 @@ __metadata: update-electron-app: "npm:^3.0.0" vite: "npm:^5.2.11" vite-plugin-static-copy: "npm:^1.0.4" - wavesurfer.js: "npm:^7.7.14" + wavesurfer.js: "npm:^7.7.15" zod: "npm:^3.23.8" zod-to-json-schema: "npm:^3.23.0" - zx: "npm:^8.0.2" + zx: "npm:^8.1.0" languageName: unknown linkType: soft @@ -11215,17 +10819,17 @@ __metadata: languageName: node linkType: hard -"eslint@npm:^9.2.0": - version: 9.2.0 - resolution: "eslint@npm:9.2.0" +"eslint@npm:^9.3.0": + version: 9.3.0 + resolution: "eslint@npm:9.3.0" dependencies: "@eslint-community/eslint-utils": "npm:^4.2.0" "@eslint-community/regexpp": "npm:^4.6.1" - "@eslint/eslintrc": "npm:^3.0.2" - "@eslint/js": "npm:9.2.0" + "@eslint/eslintrc": "npm:^3.1.0" + "@eslint/js": "npm:9.3.0" "@humanwhocodes/config-array": "npm:^0.13.0" "@humanwhocodes/module-importer": "npm:^1.0.1" - "@humanwhocodes/retry": "npm:^0.2.3" + "@humanwhocodes/retry": "npm:^0.3.0" "@nodelib/fs.walk": "npm:^1.2.8" ajv: "npm:^6.12.4" chalk: "npm:^4.0.0" @@ -11255,7 +10859,7 @@ __metadata: text-table: "npm:^0.2.0" bin: eslint: bin/eslint.js - checksum: 10c0/eab3265100a359a486e40e1d9d4d3ecff936d2f4d952f4ab107d404e0684fffbe186ecd0fb41791af5bcb13570a27032ddf9a2e628927ed33473f64255b0037b + checksum: 10c0/d0cf1923408ce720f2fa0dde39094dbee6b03a71d5e6466402bbeb29c08a618713f7a1e8cfc4b4d50dbe3750ce68adf614a0e973dea6fa36ddc428dfb89d4cac languageName: node linkType: hard @@ -12928,10 +12532,10 @@ __metadata: languageName: node linkType: hard -"intl-tel-input@npm:^23.0.4": - version: 23.0.4 - resolution: "intl-tel-input@npm:23.0.4" - checksum: 10c0/bac64c7cefb9da808c0a74479b3ce8839852300120b90ae2aa7598dc787bb36585fbdab098ae9b092d90863774ede0660eff7499acac60c9f0c40a0ba1e66ee4 +"intl-tel-input@npm:^23.0.7": + version: 23.0.7 + resolution: "intl-tel-input@npm:23.0.7" + checksum: 10c0/c97830dd2a13e931ce0a9dd7f4969a6f850a4e3814a8a38dea4fe0cf67b533bc2b0b597a5b8a8112b1b3c4b063272afb76135c638f8f8a86b7b68aaa0da22955 languageName: node linkType: hard @@ -13030,13 +12634,6 @@ __metadata: languageName: node linkType: hard -"is-buffer@npm:~1.1.6": - version: 1.1.6 - resolution: "is-buffer@npm:1.1.6" - checksum: 10c0/ae18aa0b6e113d6c490ad1db5e8df9bdb57758382b313f5a22c9c61084875c6396d50bbf49315f5b1926d142d74dfb8d31b40d993a383e0a158b15fea7a82234 - languageName: node - linkType: hard - "is-callable@npm:^1.1.3, is-callable@npm:^1.1.4, is-callable@npm:^1.2.7": version: 1.2.7 resolution: "is-callable@npm:1.2.7" @@ -13399,6 +12996,15 @@ __metadata: languageName: node linkType: hard +"js-tiktoken@npm:^1.0.12": + version: 1.0.12 + resolution: "js-tiktoken@npm:1.0.12" + dependencies: + base64-js: "npm:^1.5.1" + checksum: 10c0/7afb4826e21342386a1884754fbc1c1828f948c4dd0ab093bf778d1323e65343bd5343d15f7cda46af396f1fe4a0297739936149b7c40a0601eefe3fcaef8727 + languageName: node + linkType: hard + "js-tiktoken@npm:^1.0.7, js-tiktoken@npm:^1.0.8": version: 1.0.10 resolution: "js-tiktoken@npm:1.0.10" @@ -13637,17 +13243,15 @@ __metadata: languageName: node linkType: hard -"langchain@npm:^0.1.37": - version: 0.1.37 - resolution: "langchain@npm:0.1.37" +"langchain@npm:^0.2.0, langchain@npm:~0.2.0": + version: 0.2.0 + resolution: "langchain@npm:0.2.0" dependencies: - "@anthropic-ai/sdk": "npm:^0.9.1" - "@langchain/community": "npm:~0.0.47" - "@langchain/core": "npm:~0.1.60" + "@langchain/core": "npm:~0.2.0" "@langchain/openai": "npm:~0.0.28" "@langchain/textsplitters": "npm:~0.0.0" binary-extensions: "npm:^2.2.0" - js-tiktoken: "npm:^1.0.7" + js-tiktoken: "npm:^1.0.12" js-yaml: "npm:^4.1.0" jsonpointer: "npm:^5.0.1" langchainhub: "npm:~0.0.8" @@ -13670,7 +13274,6 @@ __metadata: "@gomomento/sdk-core": ^1.51.1 "@gomomento/sdk-web": ^1.51.1 "@google-ai/generativelanguage": ^0.2.1 - "@google-cloud/storage": ^6.10.1 || ^7.7.0 "@mendable/firecrawl-js": ^0.0.13 "@notionhq/client": ^2.2.10 "@pinecone-database/pinecone": "*" @@ -13687,7 +13290,6 @@ __metadata: d3-dsv: ^2.0.0 epub2: ^3.0.1 fast-xml-parser: "*" - google-auth-library: ^8.9.0 handlebars: ^4.7.8 html-to-text: ^9.0.5 ignore: ^5.2.0 @@ -13733,8 +13335,6 @@ __metadata: optional: true "@google-ai/generativelanguage": optional: true - "@google-cloud/storage": - optional: true "@mendable/firecrawl-js": optional: true "@notionhq/client": @@ -13769,8 +13369,6 @@ __metadata: optional: true fast-xml-parser: optional: true - google-auth-library: - optional: true handlebars: optional: true html-to-text: @@ -13819,7 +13417,7 @@ __metadata: optional: true youtubei.js: optional: true - checksum: 10c0/ce0d2093fe3e624f5c872dd494b58a958e1ba79238dd22508591a0f207c1caf0e7ca16f4cc2af4e447f261fe50682da18d3c891710d702d8057d9b3ee33598c7 + checksum: 10c0/d60214cc6386f1208bf9e1f780758e66d6f57f1697dced9dcd529cc93e56acb32fdca574678441aaeffe88bcd7c8d11b3a41d26bee5d7e2e0eccfe81c316fb83 languageName: node linkType: hard @@ -14329,17 +13927,6 @@ __metadata: languageName: node linkType: hard -"md5@npm:^2.3.0": - version: 2.3.0 - resolution: "md5@npm:2.3.0" - dependencies: - charenc: "npm:0.0.2" - crypt: "npm:0.0.2" - is-buffer: "npm:~1.1.6" - checksum: 10c0/14a21d597d92e5b738255fbe7fe379905b8cb97e0a49d44a20b58526a646ec5518c337b817ce0094ca94d3e81a3313879c4c7b510d250c282d53afbbdede9110 - languageName: node - linkType: hard - "mdast-util-from-markdown@npm:^1.3.0": version: 1.3.1 resolution: "mdast-util-from-markdown@npm:1.3.1" @@ -14556,9 +14143,9 @@ __metadata: languageName: node linkType: hard -"mermaid@npm:^10.9.0": - version: 10.9.0 - resolution: "mermaid@npm:10.9.0" +"mermaid@npm:^10.9.1": + version: 10.9.1 + resolution: "mermaid@npm:10.9.1" dependencies: "@braintree/sanitize-url": "npm:^6.0.1" "@types/d3-scale": "npm:^4.0.3" @@ -14580,7 +14167,7 @@ __metadata: ts-dedent: "npm:^2.2.0" uuid: "npm:^9.0.0" web-worker: "npm:^1.2.0" - checksum: 10c0/ff1a96c8cd3384f64c9254d18f795f7fb2e454c47595c3deadbf6468544c607c5968c4a4bff8220bf27ccf0294f5527bc93660eea9958b4bc03e5e9871eaf57e + checksum: 10c0/034f326682e3e478e4bd85e418cfef00773db4432301b858247c8d4bf813e67fa1901e8548fc490eafe4c9c215c9fb96dead73007ff317ee99973cf4f63c8791 languageName: node linkType: hard @@ -16013,9 +15600,9 @@ __metadata: languageName: node linkType: hard -"openai@npm:^4.45.0": - version: 4.45.0 - resolution: "openai@npm:4.45.0" +"openai@npm:^4.47.1": + version: 4.47.1 + resolution: "openai@npm:4.47.1" dependencies: "@types/node": "npm:^18.11.18" "@types/node-fetch": "npm:^2.6.4" @@ -16027,7 +15614,7 @@ __metadata: web-streams-polyfill: "npm:^3.2.1" bin: openai: bin/cli - checksum: 10c0/55e8810a1b1671fc0da1f017df1c1dc8703527996c2f398e73404660c1687c800976e0afd90535149220d4b709461066c9fa489ded93c9ad2c10ef0c64907df5 + checksum: 10c0/10f5c7cd388dca758a31ccfd572d0ff732324842bc41f8cda80db043ad59035dbb9f903d0273bc14d1565f680901b7d65993b5e4c27dcf5deee2235a544ac0e4 languageName: node linkType: hard @@ -17636,16 +17223,16 @@ __metadata: languageName: node linkType: hard -"sass@npm:^1.77.1": - version: 1.77.1 - resolution: "sass@npm:1.77.1" +"sass@npm:^1.77.2": + version: 1.77.2 + resolution: "sass@npm:1.77.2" dependencies: chokidar: "npm:>=3.0.0 <4.0.0" immutable: "npm:^4.0.0" source-map-js: "npm:>=0.6.2 <2.0.0" bin: sass: sass.js - checksum: 10c0/edcfc7d038234b1198c3ddcac5963fcd1e17a9c1ee0f9bd09784ab5353b60ff50b189b4c9154b34f5da9ca0eaab8b189fd3e83a4b43a494151ad4735f8e5f364 + checksum: 10c0/0d292339064de3c902e209d41de9c4eb2038cff326476aeebbb5be3eee1d23400d975face2b8e124ae617b10af3e93bec01580f61912f34e4c517fe137a118b6 languageName: node linkType: hard @@ -17901,12 +17488,12 @@ __metadata: languageName: node linkType: hard -"shiki@npm:1.3.0, shiki@npm:^1.3.0": - version: 1.3.0 - resolution: "shiki@npm:1.3.0" +"shiki@npm:1.6.0, shiki@npm:^1.5.2": + version: 1.6.0 + resolution: "shiki@npm:1.6.0" dependencies: - "@shikijs/core": "npm:1.3.0" - checksum: 10c0/a57f9d86567e86e9f003ccc42aaee594a47dde0f034047c251cf8d4b8971ec790c1ec2a3563c8126408ac2274ab7717201a24a6b70b1ef9dc1ab8f55f64d5879 + "@shikijs/core": "npm:1.6.0" + checksum: 10c0/9a3a5a123029b52cf881be3e14007b50c5dc7f2222c809904323b1a9b415ec1a12009b0cf28c0b00370c22372d5b445e3ef7d60d14db932f069abc38c246cfa0 languageName: node linkType: hard @@ -19506,46 +19093,6 @@ __metadata: languageName: node linkType: hard -"vite@npm:^5.2.10": - version: 5.2.10 - resolution: "vite@npm:5.2.10" - dependencies: - esbuild: "npm:^0.20.1" - fsevents: "npm:~2.3.3" - postcss: "npm:^8.4.38" - rollup: "npm:^4.13.0" - peerDependencies: - "@types/node": ^18.0.0 || >=20.0.0 - less: "*" - lightningcss: ^1.21.0 - sass: "*" - stylus: "*" - sugarss: "*" - terser: ^5.4.0 - dependenciesMeta: - fsevents: - optional: true - peerDependenciesMeta: - "@types/node": - optional: true - less: - optional: true - lightningcss: - optional: true - sass: - optional: true - stylus: - optional: true - sugarss: - optional: true - terser: - optional: true - bin: - vite: bin/vite.js - checksum: 10c0/d50630ac8de807a6185cd9b5763b3969b2950a454cf6a4482f3780f183865e8d6f7e3aa57dd70ede1c493aaa861efb25b43562287efbcf8b471b7f3b88857a33 - languageName: node - linkType: hard - "vite@npm:^5.2.11": version: 5.2.11 resolution: "vite@npm:5.2.11" @@ -19601,25 +19148,26 @@ __metadata: languageName: node linkType: hard -"vitepress@npm:^1.1.4": - version: 1.1.4 - resolution: "vitepress@npm:1.1.4" +"vitepress@npm:^1.2.0": + version: 1.2.0 + resolution: "vitepress@npm:1.2.0" dependencies: "@docsearch/css": "npm:^3.6.0" "@docsearch/js": "npm:^3.6.0" - "@shikijs/core": "npm:^1.3.0" - "@shikijs/transformers": "npm:^1.3.0" - "@types/markdown-it": "npm:^14.0.1" + "@shikijs/core": "npm:^1.5.2" + "@shikijs/transformers": "npm:^1.5.2" + "@types/markdown-it": "npm:^14.1.1" "@vitejs/plugin-vue": "npm:^5.0.4" - "@vue/devtools-api": "npm:^7.0.27" + "@vue/devtools-api": "npm:^7.2.0" + "@vue/shared": "npm:^3.4.27" "@vueuse/core": "npm:^10.9.0" "@vueuse/integrations": "npm:^10.9.0" focus-trap: "npm:^7.5.4" mark.js: "npm:8.11.1" minisearch: "npm:^6.3.0" - shiki: "npm:^1.3.0" - vite: "npm:^5.2.10" - vue: "npm:^3.4.25" + shiki: "npm:^1.5.2" + vite: "npm:^5.2.11" + vue: "npm:^3.4.27" peerDependencies: markdown-it-mathjax3: ^4 postcss: ^8 @@ -19630,7 +19178,7 @@ __metadata: optional: true bin: vitepress: bin/vitepress.js - checksum: 10c0/c8ae9aebd9c780362f7258023aaff0345c5afd4194bf0963a7530dade3bfc7b7028233a9ea73a3d714d0d2618dd3492424e0fa3b68bf562a4045a4bf33ca2175 + checksum: 10c0/bcc944e3cadca3bf12e8ff325e9f766fe7a164001c3ceb403256f1edfa4f7b3a8df165a4e2e27fbd70c00844f1887b1944647ce76e77d84d9de31fe1d6ebd597 languageName: node linkType: hard @@ -19657,24 +19205,6 @@ __metadata: languageName: node linkType: hard -"vue@npm:^3.4.25": - version: 3.4.26 - resolution: "vue@npm:3.4.26" - dependencies: - "@vue/compiler-dom": "npm:3.4.26" - "@vue/compiler-sfc": "npm:3.4.26" - "@vue/runtime-dom": "npm:3.4.26" - "@vue/server-renderer": "npm:3.4.26" - "@vue/shared": "npm:3.4.26" - peerDependencies: - typescript: "*" - peerDependenciesMeta: - typescript: - optional: true - checksum: 10c0/888cea9ff79d6401872de85b381e4ff7a87df23d37b43363fdfea9892935c2925bd4fc29a278ff92ee8137d4baf819c4319b9c5d5240bddc7437b8865092d2f6 - languageName: node - linkType: hard - "vue@npm:^3.4.27": version: 3.4.27 resolution: "vue@npm:3.4.27" @@ -19702,10 +19232,10 @@ __metadata: languageName: node linkType: hard -"wavesurfer.js@npm:^7.7.14": - version: 7.7.14 - resolution: "wavesurfer.js@npm:7.7.14" - checksum: 10c0/5e9ec86481e97ee56063ff54d8b5430ab1e3256f19e5b7923045957cb8e3dd4c1c2963ae359d5b4f197342cac79c5b7c45c76c1f11650712b18b7f4e1d8f8350 +"wavesurfer.js@npm:^7.7.15": + version: 7.7.15 + resolution: "wavesurfer.js@npm:7.7.15" + checksum: 10c0/20e7830c3288cf4e966851008e1c568109b2a363bf025f73ee6745085ca741986cf17dd36b825646c714c505fe2360d7505cf0faabbab13947cc750debb2fc40 languageName: node linkType: hard @@ -20242,12 +19772,12 @@ __metadata: languageName: node linkType: hard -"zx@npm:^8.0.2": - version: 8.0.2 - resolution: "zx@npm:8.0.2" +"zx@npm:^8.1.0": + version: 8.1.0 + resolution: "zx@npm:8.1.0" dependencies: "@types/fs-extra": "npm:^11.0.4" - "@types/node": "npm:>=20.12.5" + "@types/node": "npm:>=20.12.11" dependenciesMeta: "@types/fs-extra": optional: true @@ -20255,6 +19785,6 @@ __metadata: optional: true bin: zx: build/cli.js - checksum: 10c0/fd1dd3c7b6b0d9a3c692f40eba6a47665bc8e9b7030192227ada8a73b0c83fd7059452b87417a18c154746621e5fd67225bb39f51ccb7f8f7213924641f7983f + checksum: 10c0/a833ccfb94889c900f127bb89cd537fc3a32eb5c0bda71ddba6e24d5f268926234c48855af8da4550646fa5cff876fe18ca835f24f2d4d8548b3ec21c253d372 languageName: node linkType: hard