diff --git a/.github/workflows/test-enjoy-app.yml b/.github/workflows/test-enjoy-app.yml index fc4bbe3f..155656fd 100644 --- a/.github/workflows/test-enjoy-app.yml +++ b/.github/workflows/test-enjoy-app.yml @@ -20,7 +20,7 @@ jobs: strategy: fail-fast: false matrix: - os: [macos-12, macos-14, windows-latest] + os: [macos-12, macos-14, windows-latest, ubuntu-latest] steps: - uses: actions/checkout@v4 @@ -47,17 +47,24 @@ jobs: run: yarn enjoy:package - name: Run main tests with xvfb-run on Ubuntu + continue-on-error: true if: contains(matrix.os, 'ubuntu') env: DEBUG: pw:browser* run: | xvfb-run --auto-servernum --server-args="-screen 0 1280x960x24" -- yarn enjoy:test:main - - name: Run main tests on macOS or Windows - if: contains(matrix.os, 'macos') || contains(matrix.os, 'windows') + - name: Run main tests on Windows + continue-on-error: true + if: contains(matrix.os, 'windows') + run: yarn enjoy:test:main + + - name: Run main tests on macOS + if: contains(matrix.os, 'macos') run: yarn enjoy:test:main - name: Run renderer tests with xvfb-run on Ubuntu + continue-on-error: true if: contains(matrix.os, 'ubuntu') run: | xvfb-run --auto-servernum --server-args="-screen 0 1280x960x24" -- yarn enjoy:test:renderer diff --git a/1000h-portal/package.json b/1000h-portal/package.json index f196c9fb..5cc769a0 100644 --- a/1000h-portal/package.json +++ b/1000h-portal/package.json @@ -12,7 +12,7 @@ "dependencies": { "@nuxtjs/seo": "^2.0.0-rc.23", "nuxt": "^3.13.2", - "nuxt-og-image": "^3.0.4", + "nuxt-og-image": "^3.0.6", "vue": "^3.5.12", "vue-router": "^4.4.5" }, @@ -20,6 +20,6 @@ "autoprefixer": "^10.4.20", "postcss": "^8.4.47", "sass": "^1.79.5", - "tailwindcss": "^3.4.13" + "tailwindcss": "^3.4.14" } } diff --git a/enjoy/e2e/main.spec.ts b/enjoy/e2e/main.spec.ts index 5c464a42..e2f250e5 100644 --- a/enjoy/e2e/main.spec.ts +++ b/enjoy/e2e/main.spec.ts @@ -60,9 +60,17 @@ test.afterAll(async () => { await electronApp.close(); }); -test("validate whisper command", async () => { +test("validate echogarden command", async () => { const res = await page.evaluate(() => { - return window.__ENJOY_APP__.whisper.check(); + return window.__ENJOY_APP__.echogarden.check({ + engine: "whisper", + whisper: { + model: "tiny.en", + language: "en", + encoderProvider: "cpu", + decoderProvider: "cpu", + }, + }); }); console.info(res.log); expect(res.success).toBeTruthy(); @@ -75,13 +83,6 @@ test("valid ffmpeg command", async () => { expect(res).toBeTruthy(); }); -test("validate echogarden align command", async () => { - const res = await page.evaluate(() => { - return window.__ENJOY_APP__.echogarden.check(); - }); - expect(res).toBeTruthy(); -}); - test("should setup default library path", async () => { const settings = fs.readJsonSync(path.join(resultDir, "settings.json")); expect(settings.library).not.toBeNull(); diff --git a/enjoy/lib/whisper.cpp/arm64/darwin/ggml-metal.metal b/enjoy/lib/whisper.cpp/arm64/darwin/ggml-metal.metal deleted file mode 100644 index 57fdf564..00000000 --- a/enjoy/lib/whisper.cpp/arm64/darwin/ggml-metal.metal +++ /dev/null @@ -1,6854 +0,0 @@ -#define GGML_COMMON_DECL_METAL -#define GGML_COMMON_IMPL_METAL -#include "ggml-common.h" - -#include - -using namespace metal; - -#define MAX(x, y) ((x) > (y) ? (x) : (y)) -#define MIN(x, y) ((x) < (y) ? (x) : (y)) -#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } - -#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 - -enum ggml_sort_order { - GGML_SORT_ORDER_ASC, - GGML_SORT_ORDER_DESC, -}; - -// general-purpose kernel for addition, multiplication and division of two tensors -// pros: works for non-contiguous tensors, supports broadcast across all dims -// cons: not very efficient -kernel void kernel_add( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int64_t & offs, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10)); - } -} - -kernel void kernel_mul( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10)); - } -} - -kernel void kernel_div( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10)); - } -} - -// assumption: src1 is a row -// broadcast src1 into src0 -kernel void kernel_add_row( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - constant uint64_t & nb [[buffer(28)]], - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] + src1[tpig % nb]; -} - -kernel void kernel_mul_row( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - constant uint64_t & nb [[buffer(28)]], - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src1[tpig % nb]; -} - -kernel void kernel_div_row( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - constant uint64_t & nb [[buffer(28)]], - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] / src1[tpig % nb]; -} - -kernel void kernel_scale( - device const float * src0, - device float * dst, - constant float & scale, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale; -} - -kernel void kernel_scale_4( - device const float4 * src0, - device float4 * dst, - constant float & scale, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale; -} - -kernel void kernel_clamp( - device const float * src0, - device float * dst, - constant float & min, - constant float & max, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]); -} - -kernel void kernel_relu( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = max(0.0f, src0[tpig]); -} - -kernel void kernel_sigmoid( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); -} - -kernel void kernel_tanh( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = precise::tanh(x); -} - -constant float GELU_COEF_A = 0.044715f; -constant float GELU_QUICK_COEF = -1.702f; -constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - -kernel void kernel_gelu( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - - dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - // BEWARE !!! - // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! - // This was observed with Falcon 7B and 40B models - // - dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_quick( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - - dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -kernel void kernel_gelu_quick_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -kernel void kernel_silu( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = x / (1.0f + exp(-x)); -} - -kernel void kernel_silu_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - dst[tpig] = x / (1.0f + exp(-x)); -} - -kernel void kernel_sqr( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src0[tpig]; -} - -kernel void kernel_sum_rows( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tpig[[thread_position_in_grid]]) { - int64_t i3 = tpig.z; - int64_t i2 = tpig.y; - int64_t i1 = tpig.x; - - if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { - return; - } - - device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); - device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3); - - float row_sum = 0; - - for (int64_t i0 = 0; i0 < ne00; i0++) { - row_sum += src_row[i0]; - } - - dst_row[0] = row_sum; -} - -template -kernel void kernel_soft_max( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (ne02*ne01); - const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; - const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - - device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; - device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - const int64_t h = i02; - - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float lmax = -INFINITY; - - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); - } - - // find the max value in the block - float max_val = simd_max(lmax); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = -INFINITY; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = max_val; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - max_val = buf[tiisg]; - max_val = simd_max(max_val); - } - - // parallel sum - float lsum = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); - lsum += exp_psrc0; - pdst[i00] = exp_psrc0; - } - - // This barrier fixes a failing test - // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 - threadgroup_barrier(mem_flags::mem_none); - - float sum = simd_sum(lsum); - - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = sum; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - sum = buf[tiisg]; - sum = simd_sum(sum); - } - - const float inv_sum = 1.0f/sum; - - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - pdst[i00] *= inv_sum; - } -} - -template -kernel void kernel_soft_max_4( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (ne02*ne01); - const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; - const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - - device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; - device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; - - float slope = 1.0f; - - if (max_bias > 0.0f) { - const int64_t h = i02; - - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float4 lmax4 = -INFINITY; - - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); - } - - const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); - - float max_val = simd_max(lmax); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = -INFINITY; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = max_val; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - max_val = buf[tiisg]; - max_val = simd_max(max_val); - } - - // parallel sum - float4 lsum4 = 0.0f; - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); - lsum4 += exp_psrc4; - pdst4[i00] = exp_psrc4; - } - - const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; - - // This barrier fixes a failing test - // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 - threadgroup_barrier(mem_flags::mem_none); - - float sum = simd_sum(lsum); - - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = sum; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - sum = buf[tiisg]; - sum = simd_sum(sum); - } - - const float inv_sum = 1.0f/sum; - - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - pdst4[i00] *= inv_sum; - } -} - -typedef decltype(kernel_soft_max) kernel_soft_max_t; -typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t; - -template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max; -template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max; -template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; -template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; - -kernel void kernel_diag_mask_inf( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int & n_past, - uint3 tpig[[thread_position_in_grid]]) { - const int64_t i02 = tpig[2]; - const int64_t i01 = tpig[1]; - const int64_t i00 = tpig[0]; - - if (i00 > n_past + i01) { - dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; - } else { - dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; - } -} - -kernel void kernel_diag_mask_inf_8( - device const float4 * src0, - device float4 * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int & n_past, - uint3 tpig[[thread_position_in_grid]]) { - - const int64_t i = 2*tpig[0]; - - dst[i+0] = src0[i+0]; - dst[i+1] = src0[i+1]; - int64_t i4 = 4*i; - const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; - const int64_t i01 = i4/(ne00); i4 -= i01*ne00; - const int64_t i00 = i4; - for (int k = 3; k >= 0; --k) { - if (i00 + 4 + k <= n_past + i01) { - break; - } - dst[i+1][k] = -INFINITY; - if (i00 + k > n_past + i01) { - dst[i][k] = -INFINITY; - } - } -} - -kernel void kernel_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * sum [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); - // MEAN - // parallel sum - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - sum[tpitg] += x[i00]; - } - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - const float mean = sum[0] / ne00; - - // recenter and VARIANCE - threadgroup_barrier(mem_flags::mem_threadgroup); - device float * y = dst + tgpig*ne00; - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - y[i00] = x[i00] - mean; - sum[tpitg] += y[i00] * y[i00]; - } - - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - const float variance = sum[0] / ne00; - - const float scale = 1.0f/sqrt(variance + eps); - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - y[i00] = y[i00] * scale; - } -} - -kernel void kernel_rms_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); - - float4 sumf = 0; - float all_sum = 0; - - // parallel sum - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - sumf += x[i00] * x[i00]; - } - all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; - all_sum = simd_sum(all_sum); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = all_sum; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - all_sum = buf[tiisg]; - all_sum = simd_sum(all_sum); - } - - const float mean = all_sum/ne00; - const float scale = 1.0f/sqrt(mean + eps); - - device float4 * y = (device float4 *) (dst + tgpig*ne00); - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - y[i00] = x[i00] * scale; - } -} - -kernel void kernel_group_norm( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int32_t & n_groups, - constant float & eps, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t ne = ne00*ne01*ne02; - const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups); - - int start = tgpig * gs; - int end = start + gs; - - start += tpitg; - - if (end >= ne) { - end = ne; - } - - float tmp = 0.0f; // partial sum for thread in warp - - for (int j = start; j < end; j += ntg) { - tmp += src0[j]; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - tmp = simd_sum(tmp); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = tmp; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - tmp = buf[tiisg]; - tmp = simd_sum(tmp); - } - - const float mean = tmp / gs; - tmp = 0.0f; - - for (int j = start; j < end; j += ntg) { - float xi = src0[j] - mean; - dst[j] = xi; - tmp += xi * xi; - } - - tmp = simd_sum(tmp); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = tmp; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - tmp = buf[tiisg]; - tmp = simd_sum(tmp); - } - - const float variance = tmp / gs; - const float scale = 1.0f/sqrt(variance + eps); - for (int j = start; j < end; j += ntg) { - dst[j] *= scale; - } -} - -// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q4 quants begin (0 or QK4_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - - float2 acc = 0.f; - - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); - - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); - } - return d * (sumy * -8.f + acc[0] + acc[1]); -} - -// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q4 quants begin (0 or QK4_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - float m = qb_curr->m; - - float2 acc = 0.f; - - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); - - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); - } - return d * (acc[0] + acc[1]) + sumy * m; -} - -// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q5 quants begin (0 or QK5_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - - float2 acc = 0.f; - - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); - const uint32_t qh = *((device const uint32_t *)qb_curr->qh); - - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) - + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); - acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) - + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); - } - return d * (sumy * -16.f + acc[0] + acc[1]); -} - -// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q5 quants begin (0 or QK5_1/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - float m = qb_curr->m; - - float2 acc = 0.f; - - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); - const uint32_t qh = *((device const uint32_t *)qb_curr->qh); - - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) - + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); - acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) - + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); - } - return d * (acc[0] + acc[1]) + sumy * m; -} - -// putting them in the kernel cause a significant performance penalty -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group -//Note: This is a template, but strictly speaking it only applies to -// quantizations where the block size is 32. It also does not -// guard against the number of rows not being divisible by -// N_DST, so this is another explicit assumption of the implementation. -template -void mul_vec_q_n_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, uint tiisg, uint sgitg) { - const int nb = ne00/QK4_0; - - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * nsg + sgitg) * nr; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q_type * x = (device const block_q_type *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[16]; // src1 vector cache - float sumf[nr] = {0.f}; - - const int ix = (tiisg/2); - const int il = (tiisg%2)*8; - - device const float * yb = y + ix * QK4_0 + il; - - // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += nw/2) { - float sumy = 0; - for (int i = 0; i < 8; i += 2) { - sumy += yb[i] + yb[i+1]; - yl[i+0] = yb[i+ 0]; - yl[i+1] = yb[i+ 1]/256.f; - - sumy += yb[i+16] + yb[i+17]; - yl[i+8] = yb[i+16]/16.f; - yl[i+9] = yb[i+17]/4096.f; - } - - for (int row = 0; row < nr; row++) { - sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); - } - - yb += QK4_0 * 16; - } - - for (int row = 0; row < nr; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot; - } - } -} - -kernel void kernel_mul_mv_q4_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); -} - -kernel void kernel_mul_mv_q4_1_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); -} - -kernel void kernel_mul_mv_q5_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); -} - -kernel void kernel_mul_mv_q5_1_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); -} - - -#define NB_Q8_0 8 - -void kernel_mul_mv_q8_0_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - const int nr = N_DST; - const int nsg = N_SIMDGROUP; - const int nw = N_SIMDWIDTH; - - const int nb = ne00/QK8_0; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * nsg + sgitg) * nr; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[NB_Q8_0]; - float sumf[nr]={0.f}; - - const int ix = tiisg/4; - const int il = tiisg%4; - - device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; - - // each thread in a SIMD group deals with NB_Q8_0 quants at a time - for (int ib = ix; ib < nb; ib += nw/4) { - for (int i = 0; i < NB_Q8_0; ++i) { - yl[i] = yb[i]; - } - - for (int row = 0; row < nr; row++) { - device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; - float sumq = 0.f; - for (int iq = 0; iq < NB_Q8_0; ++iq) { - sumq += qs[iq] * yl[iq]; - } - sumf[row] += sumq*x[ib+row*nb].d; - } - - yb += NB_Q8_0 * nw; - } - - for (int row = 0; row < nr; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; - } - } -} - -[[host_name("kernel_mul_mv_q8_0_f32")]] -kernel void kernel_mul_mv_q8_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); -} - -#define N_F32_F32 4 - -void kernel_mul_mv_f32_f32_impl( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - int64_t ne10, - int64_t ne11, - int64_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg) { - - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F32_F32; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const float * x = (device const float *) (src0 + offset0); - - if (ne00 < 128) { - for (int row = 0; row < N_F32_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - device const float4 * x4 = (device const float4 *)x; - for (int row = 0; row < N_F32_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - device const float4 * y4 = (device const float4 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -[[host_name("kernel_mul_mv_f32_f32")]] -kernel void kernel_mul_mv_f32_f32( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); -} - -#define N_F16_F16 4 - -kernel void kernel_mul_mv_f16_f16( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F16_F16; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const half * x = (device const half *) (src0 + offset0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (half) x[i] * (half) y[i]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - device const half4 * x4 = (device const half4 *)x; - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); - device const half4 * y4 = (device const half4 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -void kernel_mul_mv_f16_f32_1row_impl( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const half * x = (device const half *) (src0 + offset0); - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - if (ne00 < 128) { - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } else { - device const half4 * x4 = (device const half4 *) x; - device const float4 * y4 = (device const float4 *) y; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; - } - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } -} - -[[host_name("kernel_mul_mv_f16_f32_1row")]] -kernel void kernel_mul_mv_f16_f32_1row( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); -} - -#define N_F16_F32 4 - -void kernel_mul_mv_f16_f32_impl( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - int64_t ne10, - int64_t ne11, - int64_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg) { - - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F16_F32; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const half * x = (device const half *) (src0 + offset0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - device const half4 * x4 = (device const half4 *)x; - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - device const float4 * y4 = (device const float4 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -[[host_name("kernel_mul_mv_f16_f32")]] -kernel void kernel_mul_mv_f16_f32( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); -} - -// Assumes row size (ne00) is a multiple of 4 -kernel void kernel_mul_mv_f16_f32_l4( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int nrows = ne11; - const int64_t r0 = tgpig.x; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const half4 * x4 = (device const half4 *) (src0 + offset0); - - for (int r1 = 0; r1 < nrows; ++r1) { - device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } -} - -static float rope_yarn_ramp(const float low, const float high, const int i0) { - const float y = (i0 / 2 - low) / max(0.001f, high - low); - return 1.0f - min(1.0f, max(0.0f, y)); -} - -// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn -// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -static void rope_yarn( - float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, - thread float * cos_theta, thread float * sin_theta -) { - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = freq_scale * theta_extrap; - float theta = theta_interp; - if (ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; - theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - - // Get n-d magnitude scaling corrected for interpolation - mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); - } - *cos_theta = cos(theta) * mscale; - *sin_theta = sin(theta) * mscale; -} - -// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get -// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { - return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base)); -} - -static void rope_yarn_corr_dims( - int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] -) { - // start and end correction dims - dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); - dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); -} - -typedef void (rope_t)( - device const void * src0, - device const int32_t * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & mode, - constant int & n_orig_ctx, - constant float & freq_base, - constant float & freq_scale, - constant float & ext_factor, - constant float & attn_factor, - constant float & beta_fast, - constant float & beta_slow, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]); - -template -kernel void kernel_rope( - device const void * src0, - device const int32_t * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & mode, - constant int & n_orig_ctx, - constant float & freq_base, - constant float & freq_scale, - constant float & ext_factor, - constant float & attn_factor, - constant float & beta_fast, - constant float & beta_slow, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]) { - const int64_t i3 = tgpig[2]; - const int64_t i2 = tgpig[1]; - const int64_t i1 = tgpig[0]; - - const bool is_neox = mode & 2; - - float corr_dims[2]; - rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); - - device const int32_t * pos = src1; - - const int64_t p = pos[i2]; - - const float theta_0 = (float)p; - const float inv_ndims = -1.f/n_dims; - - if (!is_neox) { - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { - - const float theta = theta_0 * pow(freq_base, inv_ndims*i0); - float cos_theta, sin_theta; - rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const T x0 = src[0]; - const T x1 = src[1]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[1] = x0*sin_theta + x1*cos_theta; - } - } else { - for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) { - if (ic < n_dims) { - const int64_t ib = 0; - - // simplified from `(ib * n_dims + ic) * inv_ndims` - const float cur_rot = inv_ndims*ic - ib; - - const float theta = theta_0 * pow(freq_base, cur_rot); - float cos_theta, sin_theta; - rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); - - const int64_t i0 = ib*n_dims + ic/2; - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - } else { - const int64_t i0 = ic; - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } - } -} - -template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; -template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; - -typedef void (im2col_t)( - device const float * x, - device char * dst, - constant int32_t & ofs0, - constant int32_t & ofs1, - constant int32_t & IW, - constant int32_t & IH, - constant int32_t & CHW, - constant int32_t & s0, - constant int32_t & s1, - constant int32_t & p0, - constant int32_t & p1, - constant int32_t & d0, - constant int32_t & d1, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tgpg[[threadgroups_per_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]); - -template -kernel void kernel_im2col( - device const float * x, - device char * dst, - constant int32_t & ofs0, - constant int32_t & ofs1, - constant int32_t & IW, - constant int32_t & IH, - constant int32_t & CHW, - constant int32_t & s0, - constant int32_t & s1, - constant int32_t & p0, - constant int32_t & p1, - constant int32_t & d0, - constant int32_t & d1, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tgpg[[threadgroups_per_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0; - const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1; - - const int32_t offset_dst = - (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + - (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]); - - device T * pdst = (device T *) (dst); - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - pdst[offset_dst] = 0.0f; - } else { - const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1; - pdst[offset_dst] = x[offset_src + iih * IW + iiw]; - } -} - -template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; -template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; - -kernel void kernel_upscale_f32( - device const char * src0, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int32_t & sf, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; - - const int64_t i03 = i3; - const int64_t i02 = i2; - const int64_t i01 = i1/sf; - - device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); - device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - dst_ptr[i0] = src0_ptr[i0/sf]; - } -} - -kernel void kernel_pad_f32( - device const char * src0, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; - - const int64_t i03 = i3; - const int64_t i02 = i2; - const int64_t i01 = i1; - - device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); - device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); - - if (i1 < ne01 && i2 < ne02 && i3 < ne03) { - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i0 < ne00) { - dst_ptr[i0] = src0_ptr[i0]; - } else { - dst_ptr[i0] = 0.0f; - } - } - - return; - } - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - dst_ptr[i0] = 0.0f; - } -} - -kernel void kernel_arange_f32( - device char * dst, - constant int64_t & ne0, - constant float & start, - constant float & step, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - device float * dst_ptr = (device float *) dst; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - dst_ptr[i0] = start + step * i0; - } -} - -kernel void kernel_timestep_embedding_f32( - device const char * src0, - device char * dst, - constant uint64_t & nb1, - constant int & dim, - constant int & max_period, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - int i = tgpig.x; - device float * embed_data = (device float *)(dst + i*nb1); - - int half_ = dim / 2; - for (int j = tpitg.x; j < half_; j += ntg.x) { - float timestep = ((device float *)src0)[i]; - float freq = (float)exp(-log((float)max_period) * j / half_); - float arg = timestep * freq; - embed_data[j ] = cos(arg); - embed_data[j + half_] = sin(arg); - } - - if (dim % 2 != 0 && tpitg.x == 0) { - embed_data[dim] = 0.f; - } -} - -// bitonic sort implementation following the CUDA kernels as reference -typedef void (argsort_t)( - device const float * x, - device int32_t * dst, - constant int64_t & ncols, - constant int64_t & ncols_pad, - threadgroup int32_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]]); - -template -kernel void kernel_argsort_f32_i32( - device const float * x, - device int32_t * dst, - constant int64_t & ncols, - constant int64_t & ncols_pad, - threadgroup int32_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]]) { - // bitonic sort - int col = tpitg[0]; - int row = tgpig[1]; - - if (col >= ncols_pad) return; - - device const float * x_row = x + row * ncols; - threadgroup int32_t * dst_row = shared_values; - - // initialize indices - dst_row[col] = col; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (int k = 2; k <= ncols_pad; k *= 2) { - for (int j = k / 2; j > 0; j /= 2) { - int ixj = col ^ j; - if (ixj > col) { - if ((col & k) == 0) { - if (dst_row[col] >= ncols || - (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? - x_row[dst_row[col]] > x_row[dst_row[ixj]] : - x_row[dst_row[col]] < x_row[dst_row[ixj]])) - ) { - SWAP(dst_row[col], dst_row[ixj]); - } - } else { - if (dst_row[ixj] >= ncols || - (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? - x_row[dst_row[col]] < x_row[dst_row[ixj]] : - x_row[dst_row[col]] > x_row[dst_row[ixj]])) - ) { - SWAP(dst_row[col], dst_row[ixj]); - } - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - } - - // copy the result to dst without the padding - if (col < ncols) { - dst[row * ncols + col] = dst_row[col]; - } -} - -template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; -template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; - -kernel void kernel_leaky_relu_f32( - device const float * src0, - device float * dst, - constant float & slope, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; -} - -typedef void (flash_attn_ext_f16_t)( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb23, - constant uint64_t & nb31, - constant int64_t & ne1, - constant int64_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, - threadgroup half * shared, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]); - -// ref: https://arxiv.org/pdf/2307.08691.pdf -template // head size, queries per threadgroup, cache items per threadgroup -kernel void kernel_flash_attn_ext_f16( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb23, - constant uint64_t & nb31, - constant int64_t & ne1, - constant int64_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, - threadgroup half * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - const short nsg = ntg.y; // number of simdgroups - - const short iq3 = tgpig[2]; - const short iq2 = tgpig[1]; - const short iq1 = tgpig[0]*Q; - - const short D4 = D/4; - const short D8 = D/8; - //const short Q8 = Q/8; - const short NW = N_SIMDWIDTH; - const short SH = (C + Q); // shared memory per simdgroup in (half) - - const short T = D + 2*nsg*SH; // shared memory size per query in (half) - const short TF = T/2; // shared memory size per query in (float) - const short T4 = T/4; // shared memory size per query in (half4) - - threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix - - // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - simdgroup_half8x8 lo[D8]; - - // load heads from Q to shared memory - for (short j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); - - for (short i = tiisg; i < D4; i += NW) { - if (iq1 + j < ne01) { - sq4[j*T4 + i] = (half4) q4[i]; - } else { - sq4[j*T4 + i] = 0.0h; - } - } - } - - // zero out lo - for (short i = 0; i < D8; ++i) { - lo[i] = make_filled_simdgroup_matrix(0.0h); - } - - // zero out shared memory SH - for (short j = 0; j < Q; ++j) { - for (short i = tiisg; i < SH; i += NW) { - ss[j*TF + i] = 0.0f; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - { - float S[Q] = { [0 ... Q-1] = 0.0h }; - float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; - - // assume K and V are same shape - const short ne22 = ne12; - const short ne23 = ne13; - - // broadcast - const short rk2 = ne02/ne12; - const short rk3 = ne03/ne13; - - const short rv2 = ne02/ne22; - const short rv3 = ne03/ne23; - - // k indices - const short ik2 = iq2/rk2; - const short ik3 = iq3/rk3; - - // v indices - const short iv2 = iq2/rv2; - const short iv3 = iq3/rv3; - - // load the queries from shared memory into local memory - simdgroup_half8x8 mq[D8]; - - for (short i = 0; i < D8; ++i) { - simdgroup_load(mq[i], sq + i*8, T); - } - - // pointer to the mask - device const half * mp = (device const half *) (mask + iq1*nb31); - - // prepare diagonal scale matrix - simdgroup_float8x8 mscale(scale); - - // prepare diagonal slope matrix - simdgroup_float8x8 mslope(1.0f); - - // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = iq2; - - const float base = h < n_head_log2 ? m0 : m1; - const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - mslope = simdgroup_float8x8(pow(base, exph)); - } - - // loop over the KV cache - // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { - const int ic = ic0 + C*sgitg; - if (ic >= ne11) { - break; - } - - // Q*K^T - { - for (short cc = 0; cc < C/8; ++cc) { - simdgroup_float8x8 mqk = make_filled_simdgroup_matrix(0.h); - - device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); - - for (short i = 0; i < D8; ++i) { - simdgroup_half8x8 mk; - simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose - - simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); - } - - if (mask != q) { - // mqk = mqk*scale + mask*slope - simdgroup_half8x8 mm; - simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); - simdgroup_multiply(mm, mslope, mm); - simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); - } else { - // mqk = mqk*scale - simdgroup_multiply(mqk, mscale, mqk); - } - - simdgroup_store(mqk, ss + 8*cc, TF, 0, false); - } - } - - // used to detect blocks full of -INF - float smax = -INFINITY; - - // online softmax - { - float ms[Q]; - - for (short j = 0; j < Q; ++j) { - const short p = tiisg; - - const float m = M[j]; - const float s = ss[j*TF + p]; - - smax = simd_max(max(smax, s)); - M[j] = simd_max(max(M[j], s)); - - ms[j] = exp(m - M[j]); - const float vs = exp(s - M[j]); - - S[j] = S[j]*ms[j] + simd_sum(vs); - - // the P matrix from the paper (Q rows, C columns) - ss[j*TF + p] = vs; - } - - // create a QxQ diagonal matrix for rescaling the output - if (tiisg < Q) { - ss[tiisg*TF + C + tiisg] = ms[tiisg]; - } - } - - // skip -INF blocks - if (smax == -INFINITY) { - continue; - } - - // O = diag(ms)*O - { - simdgroup_float8x8 mm; - simdgroup_load(mm, ss + C, TF, 0, false); - - for (short i = 0; i < D8; ++i) { - simdgroup_multiply(lo[i], mm, lo[i]); - } - } - - // O = O + (Q*K^T)*V - { - for (short cc = 0; cc < C/8; ++cc) { - device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); - - for (short i = 0; i < D8; ++i) { - simdgroup_half8x8 mk; - simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); - - simdgroup_float8x8 mv; - simdgroup_load(mv, ss + 8*cc, TF, 0, false); - - simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]); - } - } - } - } - - // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (short j = 0; j < Q; ++j) { - if (tiisg == 0) { - ss[j*TF + 0] = S[j]; - ss[j*TF + 1] = M[j]; - } - } - } - - // reduce the warps sequentially - for (short sg = 1; sg < nsg; ++sg) { - float S = { 0.0h }; - float M = { -FLT_MAX/2 }; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // each simdgroup stores its output to shared memory, reusing sq - if (sgitg == sg) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // the first simdgroup accumulates the results from the other simdgroups - if (sgitg == 0) { - for (short j = 0; j < Q; ++j) { - const float S0 = ss[j*TF + 0]; - const float S1 = ss[j*TF + sg*SH + 0]; - - const float M0 = ss[j*TF + 1]; - const float M1 = ss[j*TF + sg*SH + 1]; - - M = max(M0, M1); - - const float ms0 = exp(M0 - M); - const float ms1 = exp(M1 - M); - - S = S0*ms0 + S1*ms1; - - if (tiisg == 0) { - ss[j*TF + 0] = S; - ss[j*TF + 1] = M; - - ss[j*TF + C + j ] = ms0; - ss[j*TF + C + j + sg*SH] = ms1; - } - } - - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - { - simdgroup_half8x8 t; - simdgroup_float8x8 ms0; - simdgroup_float8x8 ms1; - - simdgroup_load(ms0, ss + C, TF, 0, false); - simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); - - for (short i = 0; i < D8; ++i) { - simdgroup_load (t, sq + i*8, T, 0, false); - simdgroup_multiply(t, ms1, t); - - simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); - } - } - } - } - - // store result to shared memory (reuse sq) - if (sgitg == 0) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); - } - } - - device float4 * dst4 = (device float4 *) dst; - - // final rescale with 1/S and store to global memory - if (sgitg == 0) { - for (short j = 0; j < Q && iq1 + j < ne01; ++j) { - const float S = ss[j*TF + 0]; - - for (short i = tiisg; i < D4; i += NW) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; - } - } - } -} - -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; -template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; -template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; -template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; - -template // head size, queries per threadgroup, cache items per threadgroup -kernel void kernel_flash_attn_ext_vec_f16( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb23, - constant uint64_t & nb31, - constant int64_t & ne1, - constant int64_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, - threadgroup half * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - const short nsg = ntg.y; // number of simdgroups - - const short iq3 = tgpig[2]; - const short iq2 = tgpig[1]; - const short iq1 = tgpig[0]; - - const short D4 = D/4; - const short NW = N_SIMDWIDTH; - const short SH = (C + Q); // shared memory per simdgroup in (half) - - const short T = D + 2*nsg*SH; // shared memory size per query in (half) - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = iq2; - - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix - threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 - threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results - - // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - half4 lo[D4/NW]; - - // load heads from Q to shared memory - device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); - - for (short i = tiisg; i < D4; i += NW) { - if (iq1 < ne01) { - sq4[i] = (half4) q4[i]; - } else { - sq4[i] = 0.0h; - } - } - - // zero out lo - for (short i = tiisg; i < D4; i += NW) { - lo[i/NW] = 0.0h; - } - - // zero out shared memory SH - for (short i = tiisg; i < SH/4; i += NW) { - ss4[i] = 0.0h; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - { - float S = { 0.0h }; - float M = { -FLT_MAX/2 }; - - // assume K and V are same shape - const short ne22 = ne12; - const short ne23 = ne13; - - // broadcast - const short rk2 = ne02/ne12; - const short rk3 = ne03/ne13; - - const short rv2 = ne02/ne22; - const short rv3 = ne03/ne23; - - // k indices - const short ik2 = iq2 / rk2; - const short ik3 = iq3 / rk3; - - // v indices - const short iv2 = iq2 / rv2; - const short iv3 = iq3 / rv3; - - // load the queries from shared memory into local memory - half4 mq[D4]; - - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - mq[i] = sq4[i]; - } - - // pointer to the mask - device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); - - // loop over the KV cache - // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { - const int ic = ic0 + C*sgitg; - if (ic >= ne11) { - break; - } - - // Q*K^T - { -#pragma unroll - for (short cc = 0; cc < C/4; ++cc) { - float4 mqk = { 0.0h }; - - device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); - -#pragma unroll - for (short ii = 0; ii < D4; ii += NW) { - const short i = ii + tiisg; - - half4x4 mk; - mk[0] = pk4[i + 0*(nb11/8)]; - mk[1] = pk4[i + 1*(nb11/8)]; - mk[2] = pk4[i + 2*(nb11/8)]; - mk[3] = pk4[i + 3*(nb11/8)]; - - mqk += (float4) (mq[i] * mk); - } - - // reduce the results from the threads in the simdgroup - mqk += simd_shuffle_down(mqk, 16); - mqk += simd_shuffle_down(mqk, 8); - mqk += simd_shuffle_down(mqk, 4); - mqk += simd_shuffle_down(mqk, 2); - mqk += simd_shuffle_down(mqk, 1); - - // mqk = mqk*scale + mask*slope - if (tiisg == 0) { - mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f); - - ss4[cc] = mqk; - } - } - } - - // online softmax - { - const short p = tiisg; - - const float m = M; - const float s = ss[p]; - - M = simd_max(max(M, s)); - - const float ms = exp(m - M); - const float vs = exp(s - M); - - S = S*ms + simd_sum(vs); - - // the P matrix from the paper (Q rows, C columns) - ss[p] = vs; - - // O = diag(ms)*O -#pragma unroll - for (short ii = 0; ii < D4; ii += NW) { - const short i = ii + tiisg; - lo[i/NW] *= ms; - } - } - - // O = O + (Q*K^T)*V - { -#pragma unroll - for (short cc = 0; cc < C/4; ++cc) { - device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); - -#pragma unroll - for (short ii = 0; ii < D4; ii += NW) { - const short i = ii + tiisg; - - lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; - lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; - lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; - lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; - } - } - } - - } - - // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - if (tiisg == 0) { - ss[0] = S; - ss[1] = M; - } - } - - // store results to shared memory - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - sr4[i] = lo[ii/NW]; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // parallel reduce - for (short r = nsg/2; r > 0; r >>= 1) { - if (sgitg < r) { - const float S0 = ss[ 0]; - const float S1 = ss[r*SH + 0]; - - const float M0 = ss[ 1]; - const float M1 = ss[r*SH + 1]; - - const float M = max(M0, M1); - - const float ms0 = exp(M0 - M); - const float ms1 = exp(M1 - M); - - const float S = S0*ms0 + S1*ms1; - - if (tiisg == 0) { - ss[0] = S; - ss[1] = M; - } - - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - device float4 * dst4 = (device float4 *) dst; - - // final rescale with 1/S and store to global memory - if (sgitg == 0) { - const float S = ss[0]; - - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S; - } - } -} - -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; - -kernel void kernel_cpy_f16_f16( - device const half * src0, - device half * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f16_f32( - device const half * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f16( - device const float * src0, - device half * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - // TODO: is there a better way to handle -INFINITY? - dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0]; - } -} - -kernel void kernel_cpy_f32_f32( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_q8_0( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0; - - device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - const float v = src[j]; - amax = MAX(amax, fabs(v)); - } - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK8_0].d = d; - - for (int j = 0; j < QK8_0; ++j) { - const float x0 = src[j]*id; - - dst_data[i00/QK8_0].qs[j] = round(x0); - } - } -} - -kernel void kernel_cpy_f32_q4_0( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0; - - device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < QK4_0; j++) { - const float v = src[j]; - if (amax < fabs(v)) { - amax = fabs(v); - max = v; - } - } - - const float d = max / -8; - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK4_0].d = d; - - for (int j = 0; j < QK4_0/2; ++j) { - const float x0 = src[0 + j]*id; - const float x1 = src[QK4_0/2 + j]*id; - - const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); - const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); - - dst_data[i00/QK4_0].qs[j] = xi0; - dst_data[i00/QK4_0].qs[j] |= xi1 << 4; - } - } -} - -kernel void kernel_cpy_f32_q4_1( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1; - - device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float min = FLT_MAX; - float max = -FLT_MAX; - - for (int j = 0; j < QK4_1; j++) { - const float v = src[j]; - if (min > v) min = v; - if (max < v) max = v; - } - - const float d = (max - min) / ((1 << 4) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK4_1].d = d; - dst_data[i00/QK4_1].m = min; - - for (int j = 0; j < QK4_1/2; ++j) { - const float x0 = (src[0 + j] - min)*id; - const float x1 = (src[QK4_1/2 + j] - min)*id; - - const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); - const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); - - dst_data[i00/QK4_1].qs[j] = xi0; - dst_data[i00/QK4_1].qs[j] |= xi1 << 4; - } - } -} - -kernel void kernel_cpy_f32_q5_0( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0; - - device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < QK5_0; j++) { - const float v = src[j]; - if (amax < fabs(v)) { - amax = fabs(v); - max = v; - } - } - - const float d = max / -16; - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK5_0].d = d; - - uint32_t qh = 0; - for (int j = 0; j < QK5_0/2; ++j) { - const float x0 = src[0 + j]*id; - const float x1 = src[QK5_0/2 + j]*id; - - const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); - const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); - - dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); - qh |= ((xi0 & 0x10u) >> 4) << (j + 0); - qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); - } - thread const uint8_t * qh8 = (thread const uint8_t *)&qh; - for (int j = 0; j < 4; ++j) { - dst_data[i00/QK5_0].qh[j] = qh8[j]; - } - } -} - -kernel void kernel_cpy_f32_q5_1( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1; - - device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float max = src[0]; - float min = src[0]; - - for (int j = 1; j < QK5_1; j++) { - const float v = src[j]; - min = v < min ? v : min; - max = v > max ? v : max; - } - - const float d = (max - min) / 31; - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK5_1].d = d; - dst_data[i00/QK5_1].m = min; - - uint32_t qh = 0; - for (int j = 0; j < QK5_1/2; ++j) { - const float x0 = (src[0 + j] - min)*id; - const float x1 = (src[QK5_1/2 + j] - min)*id; - - const uint8_t xi0 = (uint8_t)(x0 + 0.5f); - const uint8_t xi1 = (uint8_t)(x1 + 0.5f); - - dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); - qh |= ((xi0 & 0x10u) >> 4) << (j + 0); - qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); - } - thread const uint8_t * qh8 = (thread const uint8_t *)&qh; - for (int j = 0; j < 4; ++j) { - dst_data[i00/QK5_1].qh[j] = qh8[j]; - } - } -} - -static inline int best_index_int8(int n, constant float * val, float x) { - if (x <= val[0]) return 0; - if (x >= val[n-1]) return n-1; - int ml = 0, mu = n-1; - while (mu-ml > 1) { - int mav = (ml+mu)/2; - if (x < val[mav]) mu = mav; else ml = mav; - } - return x - val[mu-1] < val[mu] - x ? mu-1 : mu; -} - -constexpr constant static float kvalues_iq4nl_f[16] = { - -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f -}; - -kernel void kernel_cpy_f32_iq4_nl( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL; - - device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < QK4_0; j++) { - const float v = src[j]; - if (amax < fabs(v)) { - amax = fabs(v); - max = v; - } - } - - const float d = max / kvalues_iq4nl_f[0]; - const float id = d ? 1.0f/d : 0.0f; - - float sumqx = 0, sumq2 = 0; - for (int j = 0; j < QK4_NL/2; ++j) { - const float x0 = src[0 + j]*id; - const float x1 = src[QK4_NL/2 + j]*id; - - const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0); - const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1); - - dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4); - - const float v0 = kvalues_iq4nl_f[xi0]; - const float v1 = kvalues_iq4nl_f[xi1]; - const float w0 = src[0 + j]*src[0 + j]; - const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j]; - sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j]; - sumq2 += w0*v0*v0 + w1*v1*v1; - - } - - dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d; - - } -} - -kernel void kernel_concat( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i02 < ne02) { - ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0]; - src0_ptr += ntg.x*nb00; - } else { - ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0]; - src1_ptr += ntg.x*nb10; - } - dst_ptr += ntg.x*nb0; - } -} - -void kernel_mul_mv_q2_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q2_K) * nb; - -#if QK_K == 256 - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int iq = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - const int is = (8*ir)/16;// 0 or 1 - - device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; - - for (int ib = ix; ib < nb; ib += 4) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; - yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; - yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; - } - - device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); - acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); - acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); - acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); - acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); - acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); - acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); - acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); - } - float dall = dh[0]; - float dmin = dh[1] * 1.f/16.f; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + - (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f + - (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f + - (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - - dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); - - qs += step/2; - sc += step; - dh += step/2; - } - - y4 += 4 * QK_K; - } -#else - const int ix = tiisg/2; // 0...15 - const int it = tiisg%2; // 0...1 - - device const float * y4 = y + ix * QK_K + 8 * it; - - for (int ib = ix; ib < nb; ib += 16) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8]; - yl[i+16] = y4[i+32]; sumy[2] += yl[i+16]; - yl[i+24] = y4[i+48]; sumy[3] += yl[i+24]; - } - - device const uint8_t * sc = (device const uint8_t *)x[ib].scales; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); - acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); - acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); - acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); - acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); - acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); - acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); - acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + - (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f + - (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f + - (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) - - dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4)); - - qs += step/2; - sc += step; - dh += step/2; - } - - y4 += 16 * QK_K; - } -#endif - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -[[host_name("kernel_mul_mv_q2_K_f32")]] -kernel void kernel_mul_mv_q2_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -#if QK_K == 256 -void kernel_mul_mv_q3_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - - //const uint16_t kmask1 = 0x3030; - //const uint16_t kmask2 = 0x0f0f; - - const int tid = tiisg/4; - const int ix = tiisg%4; - const int ip = tid/4; // 0 or 1 - const int il = 2*((tid%4)/2); // 0 or 2 - const int ir = tid%2; - const int n = 8; - const int l0 = n*ir; - - // One would think that the Metal compiler would figure out that ip and il can only have - // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it - // with these two tales. - // - // Possible masks for the high bit - const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0 - {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2 - {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0 - {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2 - - // Possible masks for the low 2 bits - const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}}; - - const ushort4 hm = mm[2*ip + il/2]; - - const int shift = 2*il; - const float v1 = il == 0 ? 4.f : 64.f; - const float v2 = 4.f * v1; - - const uint16_t s_shift1 = 4*ip; - const uint16_t s_shift2 = s_shift1 + il; - - const int q_offset = 32*ip + l0; - const int y_offset = 128*ip + 32*il + l0; - - const int step = sizeof(block_q3_K) * nb / 2; - - device const float * y1 = yy + ix*QK_K + y_offset; - - uint32_t scales32, aux32; - thread uint16_t * scales16 = (thread uint16_t *)&scales32; - thread const int8_t * scales = (thread const int8_t *)&scales32; - - float sumf1[2] = {0.f}; - float sumf2[2] = {0.f}; - for (int i = ix; i < nb; i += 4) { - - for (int l = 0; l < 8; ++l) { - yl[l+ 0] = y1[l+ 0]; - yl[l+ 8] = y1[l+16]; - yl[l+16] = y1[l+32]; - yl[l+24] = y1[l+48]; - } - - device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); - device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); - device const uint16_t * a = (device const uint16_t *)(x[i].scales); - device const half * dh = &x[i].d; - - for (int row = 0; row < 2; ++row) { - - const float d_all = (float)dh[0]; - - scales16[0] = a[4]; - scales16[1] = a[5]; - aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030; - scales16[0] = a[il+0]; - scales16[1] = a[il+1]; - scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; - - float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; - for (int l = 0; l < n; l += 2) { - const int32_t qs = q[l/2]; - s1 += yl[l+0] * (qs & qm[il/2][0]); - s2 += yl[l+1] * (qs & qm[il/2][1]); - s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]); - s4 += yl[l+16] * (qs & qm[il/2][2]); - s5 += yl[l+17] * (qs & qm[il/2][3]); - s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]); - } - float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); - float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); - sumf1[row] += d1 * (scales[0] - 32); - sumf2[row] += d2 * (scales[2] - 32); - - s1 = s2 = s3 = s4 = s5 = s6 = 0; - for (int l = 0; l < n; l += 2) { - const int32_t qs = q[l/2+8]; - s1 += yl[l+8] * (qs & qm[il/2][0]); - s2 += yl[l+9] * (qs & qm[il/2][1]); - s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]); - s4 += yl[l+24] * (qs & qm[il/2][2]); - s5 += yl[l+25] * (qs & qm[il/2][3]); - s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]); - } - d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); - d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); - sumf1[row] += d1 * (scales[1] - 32); - sumf2[row] += d2 * (scales[3] - 32); - - q += step; - h += step; - a += step; - dh += step; - - } - - y1 += 4 * QK_K; - - } - - for (int row = 0; row < 2; ++row) { - const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); - sumf1[row] = simd_sum(sumf); - } - if (tiisg == 0) { - for (int row = 0; row < 2; ++row) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row]; - } - } -} -#else -void kernel_mul_mv_q3_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; - - const int row = 2 * r0 + sgitg; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - const int ix = tiisg/4; - const int il = 4 * (tiisg%4);// 0, 4, 8, 12 - const int iq = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 - - float2 sum = {0.f, 0.f}; - - for (int i = ix; i < nb; i += 8) { - - const float d_all = (float)(x[i].d); - - device const uint16_t * q = (device const uint16_t *)(x[i].qs + il); - device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in); - device const uint16_t * s = (device const uint16_t *)(x[i].scales); - device const float * y = yy + i * QK_K + il; - - const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8); - const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f; - const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f; - const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f; - - for (int l = 0; l < 4; l += 2) { - const uint16_t hm = h[l/2] >> iq; - sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4)) - + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16)) - + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64)) - + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256)); - sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024)) - + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096)) - + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384)) - + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536)); - } - - } - const float sumf = sum[0] + sum[1] * 1.f/256.f; - - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + row] = tot; - } - -} -#endif - -[[host_name("kernel_mul_mv_q3_K_f32")]] -kernel void kernel_mul_mv_q3_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -#if QK_K == 256 -void kernel_mul_mv_q4_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int iq = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int first_row = r0 * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[16]; - float yh[16]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q4_K) * nb / 2; - - device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; - - uint16_t sc16[4]; - thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - - for (int ib = ix; ib < nb; ib += 4) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; - yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; - yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; - yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; - } - - device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq; - device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - sc16[0] = sc[0] & kmask1; - sc16[1] = sc[2] & kmask1; - sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); - sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); - - device const uint16_t * q2 = q1 + 32; - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); - acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); - acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); - acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); - acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); - acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); - acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); - acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + - (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + - (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + - (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - - q1 += step; - sc += step; - dh += step; - } - - y4 += 4 * QK_K; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} -#else -void kernel_mul_mv_q4_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int ix = tiisg/4; // 0...7 - const int it = tiisg%4; // 0...3 - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - const int first_row = r0 * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[8]; - float yh[8]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q4_K) * nb / 2; - - device const float * y4 = y + ix * QK_K + 8 * it; - - uint16_t sc16[4]; - - for (int ib = ix; ib < nb; ib += 8) { - - float2 sumy = {0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i] = y4[i+ 0]; sumy[0] += yl[i]; - yh[i] = y4[i+32]; sumy[1] += yh[i]; - } - - device const uint16_t * sc = (device const uint16_t *)x[ib].scales; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; - device const half * dh = x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - sc16[0] = sc[0] & 0x000f; - sc16[1] = sc[0] & 0x0f00; - sc16[2] = sc[0] & 0x00f0; - sc16[3] = sc[0] & 0xf000; - - float2 acc1 = {0.f, 0.f}; - float2 acc2 = {0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (qs[i/2] & 0x000F); - acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00); - acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0); - acc2[1] += yh[i+1] * (qs[i/2] & 0xF000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] + - (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) - - dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f); - - qs += step; - sc += step; - dh += step; - } - - y4 += 8 * QK_K; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} -#endif - -[[host_name("kernel_mul_mv_q4_K_f32")]] -kernel void kernel_mul_mv_q4_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_q5_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float sumf[2]={0.f}; - - const int step = sizeof(block_q5_K) * nb; - -#if QK_K == 256 -# - float yl[16], yh[16]; - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int tid = tiisg/4; - const int ix = tiisg%4; - const int iq = tid/4; - const int ir = tid%4; - const int n = 8; - - const int l0 = n*ir; - const int q_offset = 32*iq + l0; - const int y_offset = 64*iq + l0; - - const uint8_t hm1 = 1u << (2*iq); - const uint8_t hm2 = hm1 << 1; - const uint8_t hm3 = hm1 << 4; - const uint8_t hm4 = hm2 << 4; - - uint16_t sc16[4]; - thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - - device const float * y1 = yy + ix*QK_K + y_offset; - - for (int i = ix; i < nb; i += 4) { - - device const uint8_t * q1 = x[i].qs + q_offset; - device const uint8_t * qh = x[i].qh + l0; - device const half * dh = &x[i].d; - device const uint16_t * a = (device const uint16_t *)x[i].scales + iq; - - device const float * y2 = y1 + 128; - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 8; ++l) { - yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; - yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; - yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; - yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; - } - - for (int row = 0; row < 2; ++row) { - - device const uint8_t * q2 = q1 + 64; - - sc16[0] = a[0] & kmask1; - sc16[1] = a[2] & kmask1; - sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); - sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); - - float4 acc1 = {0.f}; - float4 acc2 = {0.f}; - for (int l = 0; l < n; ++l) { - uint8_t h = qh[l]; - acc1[0] += yl[l+0] * (q1[l] & 0x0F); - acc1[1] += yl[l+8] * (q1[l] & 0xF0); - acc1[2] += yh[l+0] * (q2[l] & 0x0F); - acc1[3] += yh[l+8] * (q2[l] & 0xF0); - acc2[0] += h & hm1 ? yl[l+0] : 0.f; - acc2[1] += h & hm2 ? yl[l+8] : 0.f; - acc2[2] += h & hm3 ? yh[l+0] : 0.f; - acc2[3] += h & hm4 ? yh[l+8] : 0.f; - } - const float dall = dh[0]; - const float dmin = dh[1]; - sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + - sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + - sc8[4] * (acc1[2] + 16.f*acc2[2]) + - sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - - q1 += step; - qh += step; - dh += step/2; - a += step/2; - - } - - y1 += 4 * QK_K; - - } -#else - float yl[8], yh[8]; - - const int il = 4 * (tiisg/8); // 0, 4, 8, 12 - const int ix = tiisg%8; - const int iq = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 - - device const float * y = yy + ix*QK_K + il; - - for (int i = ix; i < nb; i += 8) { - - for (int l = 0; l < 4; ++l) { - yl[l+0] = y[l+ 0]; - yl[l+4] = y[l+16]; - yh[l+0] = y[l+32]; - yh[l+4] = y[l+48]; - } - - device const half * dh = &x[i].d; - device const uint8_t * q = x[i].qs + il; - device const uint8_t * h = x[i].qh + in; - device const int8_t * s = x[i].scales; - - for (int row = 0; row < 2; ++row) { - - const float d = dh[0]; - - float2 acc = {0.f, 0.f}; - for (int l = 0; l < 4; ++l) { - const uint8_t hl = h[l] >> iq; - acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16)) - + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16)); - acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256)) - + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256)); - } - sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]); - - q += step; - h += step; - s += step; - dh += step/2; - - } - - y += 8 * QK_K; - } -#endif - - for (int row = 0; row < 2; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; - } - } -} - -[[host_name("kernel_mul_mv_q5_K_f32")]] -kernel void kernel_mul_mv_q5_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_q6_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const uint8_t kmask1 = 0x03; - const uint8_t kmask2 = 0x0C; - const uint8_t kmask3 = 0x30; - const uint8_t kmask4 = 0xC0; - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int im = tgpig.z; - - const int row = 2 * r0 + sgitg; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float sumf = 0; - -#if QK_K == 256 - const int tid = tiisg/2; - const int ix = tiisg%2; - const int ip = tid/8; // 0 or 1 - const int il = tid%8; - const int n = 4; - const int l0 = n*il; - const int is = 8*ip + l0/16; - - const int y_offset = 128*ip + l0; - const int q_offset_l = 64*ip + l0; - const int q_offset_h = 32*ip + l0; - - for (int i = ix; i < nb; i += 2) { - - device const uint8_t * q1 = x[i].ql + q_offset_l; - device const uint8_t * q2 = q1 + 32; - device const uint8_t * qh = x[i].qh + q_offset_h; - device const int8_t * sc = x[i].scales + is; - - device const float * y = yy + i * QK_K + y_offset; - - const float dall = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < n; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); - sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); - } - - sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); - - } - -#else - const int ix = tiisg/4; - const int il = 4*(tiisg%4); - - for (int i = ix; i < nb; i += 8) { - device const float * y = yy + i * QK_K + il; - device const uint8_t * ql = x[i].ql + il; - device const uint8_t * qh = x[i].qh + il; - device const int8_t * s = x[i].scales; - - const float d = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 4; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32); - sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); - } - sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]); - } - -#endif - - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + row] = tot; - } -} - -[[host_name("kernel_mul_mv_q6_K_f32")]] -kernel void kernel_mul_mv_q6_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -// ======================= "True" 2-bit - -void kernel_mul_mv_iq2_xxs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); - { - int nval = 4; - int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i]; - nval = 2; - pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { - yl[i] = y4[i]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq2_xxs * xr = x + ibl; - device const uint16_t * q2 = xr->qs + 4 * ib; - device const half * dh = &xr->d; - - for (int row = 0; row < N_DST; row++) { - - const float db = dh[0]; - device const uint8_t * aux8 = (device const uint8_t *)q2; - const uint32_t aux32 = q2[2] | (q2[3] << 16); - const float d = db * (0.5f + (aux32 >> 28)); - - float sum = 0; - for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]); - const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; - for (int j = 0; j < 8; ++j) { - sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); - } - } - sumf[row] += d * sum; - - dh += nb*sizeof(block_iq2_xxs)/2; - q2 += nb*sizeof(block_iq2_xxs)/2; - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; - } - } -} - -[[host_name("kernel_mul_mv_iq2_xxs_f32")]] -kernel void kernel_mul_mv_iq2_xxs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_iq2_xs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512); - { - int nval = 8; - int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i]; - nval = 2; - pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { - yl[i] = y4[i]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq2_xs * xr = x + ibl; - device const uint16_t * q2 = xr->qs + 4 * ib; - device const uint8_t * sc = xr->scales + ib; - device const half * dh = &xr->d; - - for (int row = 0; row < N_DST; row++) { - - const float db = dh[0]; - const uint8_t ls1 = sc[0] & 0xf; - const uint8_t ls2 = sc[0] >> 4; - const float d1 = db * (0.5f + ls1); - const float d2 = db * (0.5f + ls2); - - float sum1 = 0, sum2 = 0; - for (int l = 0; l < 2; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); - const uint8_t signs = shared_signs[(q2[l] >> 9)]; - for (int j = 0; j < 8; ++j) { - sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); - } - } - for (int l = 2; l < 4; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); - const uint8_t signs = shared_signs[(q2[l] >> 9)]; - for (int j = 0; j < 8; ++j) { - sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); - } - } - sumf[row] += d1 * sum1 + d2 * sum2; - - dh += nb*sizeof(block_iq2_xs)/2; - q2 += nb*sizeof(block_iq2_xs)/2; - sc += nb*sizeof(block_iq2_xs); - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; - } - } -} - -[[host_name("kernel_mul_mv_iq2_xs_f32")]] -kernel void kernel_mul_mv_iq2_xs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_iq3_xxs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); - { - int nval = 4; - int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i]; - nval = 2; - pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { - yl[i] = y4[i]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq3_xxs * xr = x + ibl; - device const uint8_t * q3 = xr->qs + 8 * ib; - device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib; - device const half * dh = &xr->d; - - for (int row = 0; row < N_DST; row++) { - - const float db = dh[0]; - const uint32_t aux32 = gas[0] | (gas[1] << 16); - const float d = db * (0.5f + (aux32 >> 28)); - - float2 sum = {0}; - for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]); - const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]); - const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; - for (int j = 0; j < 4; ++j) { - sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); - sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); - } - } - sumf[row] += d * (sum[0] + sum[1]); - - dh += nb*sizeof(block_iq3_xxs)/2; - q3 += nb*sizeof(block_iq3_xxs); - gas += nb*sizeof(block_iq3_xxs)/2; - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f; - } - } -} - -[[host_name("kernel_mul_mv_iq3_xxs_f32")]] -kernel void kernel_mul_mv_iq3_xxs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_iq3_s_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; - { - int nval = 8; - int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { - yl[i] = y4[i]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq3_s * xr = x + ibl; - device const uint8_t * qs = xr->qs + 8 * ib; - device const uint8_t * qh = xr->qh + ib; - device const uint8_t * sc = xr->scales + (ib/2); - device const uint8_t * signs = xr->signs + 4 * ib; - device const half * dh = &xr->d; - - for (int row = 0; row < N_DST; row++) { - - const float db = dh[0]; - const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); - - float2 sum = {0}; - for (int l = 0; l < 4; ++l) { - const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values; - const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values; - const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); - const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); - for (int j = 0; j < 4; ++j) { - sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); - sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); - } - } - sumf[row] += d * (sum[0] + sum[1]); - - dh += nb*sizeof(block_iq3_s)/2; - qs += nb*sizeof(block_iq3_s); - qh += nb*sizeof(block_iq3_s); - sc += nb*sizeof(block_iq3_s); - signs += nb*sizeof(block_iq3_s); - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -[[host_name("kernel_mul_mv_iq3_s_f32")]] -kernel void kernel_mul_mv_iq3_s_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_iq2_s_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - //{ - // int nval = 32; - // int pos = (32*sgitg + tiisg)*nval; - // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i]; - // threadgroup_barrier(mem_flags::mem_threadgroup); - //} - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { - yl[i] = y4[i]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq2_s * xr = x + ibl; - device const uint8_t * qs = xr->qs + 4 * ib; - device const uint8_t * qh = xr->qh + ib; - device const uint8_t * sc = xr->scales + ib; - device const uint8_t * signs = qs + QK_K/8; - device const half * dh = &xr->d; - - for (int row = 0; row < N_DST; row++) { - - const float db = dh[0]; - const float d1 = db * (0.5f + (sc[0] & 0xf)); - const float d2 = db * (0.5f + (sc[0] >> 4)); - - float2 sum = {0}; - for (int l = 0; l < 2; ++l) { - //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); - //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); - constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); - constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); - for (int j = 0; j < 8; ++j) { - sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]); - sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]); - } - } - sumf[row] += d1 * sum[0] + d2 * sum[1]; - - dh += nb*sizeof(block_iq2_s)/2; - qs += nb*sizeof(block_iq2_s); - qh += nb*sizeof(block_iq2_s); - sc += nb*sizeof(block_iq2_s); - signs += nb*sizeof(block_iq2_s); - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; - } - } -} - -[[host_name("kernel_mul_mv_iq2_s_f32")]] -kernel void kernel_mul_mv_iq2_s_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_iq1_s_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_value, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - float sumy = 0; - for (int i = 0; i < 32; ++i) { - yl[i] = y4[i]; - sumy += yl[i]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq1_s * xr = x + ibl; - device const uint8_t * qs = xr->qs + 4 * ib; - device const uint16_t * qh = xr->qh + ib; - device const half * dh = &xr->d; - - for (int row = 0; row < N_DST; row++) { - - constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); - constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700))); - constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700))); - constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700))); - - float sum = 0; - for (int j = 0; j < 4; ++j) { - sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) - + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) - + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) - + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); - } - sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1); - - dh += nb*sizeof(block_iq1_s)/2; - qs += nb*sizeof(block_iq1_s); - qh += nb*sizeof(block_iq1_s)/2; - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -void kernel_mul_mv_iq1_m_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_value, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - -#if QK_K != 64 - iq1m_scale_t scale; -#endif - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - float4 sumy = {0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; - yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; - yl[i+24] = y4[i+24]; sumy[3] += yl[i+24]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq1_m * xr = x + ibl; - device const uint8_t * qs = xr->qs + 4 * ib; - device const uint8_t * qh = xr->qh + 2 * ib; - device const uint16_t * sc = (device const uint16_t *)xr->scales; - - for (int row = 0; row < N_DST; row++) { - -#if QK_K != 64 - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); -#endif - - constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); - constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); - constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700))); - constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700))); - - float2 sum = {0.f}; - for (int j = 0; j < 4; ++j) { - sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) - + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4); - sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) - + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); - } - const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); - const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); -#if QK_K == 64 - const float d = (float) *((device const half *)(sc - 1)); - sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) + - (sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1)); -#else - sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + - (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); -#endif - - sc += nb*sizeof(block_iq1_m)/2; - qs += nb*sizeof(block_iq1_m); - qh += nb*sizeof(block_iq1_m); - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -void kernel_mul_mv_iq4_nl_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values_i8, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - threadgroup float * shared_values = (threadgroup float *)shared_values_i8; - const int nb = ne00/QK4_NL; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - const int first_row = (r0 * 2 + sgitg) * 2; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - const int ix = tiisg/2; // 0...15 - const int it = tiisg%2; // 0 or 1 - - shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; - threadgroup_barrier(mem_flags::mem_threadgroup); - - float4 yl[4]; - float sumf[2]={0.f}, all_sum; - - device const float * yb = y + ix * QK4_NL + it * 8; - - uint32_t aux32[2]; - thread const uint8_t * q8 = (thread const uint8_t *)aux32; - - float4 qf1, qf2; - - for (int ib = ix; ib < nb; ib += 16) { - - device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - - for (int row = 0; row < 2; ++row) { - - device const block_iq4_nl & xb = x[row*nb + ib]; - device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); - - float4 acc1 = {0.f}, acc2 = {0.f}; - - aux32[0] = q4[0] | (q4[1] << 16); - aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; - aux32[0] &= 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; - acc1 += yl[0] * qf1; - acc2 += yl[1] * qf2; - - aux32[0] = q4[2] | (q4[3] << 16); - aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; - aux32[0] &= 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; - acc1 += yl[2] * qf1; - acc2 += yl[3] * qf2; - - acc1 += acc2; - - sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); - - } - - yb += 16 * QK4_NL; - } - - for (int row = 0; row < 2; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -#if QK_K != 64 -void kernel_mul_mv_iq4_xs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values_i8, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - threadgroup float * shared_values = (threadgroup float *)shared_values_i8; - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - const int first_row = (r0 * 2 + sgitg) * 2; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - const int ix = tiisg/16; // 0 or 1 - const int it = tiisg%16; // 0...15 - const int ib = it/2; - const int il = it%2; - - shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; - threadgroup_barrier(mem_flags::mem_threadgroup); - - float4 yl[4]; - float sumf[2]={0.f}, all_sum; - - device const float * yb = y + ix * QK_K + ib * 32 + il * 8; - - uint32_t aux32[2]; - thread const uint8_t * q8 = (thread const uint8_t *)aux32; - - float4 qf1, qf2; - - for (int ibl = ix; ibl < nb; ibl += 2) { - - device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - - for (int row = 0; row < 2; ++row) { - - device const block_iq4_xs & xb = x[row*nb + ibl]; - device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); - - float4 acc1 = {0.f}, acc2 = {0.f}; - - aux32[0] = q4[0] & 0x0f0f0f0f; - aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; - acc1 += yl[0] * qf1; - acc2 += yl[1] * qf2; - - aux32[0] = q4[1] & 0x0f0f0f0f; - aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; - acc1 += yl[2] * qf1; - acc2 += yl[3] * qf2; - - acc1 += acc2; - - const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; - sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); - - } - - yb += 2 * QK_K; - } - - for (int row = 0; row < 2; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} -#endif - -[[host_name("kernel_mul_mv_iq1_s_f32")]] -kernel void kernel_mul_mv_iq1_s_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -[[host_name("kernel_mul_mv_iq1_m_f32")]] -kernel void kernel_mul_mv_iq1_m_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -[[host_name("kernel_mul_mv_iq4_nl_f32")]] -kernel void kernel_mul_mv_iq4_nl_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -} - -[[host_name("kernel_mul_mv_iq4_xs_f32")]] -kernel void kernel_mul_mv_iq4_xs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - -#if QK_K == 64 - kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -#else - kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -#endif -} - -//============================= templates and their specializations ============================= - -// NOTE: this is not dequantizing - we are simply fitting the template -template -void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { - float4x4 temp = *(((device float4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} - -template -void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { - half4x4 temp = *(((device half4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} - -template -void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 1); - const float d1 = il ? (xb->d / 16.h) : xb->d; - const float d2 = d1 / 256.f; - const float md = -8.h * xb->d; - const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = mask0 << 8; - - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; - reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; - } -} - -template -void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 2); - const float d1 = il ? (xb->d / 16.h) : xb->d; - const float d2 = d1 / 256.f; - const float m = xb->m; - const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = mask0 << 8; - - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m; - reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m; - } -} - -template -void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 3); - const float d = xb->d; - const float md = -16.h * xb->d; - const ushort mask = il ? 0x00F0 : 0x000F; - - const uint32_t qh = *((device const uint32_t *)xb->qh); - - const int x_mv = il ? 4 : 0; - - const int gh_mv = il ? 12 : 0; - const int gh_bk = il ? 0 : 4; - - for (int i = 0; i < 8; i++) { - // extract the 5-th bits for x0 and x1 - const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; - const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; - - // combine the 4-bits from qs with the 5th bit - const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); - const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); - - reg[i/2][2*(i%2)+0] = d * x0 + md; - reg[i/2][2*(i%2)+1] = d * x1 + md; - } -} - -template -void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 4); - const float d = xb->d; - const float m = xb->m; - const ushort mask = il ? 0x00F0 : 0x000F; - - const uint32_t qh = *((device const uint32_t *)xb->qh); - - const int x_mv = il ? 4 : 0; - - const int gh_mv = il ? 12 : 0; - const int gh_bk = il ? 0 : 4; - - for (int i = 0; i < 8; i++) { - // extract the 5-th bits for x0 and x1 - const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; - const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; - - // combine the 4-bits from qs with the 5th bit - const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); - const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); - - reg[i/2][2*(i%2)+0] = d * x0 + m; - reg[i/2][2*(i%2)+1] = d * x1 + m; - } -} - -template -void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { - device const int8_t * qs = ((device const int8_t *)xb->qs); - const half d = xb->d; - - for (int i = 0; i < 16; i++) { - reg[i/4][i%4] = (qs[i + 16*il] * d); - } -} - -template -void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { - const float d = xb->d; - const float min = xb->dmin; - device const uint8_t * q = (device const uint8_t *)xb->qs; - float dl, ml; - uint8_t sc = xb->scales[il]; - -#if QK_K == 256 - q = q + 32*(il/8) + 16*(il&1); - il = (il/2)%4; -#endif - half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - ml; - } -} - -template -void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { - const half d_all = xb->d; - device const uint8_t * q = (device const uint8_t *)xb->qs; - device const uint8_t * h = (device const uint8_t *)xb->hmask; - device const int8_t * scales = (device const int8_t *)xb->scales; - -#if QK_K == 256 - q = q + 32 * (il/8) + 16 * (il&1); - h = h + 16 * (il&1); - uint8_t m = 1 << (il/2); - uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ - ((il/4)>0 ? 12 : 3); - uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; - uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; - int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) - : (scale_2&kmask2) | ((scale_1&kmask1) << 4); - float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); - const float ml = 4.f * dl; - - il = (il/2) & 3; - const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - dl *= coef; - - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); - } -#else - float kcoef = il&1 ? 1.f/16.f : 1.f; - uint16_t kmask = il&1 ? 0xF0 : 0x0F; - float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8); - float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - uint8_t m = 1<<(il*2); - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef)); - } -#endif -} - -static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { - return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)} - : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))}; -} - -template -void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { - device const uchar * q = xb->qs; - -#if QK_K == 256 - short is = (il/4) * 2; - q = q + (il/4) * 32 + 16 * (il&1); - il = il & 3; - const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const float d = il < 2 ? xb->d : xb->d / 16.h; - const float min = xb->dmin; - const float dl = d * sc[0]; - const float ml = min * sc[1]; -#else - (void) get_scale_min_k4_just2; - - q = q + 16 * (il&1); - device const uint8_t * s = xb->scales; - device const half2 * dh = (device const half2 *)xb->d; - const float2 d = (float2)dh[0]; - const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h; - const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4); -#endif - const ushort mask = il<2 ? 0x0F : 0xF0; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - ml; - } -} - -template -void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { - device const uint8_t * q = xb->qs; - device const uint8_t * qh = xb->qh; - -#if QK_K == 256 - short is = (il/4) * 2; - q = q + 32 * (il/4) + 16 * (il&1); - qh = qh + 16 * (il&1); - uint8_t ul = 1 << (il/2); - il = il & 3; - const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const float d = il < 2 ? xb->d : xb->d / 16.f; - const float min = xb->dmin; - const float dl = d * sc[0]; - const float ml = min * sc[1]; - - const ushort mask = il<2 ? 0x0F : 0xF0; - const float qh_val = il<2 ? 16.f : 256.f; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; - } -#else - q = q + 16 * (il&1); - device const int8_t * s = xb->scales; - const float dl = xb->d * s[il]; - uint8_t m = 1<<(il*2); - const float coef = il<2 ? 1.f : 1.f/16.f; - const ushort mask = il<2 ? 0x0F : 0xF0; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef)); - } -#endif -} - -template -void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { - const half d_all = xb->d; - device const uint8_t * ql = (device const uint8_t *)xb->ql; - device const uint8_t * qh = (device const uint8_t *)xb->qh; - device const int8_t * scales = (device const int8_t *)xb->scales; - -#if QK_K == 256 - ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); - qh = qh + 32*(il/8) + 16*(il&1); - float sc = scales[(il%2) + 2 * ((il/2))]; - il = (il/2) & 3; -#else - ql = ql + 16 * (il&1); - float sc = scales[il]; -#endif - const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; - const float coef = il>1 ? 1.f/16.f : 1.f; - const float ml = d_all * sc * 32.f; - const float dl = d_all * sc * coef; - for (int i = 0; i < 16; ++i) { - const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) - : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); - reg[i/4][i%4] = dl * q - ml; - } -} - -template -void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const float d = xb->d; - const int ib32 = il/2; - il = il%2; - // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 - // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's. - device const uint16_t * q2 = xb->qs + 4*ib32; - const uint32_t aux32_g = q2[0] | (q2[1] << 16); - const uint32_t aux32_s = q2[2] | (q2[3] << 16); - thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g; - const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f; - constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]); - uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127]; - for (int i = 0; i < 8; ++i) { - reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); - } - grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]); - signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127]; - for (int i = 0; i < 8; ++i) { - reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); - } -} - -template -void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const float d = xb->d; - const int ib32 = il/2; - il = il%2; - // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 - device const uint16_t * q2 = xb->qs + 4*ib32; - const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; - constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511)); - uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9]; - for (int i = 0; i < 8; ++i) { - reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); - } - grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511)); - signs = ksigns_iq2xs[q2[2*il+1] >> 9]; - for (int i = 0; i < 8; ++i) { - reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); - } -} - -template -void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const float d = xb->d; - const int ib32 = il/2; - il = il%2; - // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 - device const uint8_t * q3 = xb->qs + 8*ib32; - device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32; - const uint32_t aux32 = gas[0] | (gas[1] << 16); - const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f; - constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]); - constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]); - uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127]; - for (int i = 0; i < 4; ++i) { - reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); - reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); - } - grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]); - grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]); - signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127]; - for (int i = 0; i < 4; ++i) { - reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); - reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); - } -} - -template -void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const float d = xb->d; - const int ib32 = il/2; - il = il%2; - // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 - device const uint8_t * qs = xb->qs + 8*ib32; - device const uint8_t * signs = xb->signs + 4*ib32 + 2*il; - const uint8_t qh = xb->qh[ib32] >> 4*il; - const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); - constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256))); - constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256))); - for (int i = 0; i < 4; ++i) { - reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); - reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); - } - grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256))); - grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256))); - for (int i = 0; i < 4; ++i) { - reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); - reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); - } -} - -template -void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const float d = xb->d; - const int ib32 = il/2; - il = il%2; - // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 - device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; - device const uint8_t * signs = qs + QK_K/8; - const uint8_t qh = xb->qh[ib32] >> 4*il; - const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; - constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300))); - constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300))); - for (int i = 0; i < 8; ++i) { - reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]); - reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]); - } -} - -template -void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const int ib32 = il/2; - il = il%2; - const float d = xb->d; - device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; - device const uint16_t * qh = xb->qh; - const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1); - const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA); - const uint16_t h = qh[ib32] >> 6*il; - constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700))); - constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700))); - for (int i = 0; i < 4; ++i) { - reg[0][i] = dl * (grid1[i] & 0xf) + ml; - reg[1][i] = dl * (grid1[i] >> 4) + ml; - reg[2][i] = dl * (grid2[i] & 0xf) + ml; - reg[3][i] = dl * (grid2[i] >> 4) + ml; - } -} - -template -void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const int ib32 = il/2; - il = il%2; - device const uint16_t * sc = (device const uint16_t *)xb->scales; -#if QK_K == 64 - const float d = xb->d; -#else - iq1m_scale_t scale; - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - const float d = scale.f16; -#endif - device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; - device const uint8_t * qh = xb->qh + 2*ib32 + il; -#if QK_K == 64 - const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1); -#else - const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1); -#endif - const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); - const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); - constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); - constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); - for (int i = 0; i < 4; ++i) { - reg[0][i] = dl * (grid1[i] & 0xf) + ml1; - reg[1][i] = dl * (grid1[i] >> 4) + ml1; - reg[2][i] = dl * (grid2[i] & 0xf) + ml2; - reg[3][i] = dl * (grid2[i] >> 4) + ml2; - } -} - -template -void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) { - device const uint16_t * q4 = (device const uint16_t *)xb->qs; - const float d = xb->d; - uint32_t aux32; - thread const uint8_t * q8 = (thread const uint8_t *)&aux32; - for (int i = 0; i < 4; ++i) { - aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f; - reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; - reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; - reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; - reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; - } -} - -template -void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) { -#if QK_K == 64 - dequantize_iq4_nl(xb, il, reg); -#else - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const int ib32 = il/2; - il = il%2; - // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 - device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32; - const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4); - const float d = (float)xb->d * (ls - 32); - uint32_t aux32; - thread const uint8_t * q8 = (thread const uint8_t *)&aux32; - for (int i = 0; i < 4; ++i) { - aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f; - reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; - reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; - reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; - reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; - } -#endif -} - -template -kernel void kernel_get_rows( - device const void * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - //const int64_t i = tgpig; - //const int64_t r = ((device int32_t *) src1)[i]; - - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; - - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; - - const int64_t i02 = i11; - - for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { - float4x4 temp; - dequantize_func( - ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; - } -} - -kernel void kernel_get_rows_f32( - device const void * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; - - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} - -kernel void kernel_get_rows_f16( - device const void * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; - - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} - -kernel void kernel_get_rows_i32( - device const void * src0, - device const char * src1, - device int32_t * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; - - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} - - -#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A -#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B -#define BLOCK_SIZE_K 32 -#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A -#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B -#define THREAD_PER_BLOCK 128 -#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers -#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers -#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 -#define SG_MAT_ROW 8 - -// each block_q contains 16*nl weights -template -void kernel_mul_mm_impl(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - threadgroup half * sa = (threadgroup half *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); - - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; - const uint im = tgpig.z; - - // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; - - // a thread shouldn't load data outside of the matrix - short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; - short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - - simdgroup_half8x8 ma[4]; - simdgroup_float8x8 mb[2]; - simdgroup_float8x8 c_res[8]; - for (int i = 0; i < 8; i++){ - c_res[i] = make_filled_simdgroup_matrix(0.f); - } - - short il = (tiitg % THREAD_PER_ROW); - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); - ushort offset1 = il/nl; - - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; - device const float * y = (device const float *)(src1 - + nb12 * im - + nb11 * (r1 * BLOCK_SIZE_N + thread_col) - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); - - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { - // load data and store to threadgroup memory - half4x4 temp_a; - dequantize_func(x, il, temp_a); - threadgroup_barrier(mem_flags::mem_threadgroup); - - #pragma unroll(16) - for (int i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ - + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ - + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; - } - - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); - - il = (il + 2 < nl) ? il + 2 : il % 2; - x = (il < 2) ? x + (2+nl-1)/nl : x; - y += BLOCK_SIZE_K; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // load matrices from threadgroup memory and conduct outer products - threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); - threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); - - #pragma unroll(4) - for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { - #pragma unroll(4) - for (int i = 0; i < 4; i++) { - simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); - } - simdgroup_barrier(mem_flags::mem_none); - #pragma unroll(2) - for (int i = 0; i < 2; i++) { - simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); - } - - lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; - lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; - - #pragma unroll(8) - for (int i = 0; i < 8; i++){ - simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); - } - } - } - - if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { - device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ - + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); - } - } else { - // block is smaller than 64x32, we should avoid writing data outside of the matrix - threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ - + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; - if (sgitg == 0) { - for (int i = 0; i < n_rows; i++) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); - } - } - } - } -} - -// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids -template -void kernel_mul_mm_id_impl( - device const uchar * src0, - device const uchar * src1, - threadgroup ushort2 * rowids, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - int64_t ne1, - int64_t ne0ne1, - threadgroup uchar * shared_memory, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - threadgroup half * sa = (threadgroup half *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); - - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; - - if (r1 * BLOCK_SIZE_N >= ne1) return; - - // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; - - // a thread shouldn't load data outside of the matrix - short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; - short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - - simdgroup_half8x8 ma[4]; - simdgroup_float8x8 mb[2]; - simdgroup_float8x8 c_res[8]; - for (int i = 0; i < 8; i++){ - c_res[i] = make_filled_simdgroup_matrix(0.f); - } - short il = (tiitg % THREAD_PER_ROW); - - ushort offset1 = il/nl; - - threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col]; - - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1; - device const float * y = (device const float *)(src1 - + nb12 * id[1] - + nb11 * (id[0] % ne11) - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); - - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { - // load data and store to threadgroup memory - half4x4 temp_a; - dequantize_func(x, il, temp_a); - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (int i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ - + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ - + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; - } - - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); - - il = (il + 2 < nl) ? il + 2 : il % 2; - x = (il < 2) ? x + (2+nl-1)/nl : x; - y += BLOCK_SIZE_K; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // load matrices from threadgroup memory and conduct outer products - threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); - threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); - - for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { - for (int i = 0; i < 4; i++) { - simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); - } - simdgroup_barrier(mem_flags::mem_none); - for (int i = 0; i < 2; i++) { - simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); - } - - lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; - lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; - - for (int i = 0; i < 8; i++){ - simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); - } - } - } - - { - threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ - + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - device float * C = dst + (BLOCK_SIZE_M * r0); - if (sgitg == 0) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; - int joff = jid[0] * ne0 + jid[1] * ne0ne1; - for (int i = 0; i < n_rows; i++) { - *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M); - } - } - } - } -} - -template -kernel void kernel_mul_mm(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mm_impl( - src0, - src1, - dst, - ne00, - ne02, - nb01, - nb02, - ne12, - nb10, - nb11, - nb12, - ne0, - ne1, - r2, - r3, - shared_memory, - tgpig, - tiitg, - sgitg); -} - -template -kernel void kernel_mul_mm_id( - device const uchar * src0s, - device const uchar * src1, - device float * dst, - device const uchar * ids, - constant int64_t & nei0, - constant int64_t & nei1, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int32_t i02 = tgpig.z; - tgpig.z = 0; - - device const uchar * src0 = src0s + i02*nb02; - - // row indices - threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); - - // TODO: parallelize this loop - int64_t _ne1 = 0; - for (ushort ii1 = 0; ii1 < nei1; ii1++) { - for (ushort ii0 = 0; ii0 < nei0; ii0++) { - int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; - if (id == i02) { - //if (tiitg == 0) { - rowids[_ne1] = ushort2(ii0, ii1); - //} - _ne1++; - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - kernel_mul_mm_id_impl( - src0, - src1, - rowids, - dst, - ne00, - ne02, - nb01, - nb02, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - _ne1, - ne0*ne1, - shared_memory, - tgpig, - tiitg, - sgitg); -} - -#if QK_K == 256 -#define QK_NL 16 -#else -#define QK_NL 4 -#endif - -// -// get rows -// - -typedef void (get_rows_t)( - device const void * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3, uint, uint3); - -//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; -//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows; -#if QK_K == 64 -template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows; -#else -template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows; -#endif - -// -// matrix-matrix multiplication -// - -typedef decltype(kernel_mul_mm) mat_mm_t; - -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; -#if QK_K == 64 -template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; -#else -template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; -#endif - -// -// indirect matrix-matrix multiplication -// - -typedef decltype(kernel_mul_mm_id) mat_mm_id_t; - -template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -#if QK_K == 64 -template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -#else -template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -#endif - -// -// matrix-vector multiplication -// - -typedef void (kernel_mul_mv_impl_t)( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - int64_t ne10, - int64_t ne11, - int64_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg); - -typedef void (kernel_mul_mv2_impl_t)( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg); - -template -void mmv_fn( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - int64_t ne10, - int64_t ne11, - int64_t ne12, - int64_t ne13, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int64_t ne0, - int64_t ne1, - uint64_t nb1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { - impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg); -} - -template -void mmv_fn( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - int64_t ne10, - int64_t ne11, - int64_t ne12, - int64_t ne13, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int64_t ne0, - int64_t ne1, - uint64_t nb1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { - impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); -} - -typedef decltype(mmv_fn) mul_mv_impl_fn_t; - -template -kernel void kernel_mul_mv_id( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant int64_t & nei0, - constant int64_t & nei1, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int iid1 = tgpig.z/nei0; - const int idx = tgpig.z%nei0; - - tgpig.z = 0; - - const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx]; - - const int64_t i11 = idx % ne11; - const int64_t i12 = iid1; - - const int64_t i1 = idx; - const int64_t i2 = i12; - - device const char * src0_cur = src0s + i02*nb02; - device const char * src1_cur = src1 + i11*nb11 + i12*nb12; - device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; - - impl_fn( - /* src0 */ src0_cur, - /* src1 */ src1_cur, - /* dst */ dst_cur, - /* ne00 */ ne00, - /* ne01 */ ne01, - /* ne02 */ 1,//ne02, - /* nb00 */ nb00, - /* nb01 */ nb01, - /* nb02 */ nb02, - /* ne10 */ ne10, - /* ne11 */ 1,//ne11, - /* ne12 */ 1,//ne12, - /* ne13 */ 1,//ne13, - /* nb10 */ nb10, - /* nb11 */ nb11, - /* nb12 */ nb12, - /* ne0 */ ne0, - /* ne1 */ 1,//ne1, - /* nb1 */ nb1, - /* r2 */ 1, - /* r3 */ 1, - shared_values, - tgpig, - tiitg, - tiisg, - sgitg); -} - -typedef decltype(kernel_mul_mv_id>) kernel_mul_mv_id_t; - -template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -#if QK_K != 64 -template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -#endif - diff --git a/enjoy/lib/whisper.cpp/arm64/darwin/main b/enjoy/lib/whisper.cpp/arm64/darwin/main deleted file mode 100755 index a64d3fe6..00000000 Binary files a/enjoy/lib/whisper.cpp/arm64/darwin/main and /dev/null differ diff --git a/enjoy/lib/whisper.cpp/x64/darwin/ggml-metal.metal b/enjoy/lib/whisper.cpp/x64/darwin/ggml-metal.metal deleted file mode 100644 index 57fdf564..00000000 --- a/enjoy/lib/whisper.cpp/x64/darwin/ggml-metal.metal +++ /dev/null @@ -1,6854 +0,0 @@ -#define GGML_COMMON_DECL_METAL -#define GGML_COMMON_IMPL_METAL -#include "ggml-common.h" - -#include - -using namespace metal; - -#define MAX(x, y) ((x) > (y) ? (x) : (y)) -#define MIN(x, y) ((x) < (y) ? (x) : (y)) -#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } - -#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 - -enum ggml_sort_order { - GGML_SORT_ORDER_ASC, - GGML_SORT_ORDER_DESC, -}; - -// general-purpose kernel for addition, multiplication and division of two tensors -// pros: works for non-contiguous tensors, supports broadcast across all dims -// cons: not very efficient -kernel void kernel_add( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int64_t & offs, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10)); - } -} - -kernel void kernel_mul( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10)); - } -} - -kernel void kernel_div( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10)); - } -} - -// assumption: src1 is a row -// broadcast src1 into src0 -kernel void kernel_add_row( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - constant uint64_t & nb [[buffer(28)]], - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] + src1[tpig % nb]; -} - -kernel void kernel_mul_row( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - constant uint64_t & nb [[buffer(28)]], - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src1[tpig % nb]; -} - -kernel void kernel_div_row( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - constant uint64_t & nb [[buffer(28)]], - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] / src1[tpig % nb]; -} - -kernel void kernel_scale( - device const float * src0, - device float * dst, - constant float & scale, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale; -} - -kernel void kernel_scale_4( - device const float4 * src0, - device float4 * dst, - constant float & scale, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale; -} - -kernel void kernel_clamp( - device const float * src0, - device float * dst, - constant float & min, - constant float & max, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]); -} - -kernel void kernel_relu( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = max(0.0f, src0[tpig]); -} - -kernel void kernel_sigmoid( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); -} - -kernel void kernel_tanh( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = precise::tanh(x); -} - -constant float GELU_COEF_A = 0.044715f; -constant float GELU_QUICK_COEF = -1.702f; -constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - -kernel void kernel_gelu( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - - dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - // BEWARE !!! - // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! - // This was observed with Falcon 7B and 40B models - // - dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_quick( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - - dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -kernel void kernel_gelu_quick_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -kernel void kernel_silu( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = x / (1.0f + exp(-x)); -} - -kernel void kernel_silu_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - dst[tpig] = x / (1.0f + exp(-x)); -} - -kernel void kernel_sqr( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src0[tpig]; -} - -kernel void kernel_sum_rows( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tpig[[thread_position_in_grid]]) { - int64_t i3 = tpig.z; - int64_t i2 = tpig.y; - int64_t i1 = tpig.x; - - if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { - return; - } - - device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); - device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3); - - float row_sum = 0; - - for (int64_t i0 = 0; i0 < ne00; i0++) { - row_sum += src_row[i0]; - } - - dst_row[0] = row_sum; -} - -template -kernel void kernel_soft_max( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (ne02*ne01); - const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; - const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - - device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; - device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - const int64_t h = i02; - - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float lmax = -INFINITY; - - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); - } - - // find the max value in the block - float max_val = simd_max(lmax); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = -INFINITY; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = max_val; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - max_val = buf[tiisg]; - max_val = simd_max(max_val); - } - - // parallel sum - float lsum = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); - lsum += exp_psrc0; - pdst[i00] = exp_psrc0; - } - - // This barrier fixes a failing test - // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 - threadgroup_barrier(mem_flags::mem_none); - - float sum = simd_sum(lsum); - - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = sum; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - sum = buf[tiisg]; - sum = simd_sum(sum); - } - - const float inv_sum = 1.0f/sum; - - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - pdst[i00] *= inv_sum; - } -} - -template -kernel void kernel_soft_max_4( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (ne02*ne01); - const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; - const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - - device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; - device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; - - float slope = 1.0f; - - if (max_bias > 0.0f) { - const int64_t h = i02; - - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float4 lmax4 = -INFINITY; - - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); - } - - const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); - - float max_val = simd_max(lmax); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = -INFINITY; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = max_val; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - max_val = buf[tiisg]; - max_val = simd_max(max_val); - } - - // parallel sum - float4 lsum4 = 0.0f; - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); - lsum4 += exp_psrc4; - pdst4[i00] = exp_psrc4; - } - - const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; - - // This barrier fixes a failing test - // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 - threadgroup_barrier(mem_flags::mem_none); - - float sum = simd_sum(lsum); - - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = sum; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - sum = buf[tiisg]; - sum = simd_sum(sum); - } - - const float inv_sum = 1.0f/sum; - - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - pdst4[i00] *= inv_sum; - } -} - -typedef decltype(kernel_soft_max) kernel_soft_max_t; -typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t; - -template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max; -template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max; -template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; -template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; - -kernel void kernel_diag_mask_inf( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int & n_past, - uint3 tpig[[thread_position_in_grid]]) { - const int64_t i02 = tpig[2]; - const int64_t i01 = tpig[1]; - const int64_t i00 = tpig[0]; - - if (i00 > n_past + i01) { - dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; - } else { - dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; - } -} - -kernel void kernel_diag_mask_inf_8( - device const float4 * src0, - device float4 * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int & n_past, - uint3 tpig[[thread_position_in_grid]]) { - - const int64_t i = 2*tpig[0]; - - dst[i+0] = src0[i+0]; - dst[i+1] = src0[i+1]; - int64_t i4 = 4*i; - const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; - const int64_t i01 = i4/(ne00); i4 -= i01*ne00; - const int64_t i00 = i4; - for (int k = 3; k >= 0; --k) { - if (i00 + 4 + k <= n_past + i01) { - break; - } - dst[i+1][k] = -INFINITY; - if (i00 + k > n_past + i01) { - dst[i][k] = -INFINITY; - } - } -} - -kernel void kernel_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * sum [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); - // MEAN - // parallel sum - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - sum[tpitg] += x[i00]; - } - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - const float mean = sum[0] / ne00; - - // recenter and VARIANCE - threadgroup_barrier(mem_flags::mem_threadgroup); - device float * y = dst + tgpig*ne00; - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - y[i00] = x[i00] - mean; - sum[tpitg] += y[i00] * y[i00]; - } - - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - const float variance = sum[0] / ne00; - - const float scale = 1.0f/sqrt(variance + eps); - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - y[i00] = y[i00] * scale; - } -} - -kernel void kernel_rms_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); - - float4 sumf = 0; - float all_sum = 0; - - // parallel sum - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - sumf += x[i00] * x[i00]; - } - all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; - all_sum = simd_sum(all_sum); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = all_sum; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - all_sum = buf[tiisg]; - all_sum = simd_sum(all_sum); - } - - const float mean = all_sum/ne00; - const float scale = 1.0f/sqrt(mean + eps); - - device float4 * y = (device float4 *) (dst + tgpig*ne00); - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - y[i00] = x[i00] * scale; - } -} - -kernel void kernel_group_norm( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int32_t & n_groups, - constant float & eps, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t ne = ne00*ne01*ne02; - const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups); - - int start = tgpig * gs; - int end = start + gs; - - start += tpitg; - - if (end >= ne) { - end = ne; - } - - float tmp = 0.0f; // partial sum for thread in warp - - for (int j = start; j < end; j += ntg) { - tmp += src0[j]; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - tmp = simd_sum(tmp); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = tmp; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - tmp = buf[tiisg]; - tmp = simd_sum(tmp); - } - - const float mean = tmp / gs; - tmp = 0.0f; - - for (int j = start; j < end; j += ntg) { - float xi = src0[j] - mean; - dst[j] = xi; - tmp += xi * xi; - } - - tmp = simd_sum(tmp); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = tmp; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - tmp = buf[tiisg]; - tmp = simd_sum(tmp); - } - - const float variance = tmp / gs; - const float scale = 1.0f/sqrt(variance + eps); - for (int j = start; j < end; j += ntg) { - dst[j] *= scale; - } -} - -// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q4 quants begin (0 or QK4_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - - float2 acc = 0.f; - - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); - - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); - } - return d * (sumy * -8.f + acc[0] + acc[1]); -} - -// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q4 quants begin (0 or QK4_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - float m = qb_curr->m; - - float2 acc = 0.f; - - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); - - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); - } - return d * (acc[0] + acc[1]) + sumy * m; -} - -// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q5 quants begin (0 or QK5_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - - float2 acc = 0.f; - - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); - const uint32_t qh = *((device const uint32_t *)qb_curr->qh); - - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) - + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); - acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) - + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); - } - return d * (sumy * -16.f + acc[0] + acc[1]); -} - -// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q5 quants begin (0 or QK5_1/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - float m = qb_curr->m; - - float2 acc = 0.f; - - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); - const uint32_t qh = *((device const uint32_t *)qb_curr->qh); - - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) - + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); - acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) - + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); - } - return d * (acc[0] + acc[1]) + sumy * m; -} - -// putting them in the kernel cause a significant performance penalty -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group -//Note: This is a template, but strictly speaking it only applies to -// quantizations where the block size is 32. It also does not -// guard against the number of rows not being divisible by -// N_DST, so this is another explicit assumption of the implementation. -template -void mul_vec_q_n_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, uint tiisg, uint sgitg) { - const int nb = ne00/QK4_0; - - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * nsg + sgitg) * nr; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q_type * x = (device const block_q_type *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[16]; // src1 vector cache - float sumf[nr] = {0.f}; - - const int ix = (tiisg/2); - const int il = (tiisg%2)*8; - - device const float * yb = y + ix * QK4_0 + il; - - // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += nw/2) { - float sumy = 0; - for (int i = 0; i < 8; i += 2) { - sumy += yb[i] + yb[i+1]; - yl[i+0] = yb[i+ 0]; - yl[i+1] = yb[i+ 1]/256.f; - - sumy += yb[i+16] + yb[i+17]; - yl[i+8] = yb[i+16]/16.f; - yl[i+9] = yb[i+17]/4096.f; - } - - for (int row = 0; row < nr; row++) { - sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); - } - - yb += QK4_0 * 16; - } - - for (int row = 0; row < nr; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot; - } - } -} - -kernel void kernel_mul_mv_q4_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); -} - -kernel void kernel_mul_mv_q4_1_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); -} - -kernel void kernel_mul_mv_q5_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); -} - -kernel void kernel_mul_mv_q5_1_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); -} - - -#define NB_Q8_0 8 - -void kernel_mul_mv_q8_0_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - const int nr = N_DST; - const int nsg = N_SIMDGROUP; - const int nw = N_SIMDWIDTH; - - const int nb = ne00/QK8_0; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * nsg + sgitg) * nr; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[NB_Q8_0]; - float sumf[nr]={0.f}; - - const int ix = tiisg/4; - const int il = tiisg%4; - - device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; - - // each thread in a SIMD group deals with NB_Q8_0 quants at a time - for (int ib = ix; ib < nb; ib += nw/4) { - for (int i = 0; i < NB_Q8_0; ++i) { - yl[i] = yb[i]; - } - - for (int row = 0; row < nr; row++) { - device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; - float sumq = 0.f; - for (int iq = 0; iq < NB_Q8_0; ++iq) { - sumq += qs[iq] * yl[iq]; - } - sumf[row] += sumq*x[ib+row*nb].d; - } - - yb += NB_Q8_0 * nw; - } - - for (int row = 0; row < nr; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; - } - } -} - -[[host_name("kernel_mul_mv_q8_0_f32")]] -kernel void kernel_mul_mv_q8_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); -} - -#define N_F32_F32 4 - -void kernel_mul_mv_f32_f32_impl( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - int64_t ne10, - int64_t ne11, - int64_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg) { - - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F32_F32; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const float * x = (device const float *) (src0 + offset0); - - if (ne00 < 128) { - for (int row = 0; row < N_F32_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - device const float4 * x4 = (device const float4 *)x; - for (int row = 0; row < N_F32_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - device const float4 * y4 = (device const float4 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -[[host_name("kernel_mul_mv_f32_f32")]] -kernel void kernel_mul_mv_f32_f32( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); -} - -#define N_F16_F16 4 - -kernel void kernel_mul_mv_f16_f16( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F16_F16; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const half * x = (device const half *) (src0 + offset0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (half) x[i] * (half) y[i]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - device const half4 * x4 = (device const half4 *)x; - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); - device const half4 * y4 = (device const half4 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -void kernel_mul_mv_f16_f32_1row_impl( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const half * x = (device const half *) (src0 + offset0); - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - if (ne00 < 128) { - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } else { - device const half4 * x4 = (device const half4 *) x; - device const float4 * y4 = (device const float4 *) y; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; - } - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } -} - -[[host_name("kernel_mul_mv_f16_f32_1row")]] -kernel void kernel_mul_mv_f16_f32_1row( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); -} - -#define N_F16_F32 4 - -void kernel_mul_mv_f16_f32_impl( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - int64_t ne10, - int64_t ne11, - int64_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg) { - - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F16_F32; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const half * x = (device const half *) (src0 + offset0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - device const half4 * x4 = (device const half4 *)x; - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - device const float4 * y4 = (device const float4 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -[[host_name("kernel_mul_mv_f16_f32")]] -kernel void kernel_mul_mv_f16_f32( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); -} - -// Assumes row size (ne00) is a multiple of 4 -kernel void kernel_mul_mv_f16_f32_l4( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int nrows = ne11; - const int64_t r0 = tgpig.x; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const half4 * x4 = (device const half4 *) (src0 + offset0); - - for (int r1 = 0; r1 < nrows; ++r1) { - device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } -} - -static float rope_yarn_ramp(const float low, const float high, const int i0) { - const float y = (i0 / 2 - low) / max(0.001f, high - low); - return 1.0f - min(1.0f, max(0.0f, y)); -} - -// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn -// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -static void rope_yarn( - float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, - thread float * cos_theta, thread float * sin_theta -) { - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = freq_scale * theta_extrap; - float theta = theta_interp; - if (ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; - theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - - // Get n-d magnitude scaling corrected for interpolation - mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); - } - *cos_theta = cos(theta) * mscale; - *sin_theta = sin(theta) * mscale; -} - -// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get -// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { - return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base)); -} - -static void rope_yarn_corr_dims( - int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] -) { - // start and end correction dims - dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); - dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); -} - -typedef void (rope_t)( - device const void * src0, - device const int32_t * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & mode, - constant int & n_orig_ctx, - constant float & freq_base, - constant float & freq_scale, - constant float & ext_factor, - constant float & attn_factor, - constant float & beta_fast, - constant float & beta_slow, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]); - -template -kernel void kernel_rope( - device const void * src0, - device const int32_t * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & mode, - constant int & n_orig_ctx, - constant float & freq_base, - constant float & freq_scale, - constant float & ext_factor, - constant float & attn_factor, - constant float & beta_fast, - constant float & beta_slow, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]) { - const int64_t i3 = tgpig[2]; - const int64_t i2 = tgpig[1]; - const int64_t i1 = tgpig[0]; - - const bool is_neox = mode & 2; - - float corr_dims[2]; - rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); - - device const int32_t * pos = src1; - - const int64_t p = pos[i2]; - - const float theta_0 = (float)p; - const float inv_ndims = -1.f/n_dims; - - if (!is_neox) { - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { - - const float theta = theta_0 * pow(freq_base, inv_ndims*i0); - float cos_theta, sin_theta; - rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const T x0 = src[0]; - const T x1 = src[1]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[1] = x0*sin_theta + x1*cos_theta; - } - } else { - for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) { - if (ic < n_dims) { - const int64_t ib = 0; - - // simplified from `(ib * n_dims + ic) * inv_ndims` - const float cur_rot = inv_ndims*ic - ib; - - const float theta = theta_0 * pow(freq_base, cur_rot); - float cos_theta, sin_theta; - rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); - - const int64_t i0 = ib*n_dims + ic/2; - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - } else { - const int64_t i0 = ic; - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } - } -} - -template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; -template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; - -typedef void (im2col_t)( - device const float * x, - device char * dst, - constant int32_t & ofs0, - constant int32_t & ofs1, - constant int32_t & IW, - constant int32_t & IH, - constant int32_t & CHW, - constant int32_t & s0, - constant int32_t & s1, - constant int32_t & p0, - constant int32_t & p1, - constant int32_t & d0, - constant int32_t & d1, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tgpg[[threadgroups_per_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]); - -template -kernel void kernel_im2col( - device const float * x, - device char * dst, - constant int32_t & ofs0, - constant int32_t & ofs1, - constant int32_t & IW, - constant int32_t & IH, - constant int32_t & CHW, - constant int32_t & s0, - constant int32_t & s1, - constant int32_t & p0, - constant int32_t & p1, - constant int32_t & d0, - constant int32_t & d1, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tgpg[[threadgroups_per_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0; - const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1; - - const int32_t offset_dst = - (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + - (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]); - - device T * pdst = (device T *) (dst); - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - pdst[offset_dst] = 0.0f; - } else { - const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1; - pdst[offset_dst] = x[offset_src + iih * IW + iiw]; - } -} - -template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; -template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; - -kernel void kernel_upscale_f32( - device const char * src0, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int32_t & sf, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; - - const int64_t i03 = i3; - const int64_t i02 = i2; - const int64_t i01 = i1/sf; - - device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); - device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - dst_ptr[i0] = src0_ptr[i0/sf]; - } -} - -kernel void kernel_pad_f32( - device const char * src0, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; - - const int64_t i03 = i3; - const int64_t i02 = i2; - const int64_t i01 = i1; - - device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); - device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); - - if (i1 < ne01 && i2 < ne02 && i3 < ne03) { - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i0 < ne00) { - dst_ptr[i0] = src0_ptr[i0]; - } else { - dst_ptr[i0] = 0.0f; - } - } - - return; - } - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - dst_ptr[i0] = 0.0f; - } -} - -kernel void kernel_arange_f32( - device char * dst, - constant int64_t & ne0, - constant float & start, - constant float & step, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - device float * dst_ptr = (device float *) dst; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - dst_ptr[i0] = start + step * i0; - } -} - -kernel void kernel_timestep_embedding_f32( - device const char * src0, - device char * dst, - constant uint64_t & nb1, - constant int & dim, - constant int & max_period, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - int i = tgpig.x; - device float * embed_data = (device float *)(dst + i*nb1); - - int half_ = dim / 2; - for (int j = tpitg.x; j < half_; j += ntg.x) { - float timestep = ((device float *)src0)[i]; - float freq = (float)exp(-log((float)max_period) * j / half_); - float arg = timestep * freq; - embed_data[j ] = cos(arg); - embed_data[j + half_] = sin(arg); - } - - if (dim % 2 != 0 && tpitg.x == 0) { - embed_data[dim] = 0.f; - } -} - -// bitonic sort implementation following the CUDA kernels as reference -typedef void (argsort_t)( - device const float * x, - device int32_t * dst, - constant int64_t & ncols, - constant int64_t & ncols_pad, - threadgroup int32_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]]); - -template -kernel void kernel_argsort_f32_i32( - device const float * x, - device int32_t * dst, - constant int64_t & ncols, - constant int64_t & ncols_pad, - threadgroup int32_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]]) { - // bitonic sort - int col = tpitg[0]; - int row = tgpig[1]; - - if (col >= ncols_pad) return; - - device const float * x_row = x + row * ncols; - threadgroup int32_t * dst_row = shared_values; - - // initialize indices - dst_row[col] = col; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (int k = 2; k <= ncols_pad; k *= 2) { - for (int j = k / 2; j > 0; j /= 2) { - int ixj = col ^ j; - if (ixj > col) { - if ((col & k) == 0) { - if (dst_row[col] >= ncols || - (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? - x_row[dst_row[col]] > x_row[dst_row[ixj]] : - x_row[dst_row[col]] < x_row[dst_row[ixj]])) - ) { - SWAP(dst_row[col], dst_row[ixj]); - } - } else { - if (dst_row[ixj] >= ncols || - (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? - x_row[dst_row[col]] < x_row[dst_row[ixj]] : - x_row[dst_row[col]] > x_row[dst_row[ixj]])) - ) { - SWAP(dst_row[col], dst_row[ixj]); - } - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - } - - // copy the result to dst without the padding - if (col < ncols) { - dst[row * ncols + col] = dst_row[col]; - } -} - -template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; -template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; - -kernel void kernel_leaky_relu_f32( - device const float * src0, - device float * dst, - constant float & slope, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; -} - -typedef void (flash_attn_ext_f16_t)( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb23, - constant uint64_t & nb31, - constant int64_t & ne1, - constant int64_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, - threadgroup half * shared, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]); - -// ref: https://arxiv.org/pdf/2307.08691.pdf -template // head size, queries per threadgroup, cache items per threadgroup -kernel void kernel_flash_attn_ext_f16( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb23, - constant uint64_t & nb31, - constant int64_t & ne1, - constant int64_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, - threadgroup half * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - const short nsg = ntg.y; // number of simdgroups - - const short iq3 = tgpig[2]; - const short iq2 = tgpig[1]; - const short iq1 = tgpig[0]*Q; - - const short D4 = D/4; - const short D8 = D/8; - //const short Q8 = Q/8; - const short NW = N_SIMDWIDTH; - const short SH = (C + Q); // shared memory per simdgroup in (half) - - const short T = D + 2*nsg*SH; // shared memory size per query in (half) - const short TF = T/2; // shared memory size per query in (float) - const short T4 = T/4; // shared memory size per query in (half4) - - threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix - - // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - simdgroup_half8x8 lo[D8]; - - // load heads from Q to shared memory - for (short j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); - - for (short i = tiisg; i < D4; i += NW) { - if (iq1 + j < ne01) { - sq4[j*T4 + i] = (half4) q4[i]; - } else { - sq4[j*T4 + i] = 0.0h; - } - } - } - - // zero out lo - for (short i = 0; i < D8; ++i) { - lo[i] = make_filled_simdgroup_matrix(0.0h); - } - - // zero out shared memory SH - for (short j = 0; j < Q; ++j) { - for (short i = tiisg; i < SH; i += NW) { - ss[j*TF + i] = 0.0f; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - { - float S[Q] = { [0 ... Q-1] = 0.0h }; - float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; - - // assume K and V are same shape - const short ne22 = ne12; - const short ne23 = ne13; - - // broadcast - const short rk2 = ne02/ne12; - const short rk3 = ne03/ne13; - - const short rv2 = ne02/ne22; - const short rv3 = ne03/ne23; - - // k indices - const short ik2 = iq2/rk2; - const short ik3 = iq3/rk3; - - // v indices - const short iv2 = iq2/rv2; - const short iv3 = iq3/rv3; - - // load the queries from shared memory into local memory - simdgroup_half8x8 mq[D8]; - - for (short i = 0; i < D8; ++i) { - simdgroup_load(mq[i], sq + i*8, T); - } - - // pointer to the mask - device const half * mp = (device const half *) (mask + iq1*nb31); - - // prepare diagonal scale matrix - simdgroup_float8x8 mscale(scale); - - // prepare diagonal slope matrix - simdgroup_float8x8 mslope(1.0f); - - // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = iq2; - - const float base = h < n_head_log2 ? m0 : m1; - const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - mslope = simdgroup_float8x8(pow(base, exph)); - } - - // loop over the KV cache - // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { - const int ic = ic0 + C*sgitg; - if (ic >= ne11) { - break; - } - - // Q*K^T - { - for (short cc = 0; cc < C/8; ++cc) { - simdgroup_float8x8 mqk = make_filled_simdgroup_matrix(0.h); - - device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); - - for (short i = 0; i < D8; ++i) { - simdgroup_half8x8 mk; - simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose - - simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); - } - - if (mask != q) { - // mqk = mqk*scale + mask*slope - simdgroup_half8x8 mm; - simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); - simdgroup_multiply(mm, mslope, mm); - simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); - } else { - // mqk = mqk*scale - simdgroup_multiply(mqk, mscale, mqk); - } - - simdgroup_store(mqk, ss + 8*cc, TF, 0, false); - } - } - - // used to detect blocks full of -INF - float smax = -INFINITY; - - // online softmax - { - float ms[Q]; - - for (short j = 0; j < Q; ++j) { - const short p = tiisg; - - const float m = M[j]; - const float s = ss[j*TF + p]; - - smax = simd_max(max(smax, s)); - M[j] = simd_max(max(M[j], s)); - - ms[j] = exp(m - M[j]); - const float vs = exp(s - M[j]); - - S[j] = S[j]*ms[j] + simd_sum(vs); - - // the P matrix from the paper (Q rows, C columns) - ss[j*TF + p] = vs; - } - - // create a QxQ diagonal matrix for rescaling the output - if (tiisg < Q) { - ss[tiisg*TF + C + tiisg] = ms[tiisg]; - } - } - - // skip -INF blocks - if (smax == -INFINITY) { - continue; - } - - // O = diag(ms)*O - { - simdgroup_float8x8 mm; - simdgroup_load(mm, ss + C, TF, 0, false); - - for (short i = 0; i < D8; ++i) { - simdgroup_multiply(lo[i], mm, lo[i]); - } - } - - // O = O + (Q*K^T)*V - { - for (short cc = 0; cc < C/8; ++cc) { - device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); - - for (short i = 0; i < D8; ++i) { - simdgroup_half8x8 mk; - simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); - - simdgroup_float8x8 mv; - simdgroup_load(mv, ss + 8*cc, TF, 0, false); - - simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]); - } - } - } - } - - // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (short j = 0; j < Q; ++j) { - if (tiisg == 0) { - ss[j*TF + 0] = S[j]; - ss[j*TF + 1] = M[j]; - } - } - } - - // reduce the warps sequentially - for (short sg = 1; sg < nsg; ++sg) { - float S = { 0.0h }; - float M = { -FLT_MAX/2 }; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // each simdgroup stores its output to shared memory, reusing sq - if (sgitg == sg) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // the first simdgroup accumulates the results from the other simdgroups - if (sgitg == 0) { - for (short j = 0; j < Q; ++j) { - const float S0 = ss[j*TF + 0]; - const float S1 = ss[j*TF + sg*SH + 0]; - - const float M0 = ss[j*TF + 1]; - const float M1 = ss[j*TF + sg*SH + 1]; - - M = max(M0, M1); - - const float ms0 = exp(M0 - M); - const float ms1 = exp(M1 - M); - - S = S0*ms0 + S1*ms1; - - if (tiisg == 0) { - ss[j*TF + 0] = S; - ss[j*TF + 1] = M; - - ss[j*TF + C + j ] = ms0; - ss[j*TF + C + j + sg*SH] = ms1; - } - } - - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - { - simdgroup_half8x8 t; - simdgroup_float8x8 ms0; - simdgroup_float8x8 ms1; - - simdgroup_load(ms0, ss + C, TF, 0, false); - simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); - - for (short i = 0; i < D8; ++i) { - simdgroup_load (t, sq + i*8, T, 0, false); - simdgroup_multiply(t, ms1, t); - - simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); - } - } - } - } - - // store result to shared memory (reuse sq) - if (sgitg == 0) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); - } - } - - device float4 * dst4 = (device float4 *) dst; - - // final rescale with 1/S and store to global memory - if (sgitg == 0) { - for (short j = 0; j < Q && iq1 + j < ne01; ++j) { - const float S = ss[j*TF + 0]; - - for (short i = tiisg; i < D4; i += NW) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; - } - } - } -} - -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; -template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; -template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; -template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; - -template // head size, queries per threadgroup, cache items per threadgroup -kernel void kernel_flash_attn_ext_vec_f16( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb23, - constant uint64_t & nb31, - constant int64_t & ne1, - constant int64_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, - threadgroup half * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - const short nsg = ntg.y; // number of simdgroups - - const short iq3 = tgpig[2]; - const short iq2 = tgpig[1]; - const short iq1 = tgpig[0]; - - const short D4 = D/4; - const short NW = N_SIMDWIDTH; - const short SH = (C + Q); // shared memory per simdgroup in (half) - - const short T = D + 2*nsg*SH; // shared memory size per query in (half) - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = iq2; - - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix - threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 - threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results - - // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - half4 lo[D4/NW]; - - // load heads from Q to shared memory - device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); - - for (short i = tiisg; i < D4; i += NW) { - if (iq1 < ne01) { - sq4[i] = (half4) q4[i]; - } else { - sq4[i] = 0.0h; - } - } - - // zero out lo - for (short i = tiisg; i < D4; i += NW) { - lo[i/NW] = 0.0h; - } - - // zero out shared memory SH - for (short i = tiisg; i < SH/4; i += NW) { - ss4[i] = 0.0h; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - { - float S = { 0.0h }; - float M = { -FLT_MAX/2 }; - - // assume K and V are same shape - const short ne22 = ne12; - const short ne23 = ne13; - - // broadcast - const short rk2 = ne02/ne12; - const short rk3 = ne03/ne13; - - const short rv2 = ne02/ne22; - const short rv3 = ne03/ne23; - - // k indices - const short ik2 = iq2 / rk2; - const short ik3 = iq3 / rk3; - - // v indices - const short iv2 = iq2 / rv2; - const short iv3 = iq3 / rv3; - - // load the queries from shared memory into local memory - half4 mq[D4]; - - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - mq[i] = sq4[i]; - } - - // pointer to the mask - device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); - - // loop over the KV cache - // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { - const int ic = ic0 + C*sgitg; - if (ic >= ne11) { - break; - } - - // Q*K^T - { -#pragma unroll - for (short cc = 0; cc < C/4; ++cc) { - float4 mqk = { 0.0h }; - - device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); - -#pragma unroll - for (short ii = 0; ii < D4; ii += NW) { - const short i = ii + tiisg; - - half4x4 mk; - mk[0] = pk4[i + 0*(nb11/8)]; - mk[1] = pk4[i + 1*(nb11/8)]; - mk[2] = pk4[i + 2*(nb11/8)]; - mk[3] = pk4[i + 3*(nb11/8)]; - - mqk += (float4) (mq[i] * mk); - } - - // reduce the results from the threads in the simdgroup - mqk += simd_shuffle_down(mqk, 16); - mqk += simd_shuffle_down(mqk, 8); - mqk += simd_shuffle_down(mqk, 4); - mqk += simd_shuffle_down(mqk, 2); - mqk += simd_shuffle_down(mqk, 1); - - // mqk = mqk*scale + mask*slope - if (tiisg == 0) { - mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f); - - ss4[cc] = mqk; - } - } - } - - // online softmax - { - const short p = tiisg; - - const float m = M; - const float s = ss[p]; - - M = simd_max(max(M, s)); - - const float ms = exp(m - M); - const float vs = exp(s - M); - - S = S*ms + simd_sum(vs); - - // the P matrix from the paper (Q rows, C columns) - ss[p] = vs; - - // O = diag(ms)*O -#pragma unroll - for (short ii = 0; ii < D4; ii += NW) { - const short i = ii + tiisg; - lo[i/NW] *= ms; - } - } - - // O = O + (Q*K^T)*V - { -#pragma unroll - for (short cc = 0; cc < C/4; ++cc) { - device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); - -#pragma unroll - for (short ii = 0; ii < D4; ii += NW) { - const short i = ii + tiisg; - - lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; - lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; - lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; - lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; - } - } - } - - } - - // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - if (tiisg == 0) { - ss[0] = S; - ss[1] = M; - } - } - - // store results to shared memory - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - sr4[i] = lo[ii/NW]; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // parallel reduce - for (short r = nsg/2; r > 0; r >>= 1) { - if (sgitg < r) { - const float S0 = ss[ 0]; - const float S1 = ss[r*SH + 0]; - - const float M0 = ss[ 1]; - const float M1 = ss[r*SH + 1]; - - const float M = max(M0, M1); - - const float ms0 = exp(M0 - M); - const float ms1 = exp(M1 - M); - - const float S = S0*ms0 + S1*ms1; - - if (tiisg == 0) { - ss[0] = S; - ss[1] = M; - } - - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - device float4 * dst4 = (device float4 *) dst; - - // final rescale with 1/S and store to global memory - if (sgitg == 0) { - const float S = ss[0]; - - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S; - } - } -} - -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; - -kernel void kernel_cpy_f16_f16( - device const half * src0, - device half * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f16_f32( - device const half * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f16( - device const float * src0, - device half * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - // TODO: is there a better way to handle -INFINITY? - dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0]; - } -} - -kernel void kernel_cpy_f32_f32( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_q8_0( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0; - - device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - const float v = src[j]; - amax = MAX(amax, fabs(v)); - } - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK8_0].d = d; - - for (int j = 0; j < QK8_0; ++j) { - const float x0 = src[j]*id; - - dst_data[i00/QK8_0].qs[j] = round(x0); - } - } -} - -kernel void kernel_cpy_f32_q4_0( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0; - - device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < QK4_0; j++) { - const float v = src[j]; - if (amax < fabs(v)) { - amax = fabs(v); - max = v; - } - } - - const float d = max / -8; - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK4_0].d = d; - - for (int j = 0; j < QK4_0/2; ++j) { - const float x0 = src[0 + j]*id; - const float x1 = src[QK4_0/2 + j]*id; - - const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); - const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); - - dst_data[i00/QK4_0].qs[j] = xi0; - dst_data[i00/QK4_0].qs[j] |= xi1 << 4; - } - } -} - -kernel void kernel_cpy_f32_q4_1( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1; - - device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float min = FLT_MAX; - float max = -FLT_MAX; - - for (int j = 0; j < QK4_1; j++) { - const float v = src[j]; - if (min > v) min = v; - if (max < v) max = v; - } - - const float d = (max - min) / ((1 << 4) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK4_1].d = d; - dst_data[i00/QK4_1].m = min; - - for (int j = 0; j < QK4_1/2; ++j) { - const float x0 = (src[0 + j] - min)*id; - const float x1 = (src[QK4_1/2 + j] - min)*id; - - const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); - const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); - - dst_data[i00/QK4_1].qs[j] = xi0; - dst_data[i00/QK4_1].qs[j] |= xi1 << 4; - } - } -} - -kernel void kernel_cpy_f32_q5_0( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0; - - device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < QK5_0; j++) { - const float v = src[j]; - if (amax < fabs(v)) { - amax = fabs(v); - max = v; - } - } - - const float d = max / -16; - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK5_0].d = d; - - uint32_t qh = 0; - for (int j = 0; j < QK5_0/2; ++j) { - const float x0 = src[0 + j]*id; - const float x1 = src[QK5_0/2 + j]*id; - - const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); - const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); - - dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); - qh |= ((xi0 & 0x10u) >> 4) << (j + 0); - qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); - } - thread const uint8_t * qh8 = (thread const uint8_t *)&qh; - for (int j = 0; j < 4; ++j) { - dst_data[i00/QK5_0].qh[j] = qh8[j]; - } - } -} - -kernel void kernel_cpy_f32_q5_1( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1; - - device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float max = src[0]; - float min = src[0]; - - for (int j = 1; j < QK5_1; j++) { - const float v = src[j]; - min = v < min ? v : min; - max = v > max ? v : max; - } - - const float d = (max - min) / 31; - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK5_1].d = d; - dst_data[i00/QK5_1].m = min; - - uint32_t qh = 0; - for (int j = 0; j < QK5_1/2; ++j) { - const float x0 = (src[0 + j] - min)*id; - const float x1 = (src[QK5_1/2 + j] - min)*id; - - const uint8_t xi0 = (uint8_t)(x0 + 0.5f); - const uint8_t xi1 = (uint8_t)(x1 + 0.5f); - - dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); - qh |= ((xi0 & 0x10u) >> 4) << (j + 0); - qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); - } - thread const uint8_t * qh8 = (thread const uint8_t *)&qh; - for (int j = 0; j < 4; ++j) { - dst_data[i00/QK5_1].qh[j] = qh8[j]; - } - } -} - -static inline int best_index_int8(int n, constant float * val, float x) { - if (x <= val[0]) return 0; - if (x >= val[n-1]) return n-1; - int ml = 0, mu = n-1; - while (mu-ml > 1) { - int mav = (ml+mu)/2; - if (x < val[mav]) mu = mav; else ml = mav; - } - return x - val[mu-1] < val[mu] - x ? mu-1 : mu; -} - -constexpr constant static float kvalues_iq4nl_f[16] = { - -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f -}; - -kernel void kernel_cpy_f32_iq4_nl( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL; - - device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < QK4_0; j++) { - const float v = src[j]; - if (amax < fabs(v)) { - amax = fabs(v); - max = v; - } - } - - const float d = max / kvalues_iq4nl_f[0]; - const float id = d ? 1.0f/d : 0.0f; - - float sumqx = 0, sumq2 = 0; - for (int j = 0; j < QK4_NL/2; ++j) { - const float x0 = src[0 + j]*id; - const float x1 = src[QK4_NL/2 + j]*id; - - const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0); - const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1); - - dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4); - - const float v0 = kvalues_iq4nl_f[xi0]; - const float v1 = kvalues_iq4nl_f[xi1]; - const float w0 = src[0 + j]*src[0 + j]; - const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j]; - sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j]; - sumq2 += w0*v0*v0 + w1*v1*v1; - - } - - dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d; - - } -} - -kernel void kernel_concat( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i02 < ne02) { - ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0]; - src0_ptr += ntg.x*nb00; - } else { - ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0]; - src1_ptr += ntg.x*nb10; - } - dst_ptr += ntg.x*nb0; - } -} - -void kernel_mul_mv_q2_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q2_K) * nb; - -#if QK_K == 256 - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int iq = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - const int is = (8*ir)/16;// 0 or 1 - - device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; - - for (int ib = ix; ib < nb; ib += 4) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; - yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; - yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; - } - - device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); - acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); - acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); - acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); - acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); - acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); - acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); - acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); - } - float dall = dh[0]; - float dmin = dh[1] * 1.f/16.f; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + - (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f + - (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f + - (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - - dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); - - qs += step/2; - sc += step; - dh += step/2; - } - - y4 += 4 * QK_K; - } -#else - const int ix = tiisg/2; // 0...15 - const int it = tiisg%2; // 0...1 - - device const float * y4 = y + ix * QK_K + 8 * it; - - for (int ib = ix; ib < nb; ib += 16) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8]; - yl[i+16] = y4[i+32]; sumy[2] += yl[i+16]; - yl[i+24] = y4[i+48]; sumy[3] += yl[i+24]; - } - - device const uint8_t * sc = (device const uint8_t *)x[ib].scales; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); - acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); - acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); - acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); - acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); - acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); - acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); - acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + - (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f + - (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f + - (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) - - dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4)); - - qs += step/2; - sc += step; - dh += step/2; - } - - y4 += 16 * QK_K; - } -#endif - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -[[host_name("kernel_mul_mv_q2_K_f32")]] -kernel void kernel_mul_mv_q2_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -#if QK_K == 256 -void kernel_mul_mv_q3_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - - //const uint16_t kmask1 = 0x3030; - //const uint16_t kmask2 = 0x0f0f; - - const int tid = tiisg/4; - const int ix = tiisg%4; - const int ip = tid/4; // 0 or 1 - const int il = 2*((tid%4)/2); // 0 or 2 - const int ir = tid%2; - const int n = 8; - const int l0 = n*ir; - - // One would think that the Metal compiler would figure out that ip and il can only have - // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it - // with these two tales. - // - // Possible masks for the high bit - const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0 - {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2 - {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0 - {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2 - - // Possible masks for the low 2 bits - const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}}; - - const ushort4 hm = mm[2*ip + il/2]; - - const int shift = 2*il; - const float v1 = il == 0 ? 4.f : 64.f; - const float v2 = 4.f * v1; - - const uint16_t s_shift1 = 4*ip; - const uint16_t s_shift2 = s_shift1 + il; - - const int q_offset = 32*ip + l0; - const int y_offset = 128*ip + 32*il + l0; - - const int step = sizeof(block_q3_K) * nb / 2; - - device const float * y1 = yy + ix*QK_K + y_offset; - - uint32_t scales32, aux32; - thread uint16_t * scales16 = (thread uint16_t *)&scales32; - thread const int8_t * scales = (thread const int8_t *)&scales32; - - float sumf1[2] = {0.f}; - float sumf2[2] = {0.f}; - for (int i = ix; i < nb; i += 4) { - - for (int l = 0; l < 8; ++l) { - yl[l+ 0] = y1[l+ 0]; - yl[l+ 8] = y1[l+16]; - yl[l+16] = y1[l+32]; - yl[l+24] = y1[l+48]; - } - - device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); - device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); - device const uint16_t * a = (device const uint16_t *)(x[i].scales); - device const half * dh = &x[i].d; - - for (int row = 0; row < 2; ++row) { - - const float d_all = (float)dh[0]; - - scales16[0] = a[4]; - scales16[1] = a[5]; - aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030; - scales16[0] = a[il+0]; - scales16[1] = a[il+1]; - scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; - - float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; - for (int l = 0; l < n; l += 2) { - const int32_t qs = q[l/2]; - s1 += yl[l+0] * (qs & qm[il/2][0]); - s2 += yl[l+1] * (qs & qm[il/2][1]); - s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]); - s4 += yl[l+16] * (qs & qm[il/2][2]); - s5 += yl[l+17] * (qs & qm[il/2][3]); - s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]); - } - float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); - float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); - sumf1[row] += d1 * (scales[0] - 32); - sumf2[row] += d2 * (scales[2] - 32); - - s1 = s2 = s3 = s4 = s5 = s6 = 0; - for (int l = 0; l < n; l += 2) { - const int32_t qs = q[l/2+8]; - s1 += yl[l+8] * (qs & qm[il/2][0]); - s2 += yl[l+9] * (qs & qm[il/2][1]); - s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]); - s4 += yl[l+24] * (qs & qm[il/2][2]); - s5 += yl[l+25] * (qs & qm[il/2][3]); - s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]); - } - d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); - d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); - sumf1[row] += d1 * (scales[1] - 32); - sumf2[row] += d2 * (scales[3] - 32); - - q += step; - h += step; - a += step; - dh += step; - - } - - y1 += 4 * QK_K; - - } - - for (int row = 0; row < 2; ++row) { - const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); - sumf1[row] = simd_sum(sumf); - } - if (tiisg == 0) { - for (int row = 0; row < 2; ++row) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row]; - } - } -} -#else -void kernel_mul_mv_q3_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; - - const int row = 2 * r0 + sgitg; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - const int ix = tiisg/4; - const int il = 4 * (tiisg%4);// 0, 4, 8, 12 - const int iq = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 - - float2 sum = {0.f, 0.f}; - - for (int i = ix; i < nb; i += 8) { - - const float d_all = (float)(x[i].d); - - device const uint16_t * q = (device const uint16_t *)(x[i].qs + il); - device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in); - device const uint16_t * s = (device const uint16_t *)(x[i].scales); - device const float * y = yy + i * QK_K + il; - - const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8); - const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f; - const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f; - const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f; - - for (int l = 0; l < 4; l += 2) { - const uint16_t hm = h[l/2] >> iq; - sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4)) - + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16)) - + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64)) - + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256)); - sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024)) - + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096)) - + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384)) - + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536)); - } - - } - const float sumf = sum[0] + sum[1] * 1.f/256.f; - - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + row] = tot; - } - -} -#endif - -[[host_name("kernel_mul_mv_q3_K_f32")]] -kernel void kernel_mul_mv_q3_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -#if QK_K == 256 -void kernel_mul_mv_q4_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int iq = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int first_row = r0 * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[16]; - float yh[16]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q4_K) * nb / 2; - - device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; - - uint16_t sc16[4]; - thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - - for (int ib = ix; ib < nb; ib += 4) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; - yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; - yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; - yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; - } - - device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq; - device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - sc16[0] = sc[0] & kmask1; - sc16[1] = sc[2] & kmask1; - sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); - sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); - - device const uint16_t * q2 = q1 + 32; - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); - acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); - acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); - acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); - acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); - acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); - acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); - acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + - (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + - (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + - (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - - q1 += step; - sc += step; - dh += step; - } - - y4 += 4 * QK_K; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} -#else -void kernel_mul_mv_q4_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int ix = tiisg/4; // 0...7 - const int it = tiisg%4; // 0...3 - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - const int first_row = r0 * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[8]; - float yh[8]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q4_K) * nb / 2; - - device const float * y4 = y + ix * QK_K + 8 * it; - - uint16_t sc16[4]; - - for (int ib = ix; ib < nb; ib += 8) { - - float2 sumy = {0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i] = y4[i+ 0]; sumy[0] += yl[i]; - yh[i] = y4[i+32]; sumy[1] += yh[i]; - } - - device const uint16_t * sc = (device const uint16_t *)x[ib].scales; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; - device const half * dh = x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - sc16[0] = sc[0] & 0x000f; - sc16[1] = sc[0] & 0x0f00; - sc16[2] = sc[0] & 0x00f0; - sc16[3] = sc[0] & 0xf000; - - float2 acc1 = {0.f, 0.f}; - float2 acc2 = {0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (qs[i/2] & 0x000F); - acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00); - acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0); - acc2[1] += yh[i+1] * (qs[i/2] & 0xF000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] + - (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) - - dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f); - - qs += step; - sc += step; - dh += step; - } - - y4 += 8 * QK_K; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} -#endif - -[[host_name("kernel_mul_mv_q4_K_f32")]] -kernel void kernel_mul_mv_q4_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_q5_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float sumf[2]={0.f}; - - const int step = sizeof(block_q5_K) * nb; - -#if QK_K == 256 -# - float yl[16], yh[16]; - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int tid = tiisg/4; - const int ix = tiisg%4; - const int iq = tid/4; - const int ir = tid%4; - const int n = 8; - - const int l0 = n*ir; - const int q_offset = 32*iq + l0; - const int y_offset = 64*iq + l0; - - const uint8_t hm1 = 1u << (2*iq); - const uint8_t hm2 = hm1 << 1; - const uint8_t hm3 = hm1 << 4; - const uint8_t hm4 = hm2 << 4; - - uint16_t sc16[4]; - thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - - device const float * y1 = yy + ix*QK_K + y_offset; - - for (int i = ix; i < nb; i += 4) { - - device const uint8_t * q1 = x[i].qs + q_offset; - device const uint8_t * qh = x[i].qh + l0; - device const half * dh = &x[i].d; - device const uint16_t * a = (device const uint16_t *)x[i].scales + iq; - - device const float * y2 = y1 + 128; - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 8; ++l) { - yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; - yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; - yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; - yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; - } - - for (int row = 0; row < 2; ++row) { - - device const uint8_t * q2 = q1 + 64; - - sc16[0] = a[0] & kmask1; - sc16[1] = a[2] & kmask1; - sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); - sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); - - float4 acc1 = {0.f}; - float4 acc2 = {0.f}; - for (int l = 0; l < n; ++l) { - uint8_t h = qh[l]; - acc1[0] += yl[l+0] * (q1[l] & 0x0F); - acc1[1] += yl[l+8] * (q1[l] & 0xF0); - acc1[2] += yh[l+0] * (q2[l] & 0x0F); - acc1[3] += yh[l+8] * (q2[l] & 0xF0); - acc2[0] += h & hm1 ? yl[l+0] : 0.f; - acc2[1] += h & hm2 ? yl[l+8] : 0.f; - acc2[2] += h & hm3 ? yh[l+0] : 0.f; - acc2[3] += h & hm4 ? yh[l+8] : 0.f; - } - const float dall = dh[0]; - const float dmin = dh[1]; - sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + - sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + - sc8[4] * (acc1[2] + 16.f*acc2[2]) + - sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - - q1 += step; - qh += step; - dh += step/2; - a += step/2; - - } - - y1 += 4 * QK_K; - - } -#else - float yl[8], yh[8]; - - const int il = 4 * (tiisg/8); // 0, 4, 8, 12 - const int ix = tiisg%8; - const int iq = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 - - device const float * y = yy + ix*QK_K + il; - - for (int i = ix; i < nb; i += 8) { - - for (int l = 0; l < 4; ++l) { - yl[l+0] = y[l+ 0]; - yl[l+4] = y[l+16]; - yh[l+0] = y[l+32]; - yh[l+4] = y[l+48]; - } - - device const half * dh = &x[i].d; - device const uint8_t * q = x[i].qs + il; - device const uint8_t * h = x[i].qh + in; - device const int8_t * s = x[i].scales; - - for (int row = 0; row < 2; ++row) { - - const float d = dh[0]; - - float2 acc = {0.f, 0.f}; - for (int l = 0; l < 4; ++l) { - const uint8_t hl = h[l] >> iq; - acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16)) - + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16)); - acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256)) - + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256)); - } - sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]); - - q += step; - h += step; - s += step; - dh += step/2; - - } - - y += 8 * QK_K; - } -#endif - - for (int row = 0; row < 2; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; - } - } -} - -[[host_name("kernel_mul_mv_q5_K_f32")]] -kernel void kernel_mul_mv_q5_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_q6_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const uint8_t kmask1 = 0x03; - const uint8_t kmask2 = 0x0C; - const uint8_t kmask3 = 0x30; - const uint8_t kmask4 = 0xC0; - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int im = tgpig.z; - - const int row = 2 * r0 + sgitg; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float sumf = 0; - -#if QK_K == 256 - const int tid = tiisg/2; - const int ix = tiisg%2; - const int ip = tid/8; // 0 or 1 - const int il = tid%8; - const int n = 4; - const int l0 = n*il; - const int is = 8*ip + l0/16; - - const int y_offset = 128*ip + l0; - const int q_offset_l = 64*ip + l0; - const int q_offset_h = 32*ip + l0; - - for (int i = ix; i < nb; i += 2) { - - device const uint8_t * q1 = x[i].ql + q_offset_l; - device const uint8_t * q2 = q1 + 32; - device const uint8_t * qh = x[i].qh + q_offset_h; - device const int8_t * sc = x[i].scales + is; - - device const float * y = yy + i * QK_K + y_offset; - - const float dall = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < n; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); - sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); - } - - sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); - - } - -#else - const int ix = tiisg/4; - const int il = 4*(tiisg%4); - - for (int i = ix; i < nb; i += 8) { - device const float * y = yy + i * QK_K + il; - device const uint8_t * ql = x[i].ql + il; - device const uint8_t * qh = x[i].qh + il; - device const int8_t * s = x[i].scales; - - const float d = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 4; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32); - sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); - } - sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]); - } - -#endif - - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + row] = tot; - } -} - -[[host_name("kernel_mul_mv_q6_K_f32")]] -kernel void kernel_mul_mv_q6_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -// ======================= "True" 2-bit - -void kernel_mul_mv_iq2_xxs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); - { - int nval = 4; - int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i]; - nval = 2; - pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { - yl[i] = y4[i]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq2_xxs * xr = x + ibl; - device const uint16_t * q2 = xr->qs + 4 * ib; - device const half * dh = &xr->d; - - for (int row = 0; row < N_DST; row++) { - - const float db = dh[0]; - device const uint8_t * aux8 = (device const uint8_t *)q2; - const uint32_t aux32 = q2[2] | (q2[3] << 16); - const float d = db * (0.5f + (aux32 >> 28)); - - float sum = 0; - for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]); - const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; - for (int j = 0; j < 8; ++j) { - sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); - } - } - sumf[row] += d * sum; - - dh += nb*sizeof(block_iq2_xxs)/2; - q2 += nb*sizeof(block_iq2_xxs)/2; - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; - } - } -} - -[[host_name("kernel_mul_mv_iq2_xxs_f32")]] -kernel void kernel_mul_mv_iq2_xxs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_iq2_xs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512); - { - int nval = 8; - int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i]; - nval = 2; - pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { - yl[i] = y4[i]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq2_xs * xr = x + ibl; - device const uint16_t * q2 = xr->qs + 4 * ib; - device const uint8_t * sc = xr->scales + ib; - device const half * dh = &xr->d; - - for (int row = 0; row < N_DST; row++) { - - const float db = dh[0]; - const uint8_t ls1 = sc[0] & 0xf; - const uint8_t ls2 = sc[0] >> 4; - const float d1 = db * (0.5f + ls1); - const float d2 = db * (0.5f + ls2); - - float sum1 = 0, sum2 = 0; - for (int l = 0; l < 2; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); - const uint8_t signs = shared_signs[(q2[l] >> 9)]; - for (int j = 0; j < 8; ++j) { - sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); - } - } - for (int l = 2; l < 4; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); - const uint8_t signs = shared_signs[(q2[l] >> 9)]; - for (int j = 0; j < 8; ++j) { - sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); - } - } - sumf[row] += d1 * sum1 + d2 * sum2; - - dh += nb*sizeof(block_iq2_xs)/2; - q2 += nb*sizeof(block_iq2_xs)/2; - sc += nb*sizeof(block_iq2_xs); - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; - } - } -} - -[[host_name("kernel_mul_mv_iq2_xs_f32")]] -kernel void kernel_mul_mv_iq2_xs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_iq3_xxs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); - { - int nval = 4; - int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i]; - nval = 2; - pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { - yl[i] = y4[i]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq3_xxs * xr = x + ibl; - device const uint8_t * q3 = xr->qs + 8 * ib; - device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib; - device const half * dh = &xr->d; - - for (int row = 0; row < N_DST; row++) { - - const float db = dh[0]; - const uint32_t aux32 = gas[0] | (gas[1] << 16); - const float d = db * (0.5f + (aux32 >> 28)); - - float2 sum = {0}; - for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]); - const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]); - const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; - for (int j = 0; j < 4; ++j) { - sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); - sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); - } - } - sumf[row] += d * (sum[0] + sum[1]); - - dh += nb*sizeof(block_iq3_xxs)/2; - q3 += nb*sizeof(block_iq3_xxs); - gas += nb*sizeof(block_iq3_xxs)/2; - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f; - } - } -} - -[[host_name("kernel_mul_mv_iq3_xxs_f32")]] -kernel void kernel_mul_mv_iq3_xxs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_iq3_s_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; - { - int nval = 8; - int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { - yl[i] = y4[i]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq3_s * xr = x + ibl; - device const uint8_t * qs = xr->qs + 8 * ib; - device const uint8_t * qh = xr->qh + ib; - device const uint8_t * sc = xr->scales + (ib/2); - device const uint8_t * signs = xr->signs + 4 * ib; - device const half * dh = &xr->d; - - for (int row = 0; row < N_DST; row++) { - - const float db = dh[0]; - const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); - - float2 sum = {0}; - for (int l = 0; l < 4; ++l) { - const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values; - const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values; - const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); - const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); - for (int j = 0; j < 4; ++j) { - sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); - sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); - } - } - sumf[row] += d * (sum[0] + sum[1]); - - dh += nb*sizeof(block_iq3_s)/2; - qs += nb*sizeof(block_iq3_s); - qh += nb*sizeof(block_iq3_s); - sc += nb*sizeof(block_iq3_s); - signs += nb*sizeof(block_iq3_s); - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -[[host_name("kernel_mul_mv_iq3_s_f32")]] -kernel void kernel_mul_mv_iq3_s_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_iq2_s_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - //{ - // int nval = 32; - // int pos = (32*sgitg + tiisg)*nval; - // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i]; - // threadgroup_barrier(mem_flags::mem_threadgroup); - //} - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { - yl[i] = y4[i]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq2_s * xr = x + ibl; - device const uint8_t * qs = xr->qs + 4 * ib; - device const uint8_t * qh = xr->qh + ib; - device const uint8_t * sc = xr->scales + ib; - device const uint8_t * signs = qs + QK_K/8; - device const half * dh = &xr->d; - - for (int row = 0; row < N_DST; row++) { - - const float db = dh[0]; - const float d1 = db * (0.5f + (sc[0] & 0xf)); - const float d2 = db * (0.5f + (sc[0] >> 4)); - - float2 sum = {0}; - for (int l = 0; l < 2; ++l) { - //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); - //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); - constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); - constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); - for (int j = 0; j < 8; ++j) { - sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]); - sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]); - } - } - sumf[row] += d1 * sum[0] + d2 * sum[1]; - - dh += nb*sizeof(block_iq2_s)/2; - qs += nb*sizeof(block_iq2_s); - qh += nb*sizeof(block_iq2_s); - sc += nb*sizeof(block_iq2_s); - signs += nb*sizeof(block_iq2_s); - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; - } - } -} - -[[host_name("kernel_mul_mv_iq2_s_f32")]] -kernel void kernel_mul_mv_iq2_s_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_iq1_s_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_value, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - float sumy = 0; - for (int i = 0; i < 32; ++i) { - yl[i] = y4[i]; - sumy += yl[i]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq1_s * xr = x + ibl; - device const uint8_t * qs = xr->qs + 4 * ib; - device const uint16_t * qh = xr->qh + ib; - device const half * dh = &xr->d; - - for (int row = 0; row < N_DST; row++) { - - constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); - constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700))); - constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700))); - constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700))); - - float sum = 0; - for (int j = 0; j < 4; ++j) { - sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) - + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) - + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) - + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); - } - sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1); - - dh += nb*sizeof(block_iq1_s)/2; - qs += nb*sizeof(block_iq1_s); - qh += nb*sizeof(block_iq1_s)/2; - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -void kernel_mul_mv_iq1_m_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_value, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - -#if QK_K != 64 - iq1m_scale_t scale; -#endif - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - float4 sumy = {0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; - yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; - yl[i+24] = y4[i+24]; sumy[3] += yl[i+24]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq1_m * xr = x + ibl; - device const uint8_t * qs = xr->qs + 4 * ib; - device const uint8_t * qh = xr->qh + 2 * ib; - device const uint16_t * sc = (device const uint16_t *)xr->scales; - - for (int row = 0; row < N_DST; row++) { - -#if QK_K != 64 - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); -#endif - - constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); - constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); - constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700))); - constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700))); - - float2 sum = {0.f}; - for (int j = 0; j < 4; ++j) { - sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) - + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4); - sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) - + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); - } - const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); - const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); -#if QK_K == 64 - const float d = (float) *((device const half *)(sc - 1)); - sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) + - (sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1)); -#else - sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + - (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); -#endif - - sc += nb*sizeof(block_iq1_m)/2; - qs += nb*sizeof(block_iq1_m); - qh += nb*sizeof(block_iq1_m); - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -void kernel_mul_mv_iq4_nl_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values_i8, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - threadgroup float * shared_values = (threadgroup float *)shared_values_i8; - const int nb = ne00/QK4_NL; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - const int first_row = (r0 * 2 + sgitg) * 2; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - const int ix = tiisg/2; // 0...15 - const int it = tiisg%2; // 0 or 1 - - shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; - threadgroup_barrier(mem_flags::mem_threadgroup); - - float4 yl[4]; - float sumf[2]={0.f}, all_sum; - - device const float * yb = y + ix * QK4_NL + it * 8; - - uint32_t aux32[2]; - thread const uint8_t * q8 = (thread const uint8_t *)aux32; - - float4 qf1, qf2; - - for (int ib = ix; ib < nb; ib += 16) { - - device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - - for (int row = 0; row < 2; ++row) { - - device const block_iq4_nl & xb = x[row*nb + ib]; - device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); - - float4 acc1 = {0.f}, acc2 = {0.f}; - - aux32[0] = q4[0] | (q4[1] << 16); - aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; - aux32[0] &= 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; - acc1 += yl[0] * qf1; - acc2 += yl[1] * qf2; - - aux32[0] = q4[2] | (q4[3] << 16); - aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; - aux32[0] &= 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; - acc1 += yl[2] * qf1; - acc2 += yl[3] * qf2; - - acc1 += acc2; - - sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); - - } - - yb += 16 * QK4_NL; - } - - for (int row = 0; row < 2; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -#if QK_K != 64 -void kernel_mul_mv_iq4_xs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values_i8, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - threadgroup float * shared_values = (threadgroup float *)shared_values_i8; - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - const int first_row = (r0 * 2 + sgitg) * 2; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - const int ix = tiisg/16; // 0 or 1 - const int it = tiisg%16; // 0...15 - const int ib = it/2; - const int il = it%2; - - shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; - threadgroup_barrier(mem_flags::mem_threadgroup); - - float4 yl[4]; - float sumf[2]={0.f}, all_sum; - - device const float * yb = y + ix * QK_K + ib * 32 + il * 8; - - uint32_t aux32[2]; - thread const uint8_t * q8 = (thread const uint8_t *)aux32; - - float4 qf1, qf2; - - for (int ibl = ix; ibl < nb; ibl += 2) { - - device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - - for (int row = 0; row < 2; ++row) { - - device const block_iq4_xs & xb = x[row*nb + ibl]; - device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); - - float4 acc1 = {0.f}, acc2 = {0.f}; - - aux32[0] = q4[0] & 0x0f0f0f0f; - aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; - acc1 += yl[0] * qf1; - acc2 += yl[1] * qf2; - - aux32[0] = q4[1] & 0x0f0f0f0f; - aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; - acc1 += yl[2] * qf1; - acc2 += yl[3] * qf2; - - acc1 += acc2; - - const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; - sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); - - } - - yb += 2 * QK_K; - } - - for (int row = 0; row < 2; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} -#endif - -[[host_name("kernel_mul_mv_iq1_s_f32")]] -kernel void kernel_mul_mv_iq1_s_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -[[host_name("kernel_mul_mv_iq1_m_f32")]] -kernel void kernel_mul_mv_iq1_m_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -[[host_name("kernel_mul_mv_iq4_nl_f32")]] -kernel void kernel_mul_mv_iq4_nl_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -} - -[[host_name("kernel_mul_mv_iq4_xs_f32")]] -kernel void kernel_mul_mv_iq4_xs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - -#if QK_K == 64 - kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -#else - kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -#endif -} - -//============================= templates and their specializations ============================= - -// NOTE: this is not dequantizing - we are simply fitting the template -template -void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { - float4x4 temp = *(((device float4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} - -template -void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { - half4x4 temp = *(((device half4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} - -template -void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 1); - const float d1 = il ? (xb->d / 16.h) : xb->d; - const float d2 = d1 / 256.f; - const float md = -8.h * xb->d; - const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = mask0 << 8; - - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; - reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; - } -} - -template -void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 2); - const float d1 = il ? (xb->d / 16.h) : xb->d; - const float d2 = d1 / 256.f; - const float m = xb->m; - const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = mask0 << 8; - - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m; - reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m; - } -} - -template -void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 3); - const float d = xb->d; - const float md = -16.h * xb->d; - const ushort mask = il ? 0x00F0 : 0x000F; - - const uint32_t qh = *((device const uint32_t *)xb->qh); - - const int x_mv = il ? 4 : 0; - - const int gh_mv = il ? 12 : 0; - const int gh_bk = il ? 0 : 4; - - for (int i = 0; i < 8; i++) { - // extract the 5-th bits for x0 and x1 - const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; - const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; - - // combine the 4-bits from qs with the 5th bit - const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); - const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); - - reg[i/2][2*(i%2)+0] = d * x0 + md; - reg[i/2][2*(i%2)+1] = d * x1 + md; - } -} - -template -void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 4); - const float d = xb->d; - const float m = xb->m; - const ushort mask = il ? 0x00F0 : 0x000F; - - const uint32_t qh = *((device const uint32_t *)xb->qh); - - const int x_mv = il ? 4 : 0; - - const int gh_mv = il ? 12 : 0; - const int gh_bk = il ? 0 : 4; - - for (int i = 0; i < 8; i++) { - // extract the 5-th bits for x0 and x1 - const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; - const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; - - // combine the 4-bits from qs with the 5th bit - const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); - const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); - - reg[i/2][2*(i%2)+0] = d * x0 + m; - reg[i/2][2*(i%2)+1] = d * x1 + m; - } -} - -template -void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { - device const int8_t * qs = ((device const int8_t *)xb->qs); - const half d = xb->d; - - for (int i = 0; i < 16; i++) { - reg[i/4][i%4] = (qs[i + 16*il] * d); - } -} - -template -void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { - const float d = xb->d; - const float min = xb->dmin; - device const uint8_t * q = (device const uint8_t *)xb->qs; - float dl, ml; - uint8_t sc = xb->scales[il]; - -#if QK_K == 256 - q = q + 32*(il/8) + 16*(il&1); - il = (il/2)%4; -#endif - half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - ml; - } -} - -template -void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { - const half d_all = xb->d; - device const uint8_t * q = (device const uint8_t *)xb->qs; - device const uint8_t * h = (device const uint8_t *)xb->hmask; - device const int8_t * scales = (device const int8_t *)xb->scales; - -#if QK_K == 256 - q = q + 32 * (il/8) + 16 * (il&1); - h = h + 16 * (il&1); - uint8_t m = 1 << (il/2); - uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ - ((il/4)>0 ? 12 : 3); - uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; - uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; - int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) - : (scale_2&kmask2) | ((scale_1&kmask1) << 4); - float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); - const float ml = 4.f * dl; - - il = (il/2) & 3; - const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - dl *= coef; - - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); - } -#else - float kcoef = il&1 ? 1.f/16.f : 1.f; - uint16_t kmask = il&1 ? 0xF0 : 0x0F; - float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8); - float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - uint8_t m = 1<<(il*2); - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef)); - } -#endif -} - -static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { - return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)} - : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))}; -} - -template -void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { - device const uchar * q = xb->qs; - -#if QK_K == 256 - short is = (il/4) * 2; - q = q + (il/4) * 32 + 16 * (il&1); - il = il & 3; - const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const float d = il < 2 ? xb->d : xb->d / 16.h; - const float min = xb->dmin; - const float dl = d * sc[0]; - const float ml = min * sc[1]; -#else - (void) get_scale_min_k4_just2; - - q = q + 16 * (il&1); - device const uint8_t * s = xb->scales; - device const half2 * dh = (device const half2 *)xb->d; - const float2 d = (float2)dh[0]; - const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h; - const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4); -#endif - const ushort mask = il<2 ? 0x0F : 0xF0; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - ml; - } -} - -template -void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { - device const uint8_t * q = xb->qs; - device const uint8_t * qh = xb->qh; - -#if QK_K == 256 - short is = (il/4) * 2; - q = q + 32 * (il/4) + 16 * (il&1); - qh = qh + 16 * (il&1); - uint8_t ul = 1 << (il/2); - il = il & 3; - const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const float d = il < 2 ? xb->d : xb->d / 16.f; - const float min = xb->dmin; - const float dl = d * sc[0]; - const float ml = min * sc[1]; - - const ushort mask = il<2 ? 0x0F : 0xF0; - const float qh_val = il<2 ? 16.f : 256.f; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; - } -#else - q = q + 16 * (il&1); - device const int8_t * s = xb->scales; - const float dl = xb->d * s[il]; - uint8_t m = 1<<(il*2); - const float coef = il<2 ? 1.f : 1.f/16.f; - const ushort mask = il<2 ? 0x0F : 0xF0; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef)); - } -#endif -} - -template -void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { - const half d_all = xb->d; - device const uint8_t * ql = (device const uint8_t *)xb->ql; - device const uint8_t * qh = (device const uint8_t *)xb->qh; - device const int8_t * scales = (device const int8_t *)xb->scales; - -#if QK_K == 256 - ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); - qh = qh + 32*(il/8) + 16*(il&1); - float sc = scales[(il%2) + 2 * ((il/2))]; - il = (il/2) & 3; -#else - ql = ql + 16 * (il&1); - float sc = scales[il]; -#endif - const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; - const float coef = il>1 ? 1.f/16.f : 1.f; - const float ml = d_all * sc * 32.f; - const float dl = d_all * sc * coef; - for (int i = 0; i < 16; ++i) { - const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) - : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); - reg[i/4][i%4] = dl * q - ml; - } -} - -template -void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const float d = xb->d; - const int ib32 = il/2; - il = il%2; - // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 - // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's. - device const uint16_t * q2 = xb->qs + 4*ib32; - const uint32_t aux32_g = q2[0] | (q2[1] << 16); - const uint32_t aux32_s = q2[2] | (q2[3] << 16); - thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g; - const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f; - constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]); - uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127]; - for (int i = 0; i < 8; ++i) { - reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); - } - grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]); - signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127]; - for (int i = 0; i < 8; ++i) { - reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); - } -} - -template -void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const float d = xb->d; - const int ib32 = il/2; - il = il%2; - // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 - device const uint16_t * q2 = xb->qs + 4*ib32; - const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; - constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511)); - uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9]; - for (int i = 0; i < 8; ++i) { - reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); - } - grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511)); - signs = ksigns_iq2xs[q2[2*il+1] >> 9]; - for (int i = 0; i < 8; ++i) { - reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); - } -} - -template -void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const float d = xb->d; - const int ib32 = il/2; - il = il%2; - // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 - device const uint8_t * q3 = xb->qs + 8*ib32; - device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32; - const uint32_t aux32 = gas[0] | (gas[1] << 16); - const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f; - constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]); - constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]); - uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127]; - for (int i = 0; i < 4; ++i) { - reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); - reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); - } - grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]); - grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]); - signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127]; - for (int i = 0; i < 4; ++i) { - reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); - reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); - } -} - -template -void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const float d = xb->d; - const int ib32 = il/2; - il = il%2; - // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 - device const uint8_t * qs = xb->qs + 8*ib32; - device const uint8_t * signs = xb->signs + 4*ib32 + 2*il; - const uint8_t qh = xb->qh[ib32] >> 4*il; - const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); - constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256))); - constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256))); - for (int i = 0; i < 4; ++i) { - reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); - reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); - } - grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256))); - grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256))); - for (int i = 0; i < 4; ++i) { - reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); - reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); - } -} - -template -void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const float d = xb->d; - const int ib32 = il/2; - il = il%2; - // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 - device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; - device const uint8_t * signs = qs + QK_K/8; - const uint8_t qh = xb->qh[ib32] >> 4*il; - const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; - constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300))); - constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300))); - for (int i = 0; i < 8; ++i) { - reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]); - reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]); - } -} - -template -void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const int ib32 = il/2; - il = il%2; - const float d = xb->d; - device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; - device const uint16_t * qh = xb->qh; - const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1); - const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA); - const uint16_t h = qh[ib32] >> 6*il; - constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700))); - constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700))); - for (int i = 0; i < 4; ++i) { - reg[0][i] = dl * (grid1[i] & 0xf) + ml; - reg[1][i] = dl * (grid1[i] >> 4) + ml; - reg[2][i] = dl * (grid2[i] & 0xf) + ml; - reg[3][i] = dl * (grid2[i] >> 4) + ml; - } -} - -template -void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const int ib32 = il/2; - il = il%2; - device const uint16_t * sc = (device const uint16_t *)xb->scales; -#if QK_K == 64 - const float d = xb->d; -#else - iq1m_scale_t scale; - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - const float d = scale.f16; -#endif - device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; - device const uint8_t * qh = xb->qh + 2*ib32 + il; -#if QK_K == 64 - const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1); -#else - const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1); -#endif - const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); - const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); - constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); - constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); - for (int i = 0; i < 4; ++i) { - reg[0][i] = dl * (grid1[i] & 0xf) + ml1; - reg[1][i] = dl * (grid1[i] >> 4) + ml1; - reg[2][i] = dl * (grid2[i] & 0xf) + ml2; - reg[3][i] = dl * (grid2[i] >> 4) + ml2; - } -} - -template -void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) { - device const uint16_t * q4 = (device const uint16_t *)xb->qs; - const float d = xb->d; - uint32_t aux32; - thread const uint8_t * q8 = (thread const uint8_t *)&aux32; - for (int i = 0; i < 4; ++i) { - aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f; - reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; - reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; - reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; - reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; - } -} - -template -void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) { -#if QK_K == 64 - dequantize_iq4_nl(xb, il, reg); -#else - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const int ib32 = il/2; - il = il%2; - // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 - device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32; - const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4); - const float d = (float)xb->d * (ls - 32); - uint32_t aux32; - thread const uint8_t * q8 = (thread const uint8_t *)&aux32; - for (int i = 0; i < 4; ++i) { - aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f; - reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; - reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; - reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; - reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; - } -#endif -} - -template -kernel void kernel_get_rows( - device const void * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - //const int64_t i = tgpig; - //const int64_t r = ((device int32_t *) src1)[i]; - - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; - - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; - - const int64_t i02 = i11; - - for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { - float4x4 temp; - dequantize_func( - ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; - } -} - -kernel void kernel_get_rows_f32( - device const void * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; - - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} - -kernel void kernel_get_rows_f16( - device const void * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; - - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} - -kernel void kernel_get_rows_i32( - device const void * src0, - device const char * src1, - device int32_t * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; - - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} - - -#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A -#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B -#define BLOCK_SIZE_K 32 -#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A -#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B -#define THREAD_PER_BLOCK 128 -#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers -#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers -#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 -#define SG_MAT_ROW 8 - -// each block_q contains 16*nl weights -template -void kernel_mul_mm_impl(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - threadgroup half * sa = (threadgroup half *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); - - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; - const uint im = tgpig.z; - - // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; - - // a thread shouldn't load data outside of the matrix - short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; - short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - - simdgroup_half8x8 ma[4]; - simdgroup_float8x8 mb[2]; - simdgroup_float8x8 c_res[8]; - for (int i = 0; i < 8; i++){ - c_res[i] = make_filled_simdgroup_matrix(0.f); - } - - short il = (tiitg % THREAD_PER_ROW); - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); - ushort offset1 = il/nl; - - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; - device const float * y = (device const float *)(src1 - + nb12 * im - + nb11 * (r1 * BLOCK_SIZE_N + thread_col) - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); - - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { - // load data and store to threadgroup memory - half4x4 temp_a; - dequantize_func(x, il, temp_a); - threadgroup_barrier(mem_flags::mem_threadgroup); - - #pragma unroll(16) - for (int i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ - + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ - + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; - } - - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); - - il = (il + 2 < nl) ? il + 2 : il % 2; - x = (il < 2) ? x + (2+nl-1)/nl : x; - y += BLOCK_SIZE_K; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // load matrices from threadgroup memory and conduct outer products - threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); - threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); - - #pragma unroll(4) - for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { - #pragma unroll(4) - for (int i = 0; i < 4; i++) { - simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); - } - simdgroup_barrier(mem_flags::mem_none); - #pragma unroll(2) - for (int i = 0; i < 2; i++) { - simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); - } - - lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; - lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; - - #pragma unroll(8) - for (int i = 0; i < 8; i++){ - simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); - } - } - } - - if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { - device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ - + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); - } - } else { - // block is smaller than 64x32, we should avoid writing data outside of the matrix - threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ - + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; - if (sgitg == 0) { - for (int i = 0; i < n_rows; i++) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); - } - } - } - } -} - -// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids -template -void kernel_mul_mm_id_impl( - device const uchar * src0, - device const uchar * src1, - threadgroup ushort2 * rowids, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - int64_t ne1, - int64_t ne0ne1, - threadgroup uchar * shared_memory, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - threadgroup half * sa = (threadgroup half *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); - - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; - - if (r1 * BLOCK_SIZE_N >= ne1) return; - - // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; - - // a thread shouldn't load data outside of the matrix - short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; - short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - - simdgroup_half8x8 ma[4]; - simdgroup_float8x8 mb[2]; - simdgroup_float8x8 c_res[8]; - for (int i = 0; i < 8; i++){ - c_res[i] = make_filled_simdgroup_matrix(0.f); - } - short il = (tiitg % THREAD_PER_ROW); - - ushort offset1 = il/nl; - - threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col]; - - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1; - device const float * y = (device const float *)(src1 - + nb12 * id[1] - + nb11 * (id[0] % ne11) - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); - - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { - // load data and store to threadgroup memory - half4x4 temp_a; - dequantize_func(x, il, temp_a); - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (int i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ - + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ - + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; - } - - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); - - il = (il + 2 < nl) ? il + 2 : il % 2; - x = (il < 2) ? x + (2+nl-1)/nl : x; - y += BLOCK_SIZE_K; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // load matrices from threadgroup memory and conduct outer products - threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); - threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); - - for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { - for (int i = 0; i < 4; i++) { - simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); - } - simdgroup_barrier(mem_flags::mem_none); - for (int i = 0; i < 2; i++) { - simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); - } - - lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; - lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; - - for (int i = 0; i < 8; i++){ - simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); - } - } - } - - { - threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ - + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - device float * C = dst + (BLOCK_SIZE_M * r0); - if (sgitg == 0) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; - int joff = jid[0] * ne0 + jid[1] * ne0ne1; - for (int i = 0; i < n_rows; i++) { - *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M); - } - } - } - } -} - -template -kernel void kernel_mul_mm(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mm_impl( - src0, - src1, - dst, - ne00, - ne02, - nb01, - nb02, - ne12, - nb10, - nb11, - nb12, - ne0, - ne1, - r2, - r3, - shared_memory, - tgpig, - tiitg, - sgitg); -} - -template -kernel void kernel_mul_mm_id( - device const uchar * src0s, - device const uchar * src1, - device float * dst, - device const uchar * ids, - constant int64_t & nei0, - constant int64_t & nei1, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int32_t i02 = tgpig.z; - tgpig.z = 0; - - device const uchar * src0 = src0s + i02*nb02; - - // row indices - threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); - - // TODO: parallelize this loop - int64_t _ne1 = 0; - for (ushort ii1 = 0; ii1 < nei1; ii1++) { - for (ushort ii0 = 0; ii0 < nei0; ii0++) { - int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; - if (id == i02) { - //if (tiitg == 0) { - rowids[_ne1] = ushort2(ii0, ii1); - //} - _ne1++; - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - kernel_mul_mm_id_impl( - src0, - src1, - rowids, - dst, - ne00, - ne02, - nb01, - nb02, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - _ne1, - ne0*ne1, - shared_memory, - tgpig, - tiitg, - sgitg); -} - -#if QK_K == 256 -#define QK_NL 16 -#else -#define QK_NL 4 -#endif - -// -// get rows -// - -typedef void (get_rows_t)( - device const void * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3, uint, uint3); - -//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; -//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows; -#if QK_K == 64 -template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows; -#else -template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows; -#endif - -// -// matrix-matrix multiplication -// - -typedef decltype(kernel_mul_mm) mat_mm_t; - -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; -#if QK_K == 64 -template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; -#else -template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; -#endif - -// -// indirect matrix-matrix multiplication -// - -typedef decltype(kernel_mul_mm_id) mat_mm_id_t; - -template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -#if QK_K == 64 -template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -#else -template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -#endif - -// -// matrix-vector multiplication -// - -typedef void (kernel_mul_mv_impl_t)( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - int64_t ne10, - int64_t ne11, - int64_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg); - -typedef void (kernel_mul_mv2_impl_t)( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg); - -template -void mmv_fn( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - int64_t ne10, - int64_t ne11, - int64_t ne12, - int64_t ne13, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int64_t ne0, - int64_t ne1, - uint64_t nb1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { - impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg); -} - -template -void mmv_fn( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - int64_t ne10, - int64_t ne11, - int64_t ne12, - int64_t ne13, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int64_t ne0, - int64_t ne1, - uint64_t nb1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { - impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); -} - -typedef decltype(mmv_fn) mul_mv_impl_fn_t; - -template -kernel void kernel_mul_mv_id( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant int64_t & nei0, - constant int64_t & nei1, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int iid1 = tgpig.z/nei0; - const int idx = tgpig.z%nei0; - - tgpig.z = 0; - - const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx]; - - const int64_t i11 = idx % ne11; - const int64_t i12 = iid1; - - const int64_t i1 = idx; - const int64_t i2 = i12; - - device const char * src0_cur = src0s + i02*nb02; - device const char * src1_cur = src1 + i11*nb11 + i12*nb12; - device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; - - impl_fn( - /* src0 */ src0_cur, - /* src1 */ src1_cur, - /* dst */ dst_cur, - /* ne00 */ ne00, - /* ne01 */ ne01, - /* ne02 */ 1,//ne02, - /* nb00 */ nb00, - /* nb01 */ nb01, - /* nb02 */ nb02, - /* ne10 */ ne10, - /* ne11 */ 1,//ne11, - /* ne12 */ 1,//ne12, - /* ne13 */ 1,//ne13, - /* nb10 */ nb10, - /* nb11 */ nb11, - /* nb12 */ nb12, - /* ne0 */ ne0, - /* ne1 */ 1,//ne1, - /* nb1 */ nb1, - /* r2 */ 1, - /* r3 */ 1, - shared_values, - tgpig, - tiitg, - tiisg, - sgitg); -} - -typedef decltype(kernel_mul_mv_id>) kernel_mul_mv_id_t; - -template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -#if QK_K != 64 -template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -#endif - diff --git a/enjoy/lib/whisper.cpp/x64/darwin/main b/enjoy/lib/whisper.cpp/x64/darwin/main deleted file mode 100755 index f937f800..00000000 Binary files a/enjoy/lib/whisper.cpp/x64/darwin/main and /dev/null differ diff --git a/enjoy/lib/whisper.cpp/x64/linux/main b/enjoy/lib/whisper.cpp/x64/linux/main deleted file mode 100755 index a3c25087..00000000 Binary files a/enjoy/lib/whisper.cpp/x64/linux/main and /dev/null differ diff --git a/enjoy/lib/whisper.cpp/x64/win32/SDL2.dll b/enjoy/lib/whisper.cpp/x64/win32/SDL2.dll deleted file mode 100644 index e26bcb1c..00000000 Binary files a/enjoy/lib/whisper.cpp/x64/win32/SDL2.dll and /dev/null differ diff --git a/enjoy/lib/whisper.cpp/x64/win32/main.exe b/enjoy/lib/whisper.cpp/x64/win32/main.exe deleted file mode 100644 index 65c45db5..00000000 Binary files a/enjoy/lib/whisper.cpp/x64/win32/main.exe and /dev/null differ diff --git a/enjoy/lib/whisper.cpp/x64/win32/whisper.dll b/enjoy/lib/whisper.cpp/x64/win32/whisper.dll deleted file mode 100644 index 6f63a22e..00000000 Binary files a/enjoy/lib/whisper.cpp/x64/win32/whisper.dll and /dev/null differ diff --git a/enjoy/package.json b/enjoy/package.json index 4e66fb02..b774453d 100644 --- a/enjoy/package.json +++ b/enjoy/package.json @@ -19,9 +19,8 @@ "test:main": "yarn run playwright test e2e/main.spec.ts", "test:renderer": "yarn run playwright test e2e/renderer.spec.ts", "create-migration": "zx ./src/main/db/create-migration.mjs", - "download-whisper-model": "zx ./scripts/download-whisper-model.mjs", "download-dictionaries": "zx ./scripts/download-dictionaries.mjs", - "download": "yarn run download-whisper-model && yarn run download-dictionaries" + "download": "yarn run download-dictionaries" }, "keywords": [], "author": { @@ -41,7 +40,7 @@ "@electron-forge/plugin-vite": "<7.5.0", "@electron-forge/publisher-github": "<7.5.0", "@electron/fuses": "^1.8.0", - "@playwright/test": "^1.48.0", + "@playwright/test": "^1.48.1", "@tailwindcss/typography": "^0.5.15", "@types/ahoy.js": "^0.4.2", "@types/autosize": "^4.0.3", @@ -66,7 +65,7 @@ "@typescript-eslint/parser": "^8.9.0", "@vitejs/plugin-react": "^4.3.2", "autoprefixer": "^10.4.20", - "electron": "^32.2.0", + "electron": "^33.0.1", "electron-devtools-installer": "^3.2.0", "electron-playwright-helpers": "^1.7.1", "eslint": "^9.12.0", @@ -78,10 +77,10 @@ "prop-types": "^15.8.1", "tailwind-merge": "^2.5.4", "tailwind-scrollbar": "^3.1.0", - "tailwindcss": "^3.4.13", + "tailwindcss": "^3.4.14", "tailwindcss-animate": "^1.0.7", "ts-node": "^10.9.2", - "tslib": "^2.7.0", + "tslib": "^2.8.0", "typescript": "^5.6.3", "vite": "^5.4.9", "vite-plugin-static-copy": "^2.0.0", @@ -120,7 +119,7 @@ "@radix-ui/react-toast": "^1.2.2", "@radix-ui/react-toggle": "^1.1.0", "@radix-ui/react-tooltip": "^1.1.3", - "@rails/actioncable": "7.2.100", + "@rails/actioncable": "7.2.101", "@uidotdev/usehooks": "^2.4.1", "@vidstack/react": "^1.12.11", "ahoy.js": "^0.4.4", @@ -128,7 +127,7 @@ "axios": "^1.7.7", "camelcase": "^8.0.0", "camelcase-keys": "^9.1.3", - "chart.js": "^4.4.4", + "chart.js": "^4.4.5", "cheerio": "^1.0.0", "class-variance-authority": "^0.7.0", "clsx": "^2.1.1", @@ -140,7 +139,7 @@ "dayjs": "^1.11.13", "decamelize": "^6.0.0", "decamelize-keys": "^2.0.1", - "echogarden": "^1.8.1", + "echogarden": "^1.8.4", "electron-context-menu": "^4.0.4", "electron-log": "^5.2.0", "electron-settings": "^4.0.4", @@ -156,9 +155,9 @@ "langchain": "^0.3.2", "lodash": "^4.17.21", "lru-cache": "^11.0.1", - "lucide-react": "^0.452.0", + "lucide-react": "^0.453.0", "mark.js": "^8.11.1", - "microsoft-cognitiveservices-speech-sdk": "^1.40.0", + "microsoft-cognitiveservices-speech-sdk": "^1.41.0", "mustache": "^4.2.0", "next-themes": "^0.3.0", "openai": "^4.67.3", diff --git a/enjoy/scripts/download-whisper-model.mjs b/enjoy/scripts/download-whisper-model.mjs deleted file mode 100755 index befec9b9..00000000 --- a/enjoy/scripts/download-whisper-model.mjs +++ /dev/null @@ -1,113 +0,0 @@ -#!/usr/bin/env zx - -import axios from "axios"; -import progress from "progress"; -import { createHash } from "crypto"; - -const model = "ggml-tiny.en.bin"; -const sha = "c78c86eb1a8faa21b369bcd33207cc90d64ae9df"; - -const dir = path.join(process.cwd(), "lib/whisper.cpp/models"); - -console.info(chalk.blue(`=> Download whisper model ${model}`)); - -fs.ensureDirSync(dir); -try { - if (fs.statSync(path.join(dir, model)).isFile()) { - console.info(chalk.green(`✅ Model ${model} already exists`)); - const hash = await hashFile(path.join(dir, model), { algo: "sha1" }); - if (hash === sha) { - console.info(chalk.green(`✅ Model ${model} valid`)); - process.exit(0); - } else { - console.error( - chalk.red(`❌ Model ${model} not valid, start to redownload`) - ); - fs.removeSync(path.join(dir, model)); - } - } -} catch (err) { - if (err && err.code !== "ENOENT") { - console.error(chalk.red(`❌ Error: ${err}`)); - process.exit(1); - } else { - console.info(chalk.blue(`=> Start to download model ${model}`)); - } -} - -const proxyUrl = - process.env.HTTPS_PROXY || - process.env.https_proxy || - process.env.HTTP_PROXY || - process.env.http_proxy; - -if (proxyUrl) { - const { hostname, port, protocol } = new URL(proxyUrl); - axios.defaults.proxy = { - host: hostname, - port: port, - protocol: protocol, - }; -} - -// const modelUrlPrefix = -// "https://hf-mirror.com/ggerganov/whisper.cpp/resolve/main"; -const modelUrlPrefix = "https://enjoy-storage.baizhiheizi.com"; - -function hashFile(path, options) { - const algo = options.algo || "sha1"; - return new Promise((resolve, reject) => { - const hash = createHash(algo); - const stream = fs.createReadStream(path); - stream.on("error", reject); - stream.on("data", (chunk) => hash.update(chunk)); - stream.on("end", () => resolve(hash.digest("hex"))); - }); -} - -const download = async (url, dest) => { - console.info(chalk.blue(`=> Start to download from ${url} to ${dest}`)); - return axios - .get(url, { responseType: "stream" }) - .then((response) => { - const totalLength = response.headers["content-length"]; - - const progressBar = new progress(`-> downloading [:bar] :percent :etas`, { - width: 40, - complete: "=", - incomplete: " ", - renderThrottle: 1, - total: parseInt(totalLength), - }); - - response.data.on("data", (chunk) => { - progressBar.tick(chunk.length); - }); - - response.data.pipe(fs.createWriteStream(dest)).on("close", async () => { - console.info(chalk.green(`✅ Model ${model} downloaded successfully`)); - const hash = await hashFile(path.join(dir, model), { algo: "sha1" }); - if (hash === sha) { - console.info(chalk.green(`✅ Model ${model} valid`)); - process.exit(0); - } else { - console.error( - chalk.red( - `❌ Model ${model} not valid, please try again using command \`yarn workspace enjoy download-whisper-model\`` - ) - ); - process.exit(1); - } - }); - }) - .catch((err) => { - console.error( - chalk.red( - `❌ Failed to download ${url}: ${err}.\nPlease try again using command \`yarn workspace enjoy download-whisper-model\`` - ) - ); - process.exit(1); - }); -}; - -await download(`${modelUrlPrefix}/${model}`, path.join(dir, model)); diff --git a/enjoy/src/constants/index.ts b/enjoy/src/constants/index.ts index a211f206..50bb6a5b 100644 --- a/enjoy/src/constants/index.ts +++ b/enjoy/src/constants/index.ts @@ -2,9 +2,20 @@ export * from "./chat-agent-templates"; export * from "./gpt-presets"; export * from "./ipa"; -// https://hf-mirror.com/ggerganov/whisper.cpp/tree/main -import whisperModels from "./whisper-models.json"; -export const WHISPER_MODELS_OPTIONS = whisperModels; +export const WHISPER_MODELS = [ + "tiny", + "tiny.en", + "base", + "base.en", + "small", + "small.en", + "medium", + "medium.en", + "large-v1", + "large-v2", + "large-v3", + "large-v3-turbo", +]; import languages from "./languages.json"; export const LANGUAGES = languages; diff --git a/enjoy/src/constants/whisper-models.json b/enjoy/src/constants/whisper-models.json deleted file mode 100644 index 6b778aa6..00000000 --- a/enjoy/src/constants/whisper-models.json +++ /dev/null @@ -1,79 +0,0 @@ -[ - { - "type": "tiny", - "name": "ggml-tiny.bin", - "size": "75 MB", - "sha": "bd577a113a864445d4c299885e0cb97d4ba92b5f", - "url": "https://hf-mirror.com/ggerganov/whisper.cpp/resolve/main/ggml-tiny.bin" - }, - { - "type": "tiny.en", - "name": "ggml-tiny.en.bin", - "size": "75 MB", - "sha": "c78c86eb1a8faa21b369bcd33207cc90d64ae9df", - "url": "https://hf-mirror.com/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin" - }, - { - "type": "base", - "name": "ggml-base.bin", - "size": "142 MB", - "sha": "465707469ff3a37a2b9b8d8f89f2f99de7299dac", - "url": "https://hf-mirror.com/ggerganov/whisper.cpp/resolve/main/ggml-base.bin" - }, - { - "type": "base.en", - "name": "ggml-base.en.bin", - "size": "142 MB", - "sha": "137c40403d78fd54d454da0f9bd998f78703390c", - "url": "https://hf-mirror.com/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin" - }, - { - "type": "small", - "name": "ggml-small.bin", - "size": "466 MB", - "sha": "55356645c2b361a969dfd0ef2c5a50d530afd8d5", - "url": "https://hf-mirror.com/ggerganov/whisper.cpp/resolve/main/ggml-small.bin" - }, - { - "type": "small.en", - "name": "ggml-small.en.bin", - "size": "466 MB", - "sha": "db8a495a91d927739e50b3fc1cc4c6b8f6c2d022", - "url": "https://hf-mirror.com/ggerganov/whisper.cpp/resolve/main/ggml-small.en.bin" - }, - { - "type": "medium", - "name": "ggml-medium.bin", - "size": "1.5 GB", - "sha": "fd9727b6e1217c2f614f9b698455c4ffd82463b4", - "url": "https://hf-mirror.com/ggerganov/whisper.cpp/resolve/main/ggml-medium.bin" - }, - { - "type": "medium.en", - "name": "ggml-medium.en.bin", - "size": "1.5 GB", - "sha": "8c30f0e44ce9560643ebd10bbe50cd20eafd3723", - "url": "https://hf-mirror.com/ggerganov/whisper.cpp/resolve/main/ggml-medium.en.bin" - }, - { - "type": "large-v1", - "name": "ggml-large-v1.bin", - "size": "2.9 GB", - "sha": "b1caaf735c4cc1429223d5a74f0f4d0b9b59a299", - "url": "https://hf-mirror.com/ggerganov/whisper.cpp/resolve/main/ggml-large-v1.bin" - }, - { - "type": "large-v2", - "name": "ggml-large-v2.bin", - "size": "2.9 GB", - "sha": "0f4c8e34f21cf1a914c59d8b3ce882345ad349d6", - "url": "https://hf-mirror.com/ggerganov/whisper.cpp/resolve/main/ggml-large-v2.bin" - }, - { - "type": "large", - "name": "ggml-large-v3.bin", - "size": "2.9 GB", - "sha": "ad82bf6a9043ceed055076d0fd39f5f186ff8062", - "url": "https://hf-mirror.com/ggerganov/whisper.cpp/resolve/main/ggml-large-v3.bin" - } -] diff --git a/enjoy/src/i18n/en.json b/enjoy/src/i18n/en.json index 8a8d587b..8628260d 100644 --- a/enjoy/src/i18n/en.json +++ b/enjoy/src/i18n/en.json @@ -450,6 +450,10 @@ "whisperIsWorkingGood": "Whisper is working good", "whisperIsNotWorking": "Whisper is not working", "whisperTranscribeFailed": "Whisper transcribe failed. Maybe your machine is not compatbile with it. Please use other STT service. {{error}}", + "openaiTranscribeFailed": "OpenAI transcribe failed. {{error}}", + "cloudflareTranscribeFailed": "Cloudflare transcribe failed. {{error}}", + "azureTranscribeFailed": "Azure transcribe failed. {{error}}", + "transcribeFailed": "Transcribe failed.", "checkingWhisperModel": "Checking whisper model", "whisperModelIsWorkingGood": "Whisper model is working good", "whisperModelIsNotWorking": "Whisper model is not working", diff --git a/enjoy/src/i18n/zh-CN.json b/enjoy/src/i18n/zh-CN.json index 5d724b7f..f9fe7574 100644 --- a/enjoy/src/i18n/zh-CN.json +++ b/enjoy/src/i18n/zh-CN.json @@ -450,6 +450,10 @@ "whisperIsWorkingGood": "Whisper 正常工作", "whisperIsNotWorking": "Whisper 无法正常工作,请尝试更换模型后重试,或联系开发者", "whisperTranscribeFailed": "Whisper 语音转文本失败。您的设备可能不兼容该功能,请选择其他语音转文本服务。{{error}}", + "openaiTranscribeFailed": "OpenAI 语音转文本失败。{{error}}", + "cloudflareTranscribeFailed": "Cloudflare 语音转文本失败。{{error}}", + "azureTranscribeFailed": "Azure 语音转文本失败。{{error}}", + "transcribeFailed": "语音转文本失败。", "checkingWhisperModel": "正在检查 Whisper 模型", "whisperModelIsWorkingGood": "Whisper 模型正常工作", "whisperModelIsNotWorking": "Whisper 模型无法正常工作,请尝试更换模型后重试,或联系开发者", diff --git a/enjoy/src/main/echogarden.ts b/enjoy/src/main/echogarden.ts index 48f9ab11..18272483 100644 --- a/enjoy/src/main/echogarden.ts +++ b/enjoy/src/main/echogarden.ts @@ -1,6 +1,6 @@ import { ipcMain } from "electron"; import * as Echogarden from "echogarden/dist/api/API.js"; -import { AlignmentOptions } from "echogarden/dist/api/API"; +import { AlignmentOptions, RecognitionOptions } from "echogarden/dist/api/API"; import { encodeRawAudioToWave, decodeWaveToRawAudio, @@ -22,11 +22,19 @@ import settings from "@main/settings"; import fs from "fs-extra"; import ffmpegPath from "ffmpeg-static"; import { enjoyUrlToPath, pathToEnjoyUrl } from "./utils"; +import { UserSetting } from "./db/models"; +import { UserSettingKeyEnum } from "@/types/enums"; +import { WHISPER_MODELS } from "@/constants"; +import { WhisperOptions } from "echogarden/dist/recognition/WhisperSTT.js"; Echogarden.setGlobalOption( "ffmpegPath", ffmpegPath.replace("app.asar", "app.asar.unpacked") ); +Echogarden.setGlobalOption( + "packageBaseURL", + "https://hf-mirror.com/echogarden/echogarden-packages/resolve/main/" +); const __filename = url.fileURLToPath(import.meta.url); /* @@ -38,6 +46,7 @@ const __dirname = path const logger = log.scope("echogarden"); class EchogardenWrapper { + public recognize: typeof Echogarden.recognize; public align: typeof Echogarden.align; public alignSegments: typeof Echogarden.alignSegments; public denoise: typeof Echogarden.denoise; @@ -50,6 +59,7 @@ class EchogardenWrapper { public wordTimelineToSegmentSentenceTimeline: typeof wordTimelineToSegmentSentenceTimeline; constructor() { + this.recognize = Echogarden.recognize; this.align = Echogarden.align; this.alignSegments = Echogarden.alignSegments; this.denoise = Echogarden.denoise; @@ -63,14 +73,27 @@ class EchogardenWrapper { wordTimelineToSegmentSentenceTimeline; } - async check() { + async check( + options: RecognitionOptions = { + engine: "whisper", + whisper: { + model: "tiny.en", + language: "en", + } as WhisperOptions, + } + ) { const sampleFile = path.join(__dirname, "samples", "jfk.wav"); try { - const result = await this.align( - sampleFile, - "And so my fellow Americans ask not what your country can do for you", - {} - ); + const whisperModel = await UserSetting.get(UserSettingKeyEnum.WHISPER); + if (WHISPER_MODELS.includes(whisperModel)) { + options.whisper.model = whisperModel; + } + } catch (e) { + logger.error(e); + } + + try { + const result = await this.recognize(sampleFile, options); logger.info(result); fs.writeJsonSync( path.join(settings.cachePath(), "echogarden-check.json"), @@ -78,10 +101,10 @@ class EchogardenWrapper { { spaces: 2 } ); - return true; + return { success: true, log: "" }; } catch (e) { logger.error(e); - return false; + return { success: false, log: e.message }; } } @@ -102,6 +125,20 @@ class EchogardenWrapper { } registerIpcHandlers() { + ipcMain.handle( + "echogarden-recognize", + async (_event, url: string, options: RecognitionOptions) => { + logger.debug("echogarden-recognize:", options); + try { + const input = enjoyUrlToPath(url); + return await this.recognize(input, options); + } catch (err) { + logger.error(err); + throw err; + } + } + ); + ipcMain.handle( "echogarden-align", async ( @@ -129,6 +166,9 @@ class EchogardenWrapper { options: AlignmentOptions ) => { logger.debug("echogarden-align-segments:", timeline, options); + if (typeof input === "string") { + input = enjoyUrlToPath(input); + } try { const rawAudio = await this.ensureRawAudio(input, 16000); return await this.alignSegments(rawAudio, timeline, options); @@ -182,8 +222,8 @@ class EchogardenWrapper { } ); - ipcMain.handle("echogarden-check", async (_event) => { - return this.check(); + ipcMain.handle("echogarden-check", async (_event, options: any) => { + return this.check(options); }); } } diff --git a/enjoy/src/main/whisper.ts b/enjoy/src/main/whisper.ts deleted file mode 100644 index d5022114..00000000 --- a/enjoy/src/main/whisper.ts +++ /dev/null @@ -1,330 +0,0 @@ -import { ipcMain } from "electron"; -import settings from "@main/settings"; -import path from "path"; -import { WHISPER_MODELS_OPTIONS, PROCESS_TIMEOUT } from "@/constants"; -import { exec, spawn } from "child_process"; -import fs from "fs-extra"; -import log from "@main/logger"; -import url from "url"; -import { enjoyUrlToPath } from "./utils"; -import { t } from "i18next"; -import { UserSetting } from "@main/db/models"; -import db from "@main/db"; -import { UserSettingKeyEnum } from "@/types/enums"; - -const __filename = url.fileURLToPath(import.meta.url); -/* - * whipser bin file will be in /app.asar.unpacked instead of /app.asar - */ -const __dirname = path - .dirname(__filename) - .replace("app.asar", "app.asar.unpacked"); - -const logger = log.scope("whisper"); - -class Whipser { - private binMain: string; - private bundledModelsDir: string; - public config: WhisperConfigType; - private abortController: AbortController; - - constructor() { - const customWhisperPath = path.join( - settings.libraryPath(), - "whisper", - "main" - ); - this.bundledModelsDir = path.join(__dirname, "lib", "whisper", "models"); - if (fs.existsSync(customWhisperPath)) { - this.binMain = customWhisperPath; - } else { - this.binMain = path.join(__dirname, "lib", "whisper", "main"); - } - this.initialize(); - } - - async initialize() { - const models = []; - - const bundledModels = fs.readdirSync(this.bundledModelsDir); - for (const file of bundledModels) { - const model = WHISPER_MODELS_OPTIONS.find((m) => m.name == file); - if (!model) continue; - - models.push({ - ...model, - savePath: path.join(this.bundledModelsDir, file), - }); - } - - const dir = path.join(settings.libraryPath(), "whisper", "models"); - fs.ensureDirSync(dir); - const files = fs.readdirSync(dir); - for (const file of files) { - const model = WHISPER_MODELS_OPTIONS.find((m) => m.name == file); - if (!model) continue; - - models.push({ - ...model, - savePath: path.join(dir, file), - }); - } - - if (db.connection) { - const whisperConfig = (await UserSetting.get( - UserSettingKeyEnum.WHISPER - )) as string; - this.config = { - model: whisperConfig || models[0].name, - availableModels: models, - modelsPath: dir, - }; - } else { - this.config = { - model: models[0].name, - availableModels: models, - modelsPath: dir, - }; - } - } - - currentModel() { - if (!this.config.availableModels) return; - - let model: WhisperConfigType["availableModels"][0]; - if (this.config.model) { - model = (this.config.availableModels || []).find( - (m) => m.name === this.config.model - ); - } - if (!model) { - model = this.config.availableModels[0]; - this.config = Object.assign({}, this.config, { model: model.name }); - UserSetting.set(UserSettingKeyEnum.WHISPER, model.name); - } - - return model; - } - - async check() { - this.abortController?.abort(); - this.abortController = new AbortController(); - - const model = this.currentModel(); - logger.debug(`Checking whisper model: ${model.savePath}`); - - const sampleFile = path.join(__dirname, "samples", "jfk.wav"); - const tmpDir = settings.cachePath(); - const outputFile = path.join(tmpDir, "jfk.json"); - fs.rmSync(outputFile, { force: true }); - return new Promise((resolve, _reject) => { - const commands = [ - `"${this.binMain}"`, - `--file "${sampleFile}"`, - `--model "${model.savePath}"`, - "--output-json", - `--output-file "${path.join(tmpDir, "jfk")}"`, - ]; - logger.debug(`Checking whisper command: ${commands.join(" ")}`); - exec( - commands.join(" "), - { - timeout: PROCESS_TIMEOUT, - signal: this.abortController.signal, - }, - (error, stdout, stderr) => { - if (error) { - logger.error("error", error); - } - - if (stderr) { - logger.info("stderr", stderr); - } - - if (stdout) { - logger.debug(stdout); - } - - resolve({ - success: fs.existsSync(outputFile), - log: `${error?.message || ""}\n${stderr}\n${stdout}`, - }); - } - ); - }); - } - - /* Ensure the file is in wav format - * and 16kHz sample rate - */ - async transcribe( - params: { - file?: string; - blob?: { - type: string; - arrayBuffer: ArrayBuffer; - }; - }, - options?: { - language?: string; - force?: boolean; - extra?: string[]; - onProgress?: (progress: number) => void; - } - ): Promise> { - logger.debug("transcribing from local"); - - this.abortController?.abort(); - this.abortController = new AbortController(); - - const { blob } = params; - let { file } = params; - - if (file) { - file = enjoyUrlToPath(file); - } else if (blob) { - const format = blob.type.split("/")[1]; - if (format !== "wav") { - throw new Error("Only wav format is supported"); - } - - file = path.join(settings.cachePath(), `${Date.now()}.${format}`); - await fs.outputFile(file, Buffer.from(blob.arrayBuffer)); - } else { - throw new Error("No file or blob provided"); - } - - const { force = false, extra = [], language, onProgress } = options || {}; - - const model = this.currentModel(); - if (language && !language.startsWith("en") && model.name.includes("en")) { - throw new Error(`Model ${model.name} does not support ${language}`); - } - - const filename = path.basename(file, path.extname(file)); - const tmpDir = settings.cachePath(); - const outputFile = path.join(tmpDir, filename + ".json"); - - logger.info(`Trying to transcribe ${file} to ${outputFile}`); - if (fs.pathExistsSync(outputFile) && !force) { - logger.info(`File ${outputFile} already exists`); - return fs.readJson(outputFile); - } - - const commandArguments = [ - "--file", - file, - "--model", - model.savePath, - "--output-json", - "--output-file", - path.join(tmpDir, filename), - "--print-progress", - "--language", - model.name.includes("en") ? "en" : language?.split("-")?.[0] || "auto", - // `--split-on-word`, - // `--max-len`, - // "1", - ...extra, - ]; - - logger.info( - `Running command: ${this.binMain} ${commandArguments.join(" ")}` - ); - - const command = spawn(this.binMain, commandArguments, { - timeout: PROCESS_TIMEOUT, - signal: this.abortController.signal, - }); - - return new Promise((resolve, reject) => { - command.stdout.on("data", (data) => { - logger.debug(`stdout: ${data}`); - }); - - command.stderr.on("data", (data) => { - const output = data.toString(); - logger.info(`stderr: ${output}`); - if (output.startsWith("whisper_print_progress_callback")) { - const progress = parseInt(output.match(/\d+%/)?.[0] || "0"); - if (typeof progress === "number" && onProgress) onProgress(progress); - } - }); - - command.on("exit", (code) => { - logger.info(`transcribe process exited with code ${code}`); - }); - - command.on("error", (err) => { - logger.error("transcribe error", err.message); - reject(err); - }); - - command.on("close", (code) => { - if (code === 0 && fs.pathExistsSync(outputFile)) { - resolve(fs.readJson(outputFile)); - } else { - reject(new Error("Transcription failed")); - } - }); - }); - } - - abort() { - this.abortController?.abort(); - } - - registerIpcHandlers() { - ipcMain.handle("whisper-config", async () => { - await this.initialize(); - return this.config; - }); - - ipcMain.handle("whisper-set-model", async (_event, model) => { - const originalModel = this.config.model; - this.config.model = model; - - return this.check() - .then(({ success, log }) => { - if (success) { - return Object.assign({}, this.config, { ready: true }); - } else { - throw new Error(log); - } - }) - .catch((err) => { - this.config.model = originalModel; - throw err; - }) - .finally(() => { - UserSetting.set(UserSettingKeyEnum.WHISPER, this.config.model); - }); - }); - - ipcMain.handle("whisper-check", async (_event) => { - return await this.check(); - }); - - ipcMain.handle("whisper-transcribe", async (event, params, options) => { - return this.transcribe(params, { - ...options, - onProgress: (progress) => { - event.sender.send("whisper-on-progress", progress); - }, - }) - .then((result) => { - return result; - }) - .catch((err) => { - logger.error(err); - throw t("whisperTranscribeFailed", { error: err.message }); - }); - }); - - ipcMain.handle("whisper-abort", async (_event) => { - return await this.abort(); - }); - } -} - -export default new Whipser(); diff --git a/enjoy/src/main/window.ts b/enjoy/src/main/window.ts index a421000f..5c4e0067 100644 --- a/enjoy/src/main/window.ts +++ b/enjoy/src/main/window.ts @@ -12,7 +12,6 @@ import path from "path"; import db from "@main/db"; import settings from "@main/settings"; import downloader from "@main/downloader"; -import whisper from "@main/whisper"; import fs from "fs-extra"; import log from "@main/logger"; import { REPO_URL, WS_URL } from "@/constants"; @@ -62,9 +61,6 @@ main.init = async () => { // echogarden echogarden.registerIpcHandlers(); - // Whisper - whisper.registerIpcHandlers(); - // Waveform waveform.registerIpcHandlers(); @@ -411,7 +407,6 @@ ${log} segments: path.join(settings.userDataPath(), "segments"), speeches: path.join(settings.userDataPath(), "speeches"), recordings: path.join(settings.userDataPath(), "recordings"), - whisper: path.join(settings.libraryPath(), "whisper"), waveforms: path.join(settings.libraryPath(), "waveforms"), logs: path.join(settings.libraryPath(), "logs"), cache: settings.cachePath(), diff --git a/enjoy/src/preload.ts b/enjoy/src/preload.ts index 5bda6216..cc520d18 100644 --- a/enjoy/src/preload.ts +++ b/enjoy/src/preload.ts @@ -3,6 +3,7 @@ import { contextBridge, ipcRenderer, IpcRendererEvent } from "electron"; import { version } from "../package.json"; import { Timeline } from "echogarden/dist/utilities/Timeline"; +import { RecognitionOptions } from "echogarden/dist/api/API"; contextBridge.exposeInMainWorld("__ENJOY_APP__", { app: { @@ -472,6 +473,9 @@ contextBridge.exposeInMainWorld("__ENJOY_APP__", { }, }, echogarden: { + recognize: (input: string, options: RecognitionOptions) => { + return ipcRenderer.invoke("echogarden-recognize", input, options); + }, align: (input: string, transcript: string, options: any) => { return ipcRenderer.invoke("echogarden-align", input, transcript, options); }, @@ -502,42 +506,6 @@ contextBridge.exposeInMainWorld("__ENJOY_APP__", { return ipcRenderer.invoke("echogarden-check"); }, }, - whisper: { - config: () => { - return ipcRenderer.invoke("whisper-config"); - }, - setModel: (model: string) => { - return ipcRenderer.invoke("whisper-set-model", model); - }, - check: () => { - return ipcRenderer.invoke("whisper-check"); - }, - transcribe: ( - params: { - file?: string; - blob?: { - type: string; - arrayBuffer: ArrayBuffer; - }; - }, - options?: { - language?: string; - force?: boolean; - extra?: string[]; - } - ) => { - return ipcRenderer.invoke("whisper-transcribe", params, options); - }, - onProgress: ( - callback: (event: IpcRendererEvent, progress: number) => void - ) => ipcRenderer.on("whisper-on-progress", callback), - abort: () => { - return ipcRenderer.invoke("whisper-abort"); - }, - removeProgressListeners: () => { - ipcRenderer.removeAllListeners("whisper-on-progress"); - }, - }, ffmpeg: { check: () => { return ipcRenderer.invoke("ffmpeg-check-command"); diff --git a/enjoy/src/renderer/components/medias/media-left-panel/media-transcription-generate-button.tsx b/enjoy/src/renderer/components/medias/media-left-panel/media-transcription-generate-button.tsx index e2ebec4a..d3fc1337 100644 --- a/enjoy/src/renderer/components/medias/media-left-panel/media-transcription-generate-button.tsx +++ b/enjoy/src/renderer/components/medias/media-left-panel/media-transcription-generate-button.tsx @@ -79,6 +79,7 @@ export const MediaTranscriptionGenerateButton = (props: { generateTranscription({ originalText: data.text, language: data.language, + model: data.model, service: data.service as SttEngineOptionEnum | "upload", isolate: data.isolate, }) diff --git a/enjoy/src/renderer/components/medias/media-loading-modal.tsx b/enjoy/src/renderer/components/medias/media-loading-modal.tsx index 2801356e..3cc3fa83 100644 --- a/enjoy/src/renderer/components/medias/media-loading-modal.tsx +++ b/enjoy/src/renderer/components/medias/media-loading-modal.tsx @@ -69,6 +69,7 @@ const LoadingContent = () => { generateTranscription({ originalText: data.text, language: data.language, + model: data.model, service: data.service as SttEngineOptionEnum | "upload", isolate: data.isolate, }); diff --git a/enjoy/src/renderer/components/preferences/index.ts b/enjoy/src/renderer/components/preferences/index.ts index d3f41da6..f591f0ac 100644 --- a/enjoy/src/renderer/components/preferences/index.ts +++ b/enjoy/src/renderer/components/preferences/index.ts @@ -28,7 +28,6 @@ export * from "./theme-settings"; export * from "./proxy-settings"; -export * from "./whisper-model-options"; export * from "./network-state"; export * from "./recorder-settings"; diff --git a/enjoy/src/renderer/components/preferences/stt-settings.tsx b/enjoy/src/renderer/components/preferences/stt-settings.tsx index eb6ef654..b99b151d 100644 --- a/enjoy/src/renderer/components/preferences/stt-settings.tsx +++ b/enjoy/src/renderer/components/preferences/stt-settings.tsx @@ -6,7 +6,6 @@ import { DialogContent, DialogHeader, DialogDescription, - DialogFooter, toast, Select, SelectTrigger, @@ -14,29 +13,26 @@ import { SelectItem, SelectValue, } from "@renderer/components/ui"; -import { WhisperModelOptions } from "@renderer/components"; import { AppSettingsProviderContext, AISettingsProviderContext, } from "@renderer/context"; -import { useContext, useEffect, useState } from "react"; -import { InfoIcon, AlertCircleIcon } from "lucide-react"; +import { useContext, useState } from "react"; +import { AlertCircleIcon } from "lucide-react"; import { SttEngineOptionEnum } from "@/types/enums"; +import { WHISPER_MODELS } from "@/constants"; export const WhisperSettings = () => { - const { sttEngine, whisperConfig, refreshWhisperConfig, setSttEngine } = - useContext(AISettingsProviderContext); + const { sttEngine, whisperModel, setWhisperModel, setSttEngine } = useContext( + AISettingsProviderContext + ); const { EnjoyApp } = useContext(AppSettingsProviderContext); const [stderr, setStderr] = useState(""); - useEffect(() => { - refreshWhisperConfig(); - }, []); - const handleCheck = async () => { toast.promise( async () => { - const { success, log } = await EnjoyApp.whisper.check(); + const { success, log } = await EnjoyApp.echogarden.check(); if (success) { setStderr(""); return Promise.resolve(); @@ -98,46 +94,28 @@ export const WhisperSettings = () => { - {sttEngine === "local" && ( + {sttEngine === SttEngineOptionEnum.LOCAL && ( <> + - - - - - - {t("sttAiService")} - - {t("chooseAIModelDependingOnYourHardware")} - - - - - -
- - - {t("yourModelsWillBeDownloadedTo", { - path: whisperConfig.modelsPath, - })} - - -
-
-
-
)} diff --git a/enjoy/src/renderer/components/preferences/whisper-model-options.tsx b/enjoy/src/renderer/components/preferences/whisper-model-options.tsx deleted file mode 100644 index 9bb7791c..00000000 --- a/enjoy/src/renderer/components/preferences/whisper-model-options.tsx +++ /dev/null @@ -1,196 +0,0 @@ -import { - AlertDialog, - AlertDialogTitle, - AlertDialogHeader, - AlertDialogFooter, - AlertDialogDescription, - AlertDialogContent, - AlertDialogAction, - AlertDialogCancel, - Button, - ScrollArea, - toast, - Progress, -} from "@renderer/components/ui"; -import { t } from "i18next"; -import { CheckCircle, DownloadIcon, XCircleIcon } from "lucide-react"; -import { WHISPER_MODELS_OPTIONS } from "@/constants"; -import { useState, useContext, useEffect } from "react"; -import { - AppSettingsProviderContext, - AISettingsProviderContext, -} from "@renderer/context"; - -type ModelType = { - type: string; - name: string; - size: string; - url: string; - downloaded?: boolean; - downloadState?: DownloadStateType; -}; - -export const WhisperModelOptions = () => { - const [selectingModel, setSelectingModel] = useState(null); - const [availableModels, setAvailableModels] = useState([]); - const { EnjoyApp } = useContext(AppSettingsProviderContext); - const { whisperConfig, setWhisperModel } = useContext( - AISettingsProviderContext - ); - - useEffect(() => { - updateAvailableModels(); - }, []); - - useEffect(() => { - listenToDownloadState(); - - return () => { - EnjoyApp.download.removeAllListeners(); - }; - }, [selectingModel]); - - const updateAvailableModels = async () => { - const models = whisperConfig.availableModels; - const options: ModelType[] = WHISPER_MODELS_OPTIONS; - - options.forEach((o) => { - o.downloaded = models.findIndex((m) => m.name === o.name) > -1; - }); - setAvailableModels(options); - }; - - const downloadModel = async () => { - if (!selectingModel) return; - - const model = WHISPER_MODELS_OPTIONS.find( - (m) => m.name === selectingModel.name - ); - - EnjoyApp.download.start(model.url, whisperConfig.modelsPath); - - setSelectingModel(null); - }; - - const listenToDownloadState = () => { - EnjoyApp.download.onState((_event, state) => { - const model = availableModels.find((m) => m.name === state.name); - if (!model) return; - - if (model) { - model.downloadState = state; - } - if (state.state === "completed") { - model.downloaded = true; - setWhisperModel(model.name); - } else if (state.state === "cancelled" || state.state === "interrupted") { - model.downloaded = false; - model.downloadState = null; - } - - setAvailableModels([...availableModels]); - }); - }; - - return ( - <> - - {availableModels.map((option) => { - return ( -
-
{ - if (option.downloaded) { - toast.promise(setWhisperModel(option.name), { - loading: t("checkingWhisperModel"), - success: t("whisperModelIsWorkingGood"), - error: t("whisperModelIsNotWorking"), - }); - } else if (!option.downloadState) { - setSelectingModel(option); - } - }} - > -
-
{option.type}
- {option.downloaded ? ( - - ) : ( - - )} -
-
- {option.name} - ~{option.size} -
-
- - {!option.downloaded && option.downloadState && ( -
- - -
- )} -
- ); - })} -
- - - - - {t("download")} - - {t("areYouSureToDownload", { - name: `${selectingModel?.name}(${selectingModel?.size})`, - })} - - - - { - setSelectingModel(null); - }} - > - {t("cancel")} - - downloadModel()}> - {t("download")} - - - - - - ); -}; diff --git a/enjoy/src/renderer/components/transcriptions/transcription-create-form.tsx b/enjoy/src/renderer/components/transcriptions/transcription-create-form.tsx index befd3085..8ee52f77 100644 --- a/enjoy/src/renderer/components/transcriptions/transcription-create-form.tsx +++ b/enjoy/src/renderer/components/transcriptions/transcription-create-form.tsx @@ -30,7 +30,7 @@ import { toast, } from "@renderer/components/ui"; import { t } from "i18next"; -import { LANGUAGES } from "@/constants"; +import { LANGUAGES, WHISPER_MODELS } from "@/constants"; import { ChevronDownIcon, ChevronUpIcon, LoaderIcon } from "lucide-react"; import { parseText } from "media-captions"; import { milisecondsToTimestamp } from "@/utils"; @@ -39,6 +39,7 @@ import { SttEngineOptionEnum } from "@/types/enums"; const transcriptionSchema = z.object({ language: z.string(), service: z.union([z.nativeEnum(SttEngineOptionEnum), z.literal("upload")]), + model: z.string().optional(), text: z.string().optional(), isolate: z.boolean().optional(), }); @@ -60,13 +61,14 @@ export const TranscriptionCreateForm = (props: { originalText, } = props; const { learningLanguage } = useContext(AppSettingsProviderContext); - const { sttEngine } = useContext(AISettingsProviderContext); + const { sttEngine, whisperModel } = useContext(AISettingsProviderContext); const form = useForm>({ resolver: zodResolver(transcriptionSchema), values: { language: learningLanguage, service: originalText ? "upload" : sttEngine, + model: sttEngine === SttEngineOptionEnum.LOCAL ? whisperModel : "", text: originalText, isolate: false, }, @@ -197,6 +199,35 @@ export const TranscriptionCreateForm = (props: { )} /> + + {form.watch("service") === SttEngineOptionEnum.LOCAL && ( + ( + + {t("model")} + + + )} + /> + )} + Promise; sttEngine?: SttEngineOptionEnum; setSttEngine?: (name: string) => Promise; - whisperConfig?: WhisperConfigType; - refreshWhisperConfig?: () => void; + whisperModel?: string; + setWhisperModel?: (name: string) => void; openai?: LlmProviderType; setOpenai?: (config: LlmProviderType) => void; setGptEngine?: (engine: GptEngineSettingType) => void; @@ -38,7 +38,7 @@ export const AISettingsProvider = ({ }, }); const [openai, setOpenai] = useState(null); - const [whisperConfig, setWhisperConfig] = useState(null); + const [whisperModel, setWhisperModel] = useState(null); const [sttEngine, setSttEngine] = useState( SttEngineOptionEnum.ENJOY_AZURE ); @@ -87,6 +87,44 @@ export const AISettingsProvider = ({ setTtsProviders({ ...providers }); }; + const refreshWhisperModel = async () => { + const whisperModel = await EnjoyApp.userSettings.get( + UserSettingKeyEnum.WHISPER + ); + if (WHISPER_MODELS.includes(whisperModel)) { + setWhisperModel(whisperModel); + } else { + let model = "tiny"; + if (whisperModel.match(/tiny/)) { + model = "tiny"; + } else if (whisperModel.match(/base/)) { + model = "base"; + } else if (whisperModel.match(/small/)) { + model = "small"; + } else if (whisperModel.match(/medium/)) { + model = "medium"; + } else if (whisperModel.match(/large/)) { + model = "large-v3-turbo"; + } + + if ( + learningLanguage.match(/en/) && + model.match(/tiny|base|small|medium/) + ) { + model = `${model}.en`; + } + + setWhisperModel(model); + } + }; + + const handleSetWhisperModel = async (name: string) => { + if (WHISPER_MODELS.includes(name)) { + setWhisperModel(name); + EnjoyApp.userSettings.set(UserSettingKeyEnum.WHISPER, name); + } + }; + useEffect(() => { refreshGptProviders(); refreshTtsProviders(); @@ -101,22 +139,8 @@ export const AISettingsProvider = ({ useEffect(() => { if (db.state !== "connected") return; if (!libraryPath) return; - - refreshWhisperConfig(); }, [db.state, libraryPath]); - const refreshWhisperConfig = async () => { - const config = await EnjoyApp.whisper.config(); - setWhisperConfig(config); - }; - - const setWhisperModel = async (name: string) => { - return EnjoyApp.whisper.setModel(name).then((config) => { - if (!config) return; - setWhisperConfig(config); - }); - }; - const handleSetSttEngine = async (name: SttEngineOptionEnum) => { setSttEngine(name); return EnjoyApp.userSettings.set(UserSettingKeyEnum.STT_ENGINE, name); @@ -165,6 +189,8 @@ export const AISettingsProvider = ({ setGptEngine(engine); }); } + + refreshWhisperModel(); }; const handleSetOpenai = async (config: LlmProviderType) => { @@ -208,9 +234,8 @@ export const AISettingsProvider = ({ }, openai, setOpenai: (config: LlmProviderType) => handleSetOpenai(config), - whisperConfig, - refreshWhisperConfig, - setWhisperModel, + whisperModel, + setWhisperModel: handleSetWhisperModel, sttEngine, setSttEngine: (name: SttEngineOptionEnum) => handleSetSttEngine(name), gptProviders, diff --git a/enjoy/src/renderer/context/media-shadow-provider.tsx b/enjoy/src/renderer/context/media-shadow-provider.tsx index f4e22046..b1e291d5 100644 --- a/enjoy/src/renderer/context/media-shadow-provider.tsx +++ b/enjoy/src/renderer/context/media-shadow-provider.tsx @@ -62,6 +62,7 @@ type MediaShadowContextType = { generateTranscription: (params?: { originalText?: string; language?: string; + model?: string; service?: SttEngineOptionEnum | "upload"; isolate?: boolean; }) => Promise; diff --git a/enjoy/src/renderer/hooks/use-transcribe.tsx b/enjoy/src/renderer/hooks/use-transcribe.tsx index 68e0c110..1213be6f 100644 --- a/enjoy/src/renderer/hooks/use-transcribe.tsx +++ b/enjoy/src/renderer/hooks/use-transcribe.tsx @@ -11,14 +11,14 @@ import axios from "axios"; import { useAiCommand } from "./use-ai-command"; import { toast } from "@renderer/components/ui"; import { - Timeline, TimelineEntry, type TimelineEntryType, } from "echogarden/dist/utilities/Timeline"; -import take from "lodash/take"; -import sortedUniqBy from "lodash/sortedUniqBy"; import { parseText } from "media-captions"; import { SttEngineOptionEnum } from "@/types/enums"; +import { RecognitionResult } from "echogarden/dist/api/API.js"; +import take from "lodash/take"; +import sortedUniqBy from "lodash/sortedUniqBy"; // test a text string has any punctuations or not // some transcribed text may not have any punctuations @@ -26,7 +26,7 @@ const punctuationsPattern = /\w[.,!?](\s|$)/g; export const useTranscribe = () => { const { EnjoyApp, user, webApi } = useContext(AppSettingsProviderContext); - const { openai } = useContext(AISettingsProviderContext); + const { openai, whisperModel } = useContext(AISettingsProviderContext); const { punctuateText } = useAiCommand(); const [output, setOutput] = useState(""); @@ -47,6 +47,7 @@ export const useTranscribe = () => { params?: { targetId?: string; targetType?: string; + model?: string; originalText?: string; language: string; service: SttEngineOptionEnum | "upload"; @@ -64,6 +65,7 @@ export const useTranscribe = () => { }> => { const url = await transcode(mediaSrc); const { + model, targetId, targetType, originalText, @@ -75,34 +77,11 @@ export const useTranscribe = () => { const blob = await (await fetch(url)).blob(); let result: any; - let timeline: Timeline = []; + if (service === "upload" && originalText) { - const caption = await parseText(originalText, { type: "srt" }); - if (caption.cues.length > 0) { - timeline = caption.cues.map((cue) => { - return { - type: "sentence", - text: cue.text, - startTime: cue.startTime, - endTime: cue.endTime, - timeline: [], - }; - }); - result = { - engine: "upload", - model: "-", - text: timeline.map((entry) => entry.text).join(" "), - timeline, - }; - } else { - result = { - engine: "upload", - model: "-", - text: originalText, - }; - } + result = await alignText(originalText); } else if (service === SttEngineOptionEnum.LOCAL) { - result = await transcribeByLocal(url, language); + result = await transcribeByLocal(url, { language, model }); } else if (service === SttEngineOptionEnum.ENJOY_CLOUDFLARE) { result = await transcribeByCloudflareAi(blob); } else if (service === SttEngineOptionEnum.OPENAI) { @@ -121,67 +100,47 @@ export const useTranscribe = () => { ); } - let transcript = result.text; - if (!align) { + const { segmentTimeline, transcript } = result; + + if (!align && transcript) { return { ...result, - transcript, - url, + timeline: [], }; } - /* - * if timeline is available and the transcript contains punctuations - * use `alignSegments` to align each sentence with the timeline - * otherwise, use `align` to align the whole transcript - * if the transcript does not contain any punctuation, use AI command to add punctuation - */ - if (result.timeline?.length && transcript.match(punctuationsPattern)) { - timeline = [...result.timeline]; - setOutput("Aligning the transcript..."); + if (segmentTimeline && segmentTimeline.length > 0) { const wordTimeline = await EnjoyApp.echogarden.alignSegments( new Uint8Array(await blob.arrayBuffer()), - timeline, + segmentTimeline, { - language, + language: language.split("-")[0], isolate, } ); - timeline = await EnjoyApp.echogarden.wordToSentenceTimeline( + + const timeline = await EnjoyApp.echogarden.wordToSentenceTimeline( wordTimeline, transcript, language.split("-")[0] ); - } else { - // Remove all content inside `()`, `[]`, `{}` and trim the text - // remove all markdown formatting - transcript = transcript - .replace(/\(.*?\)/g, "") - .replace(/\[.*?\]/g, "") - .replace(/\{.*?\}/g, "") - .replace(/[*_`]/g, "") - .trim(); - - // if the transcript does not contain any punctuation, use AI command to add punctuation - if (!transcript.match(punctuationsPattern)) { - try { - transcript = await punctuateText(transcript); - } catch (err) { - toast.error(err.message); - console.warn(err.message); - } - } + return { + ...result, + timeline, + }; + } else if (transcript) { setOutput("Aligning the transcript..."); const alignmentResult = await EnjoyApp.echogarden.align( new Uint8Array(await blob.arrayBuffer()), transcript, { - language, + language: language.split("-")[0], isolate, } ); + const timeline: TimelineEntry[] = []; alignmentResult.timeline.forEach((t: TimelineEntry) => { if (t.type === "sentence") { timeline.push(t); @@ -191,74 +150,120 @@ export const useTranscribe = () => { }); } }); - } - return { - ...result, - originalText, - transcript, - timeline, - url, - }; + return { + ...result, + timeline, + }; + } else { + throw new Error(t("transcribeFailed")); + } + }; + + const alignText = async ( + originalText: string + ): Promise<{ + engine: string; + model: string; + transcript: string; + segmentTimeline: TimelineEntry[]; + }> => { + const caption = await parseText(originalText, { type: "srt" }); + + if (caption.cues.length > 0) { + // valid srt file + const segmentTimeline = caption.cues.map((cue) => { + return { + type: "segment", + text: cue.text, + startTime: cue.startTime, + endTime: cue.endTime, + timeline: [], + } as TimelineEntry; + }); + + return { + engine: "upload", + model: "-", + transcript: segmentTimeline.map((entry) => entry.text).join(" "), + segmentTimeline, + }; + } else { + // Remove all content inside `()`, `[]`, `{}` and trim the text + // remove all markdown formatting + let transcript = originalText + .replace(/\(.*?\)/g, "") + .replace(/\[.*?\]/g, "") + .replace(/\{.*?\}/g, "") + .replace(/[*_`]/g, "") + .trim(); + + // if the transcript does not contain any punctuation, use AI command to add punctuation + if (!transcript.match(punctuationsPattern)) { + try { + const punctuatedText = await punctuateText(transcript); + transcript = punctuatedText; + } catch (err) { + toast.error(err.message); + console.warn(err); + } + } + + return { + engine: "upload", + model: "-", + transcript, + segmentTimeline: [], + }; + } }; const transcribeByLocal = async ( url: string, - language?: string + options: { language: string; model?: string } ): Promise<{ engine: string; model: string; - text: string; - timeline: TimelineEntry[]; + transcript: string; + segmentTimeline: TimelineEntry[]; }> => { - const res = await EnjoyApp.whisper.transcribe( - { - file: url, - }, - { - language, - force: true, - extra: ["--prompt", `"Hello! Welcome to listen to this audio."`], - } - ); - - if (!res) { - throw new Error(t("whisperTranscribeFailed", { error: "" })); + let { language, model = whisperModel } = options || {}; + const languageCode = language.split("-")[0]; + if (model.match(/en/) && languageCode !== "en") { + model = model.replace(".en", ""); } - const timeline: TimelineEntry[] = res.transcription - .map((segment) => { - // ignore the word if it is empty or in the format of `[xxx]` or `(xxx)` - if ( - !segment.text.trim() || - segment.text.trim().match(/^[\[\(].+[\]\)]$/) - ) { - return null; - } + let res: RecognitionResult; + try { + res = await EnjoyApp.echogarden.recognize(url, { + engine: "whisper", + language: languageCode, + whisper: { + model, + }, + }); + } catch (err) { + throw new Error(t("whisperTranscribeFailed", { error: err.message })); + } - return { - type: "segment" as TimelineEntryType, - text: segment.text.trim(), - startTime: segment.offsets.from / 1000.0, - endTime: segment.offsets.to / 1000.0, - }; - }) - .filter((s) => Boolean(s?.text)); - - const transcript = timeline - .map((segment) => segment.text) - .join(" ") - .trim(); + const { transcript, timeline } = res; return { engine: "whisper", - model: res.model.type, - text: transcript, - timeline, + model, + transcript, + segmentTimeline: timeline, }; }; - const transcribeByOpenAi = async (file: File) => { + const transcribeByOpenAi = async ( + file: File + ): Promise<{ + engine: string; + model: string; + transcript: string; + segmentTimeline: TimelineEntry[]; + }> => { if (!openai?.key) { throw new Error(t("openaiKeyRequired")); } @@ -270,38 +275,39 @@ export const useTranscribe = () => { maxRetries: 0, }); - const res: { - text: string; - words?: { word: string; start: number; end: number }[]; - segments?: { text: string; start: number; end: number }[]; - } = (await client.audio.transcriptions.create({ - file, - model: "whisper-1", - response_format: "verbose_json", - timestamp_granularities: ["segment"], - })) as any; + setOutput("Transcribing from OpenAI..."); + try { + const res: { + text: string; + words?: { word: string; start: number; end: number }[]; + segments?: { text: string; start: number; end: number }[]; + } = (await client.audio.transcriptions.create({ + file, + model: "whisper-1", + response_format: "verbose_json", + timestamp_granularities: ["word", "segment"], + })) as any; - let timeline: TimelineEntry[] = []; - if (res.segments) { - res.segments.forEach((segment) => { - const segmentTimeline = { + setOutput("Aligning the transcript..."); + const segmentTimeline = (res.segments || []).map((segment) => { + return { type: "segment" as TimelineEntryType, text: segment.text, startTime: segment.start, endTime: segment.end, - timeline: [] as Timeline, + timeline: [] as TimelineEntry[], }; - - timeline.push(segmentTimeline); }); - } - return { - engine: "openai", - model: "whisper-1", - text: res.text, - timeline, - }; + return { + engine: "openai", + model: "whisper-1", + transcript: res.text, + segmentTimeline, + }; + } catch (err) { + throw new Error(t("openaiTranscribeFailed", { error: err.message })); + } }; const transcribeByCloudflareAi = async ( @@ -309,35 +315,47 @@ export const useTranscribe = () => { ): Promise<{ engine: string; model: string; - text: string; - timeline?: TimelineEntry[]; + transcript: string; + segmentTimeline: TimelineEntry[]; }> => { - const res: CfWhipserOutputType = ( - await axios.postForm(`${AI_WORKER_ENDPOINT}/audio/transcriptions`, blob, { - headers: { - Authorization: `Bearer ${user.accessToken}`, - }, - timeout: 1000 * 60 * 5, - }) - ).data; + setOutput("Transcribing from Cloudflare..."); + try { + const res: CfWhipserOutputType = ( + await axios.postForm( + `${AI_WORKER_ENDPOINT}/audio/transcriptions`, + blob, + { + headers: { + Authorization: `Bearer ${user.accessToken}`, + }, + timeout: 1000 * 60 * 5, + } + ) + ).data; + + const segmentTimeline: TimelineEntry[] = []; + if (res.vtt) { + const caption = await parseText(res.vtt, { type: "vtt" }); + for (const cue of caption.cues) { + segmentTimeline.push({ + type: "segment", + text: cue.text, + startTime: cue.startTime, + endTime: cue.endTime, + timeline: [], + }); + } + } - const caption = await parseText(res.vtt, { type: "vtt" }); - const timeline: Timeline = caption.cues.map((cue) => { return { - type: "segment", - text: cue.text, - startTime: cue.startTime, - endTime: cue.endTime, - timeline: [], + engine: "cloudflare", + model: "@cf/openai/whisper", + transcript: res.text, + segmentTimeline, }; - }); - - return { - engine: "cloudflare", - model: "@cf/openai/whisper", - text: res.text, - timeline, - }; + } catch (err) { + throw new Error(t("cloudflareTranscribeFailed", { error: err.message })); + } }; const transcribeByAzureAi = async ( @@ -350,9 +368,9 @@ export const useTranscribe = () => { ): Promise<{ engine: string; model: string; - text: string; + transcript: string; + segmentTimeline: TimelineEntry[]; tokenId: number; - timeline?: TimelineEntry[]; }> => { const { id, token, region } = await webApi.generateSpeechToken({ ...params, @@ -364,90 +382,92 @@ export const useTranscribe = () => { config.speechRecognitionLanguage = language; config.requestWordLevelTimestamps(); config.outputFormat = sdk.OutputFormat.Detailed; + config.setProfanity(sdk.ProfanityOption.Raw); // create the speech recognizer. const reco = new sdk.SpeechRecognizer(config, audioConfig); + setOutput("Transcribing from Azure..."); let results: SpeechRecognitionResultType[] = []; - const res: { - engine: string; - model: string; - text: string; - tokenId: number; - timeline?: TimelineEntry[]; - } = await new Promise((resolve, reject) => { - reco.recognizing = (_s, e) => { - setOutput(e.result.text); - }; + const { transcript, segmentTimeline }: any = await new Promise( + (resolve, reject) => { + reco.recognizing = (_s, e) => { + setOutput((prev) => prev + e.result.text); + }; - reco.recognized = (_s, e) => { - const json = e.result.properties.getProperty( - sdk.PropertyId.SpeechServiceResponse_JsonResult - ); - const result = JSON.parse(json); - results = results.concat(result); - }; + reco.recognized = (_s, e) => { + const json = e.result.properties.getProperty( + sdk.PropertyId.SpeechServiceResponse_JsonResult + ); + const result = JSON.parse(json); + results = results.concat(result); + }; - reco.canceled = (_s, e) => { - if (e.reason === sdk.CancellationReason.Error) { - return reject(new Error(e.errorDetails)); - } + reco.canceled = (_s, e) => { + if (e.reason === sdk.CancellationReason.Error) { + return reject(new Error(e.errorDetails)); + } - reco.stopContinuousRecognitionAsync(); - console.log("CANCELED: Reason=" + e.reason); - }; + reco.stopContinuousRecognitionAsync(); + console.log("CANCELED: Reason=" + e.reason); + }; - reco.sessionStopped = async (_s, e) => { - console.log( - "Session stopped. Stop continuous recognition.", - e.sessionId, - results - ); - reco.stopContinuousRecognitionAsync(); + reco.sessionStopped = async (_s, e) => { + console.log( + "Session stopped. Stop continuous recognition.", + e.sessionId, + results + ); + reco.stopContinuousRecognitionAsync(); - try { - const timeline: Timeline = []; - results.forEach((result) => { - if (!result.DisplayText) return; + if (results.length === 0) { + return reject(t("azureTranscribeFailed", { error: "" })); + } - const best = take(sortedUniqBy(result.NBest, "Confidence"), 1)[0]; - if (!best.Words) return; - if (!best.Confidence || best.Confidence < 0.5) return; + try { + const transcript = results + .map((result) => result.DisplayText) + .join(" "); + const segmentTimeline: TimelineEntry[] = []; + results.forEach((result) => { + if (!result.DisplayText) return; - const firstWord = best.Words[0]; - const lastWord = best.Words[best.Words.length - 1]; + const best = take(sortedUniqBy(result.NBest, "Confidence"), 1)[0]; + if (!best.Words) return; + if (!best.Confidence || best.Confidence < 0.5) return; - timeline.push({ - type: "sentence", - text: best.Display, - startTime: firstWord.Offset / 10000000.0, - endTime: (lastWord.Offset + lastWord.Duration) / 10000000.0, - timeline: [], + const firstWord = best.Words[0]; + const lastWord = best.Words[best.Words.length - 1]; + + segmentTimeline.push({ + type: "segment", + text: best.Display, + startTime: firstWord.Offset / 10000000.0, + endTime: (lastWord.Offset + lastWord.Duration) / 10000000.0, + timeline: [], + }); }); - }); - const transcript = timeline - .map((result) => result.text) - .join(" ") - .trim(); + resolve({ + transcript, + segmentTimeline, + }); + } catch (err) { + reject(t("azureTranscribeFailed", { error: err.message })); + } + }; + reco.startContinuousRecognitionAsync(); + } + ); - resolve({ - engine: "azure", - model: "whisper", - text: transcript, - timeline, - tokenId: id, - }); - } catch (err) { - reject(err); - } - }; - - reco.startContinuousRecognitionAsync(); - }); - - return res; + return { + engine: "azure", + model: "whisper", + transcript, + segmentTimeline, + tokenId: id, + }; }; return { diff --git a/enjoy/src/renderer/hooks/use-transcriptions.tsx b/enjoy/src/renderer/hooks/use-transcriptions.tsx index c7cae9be..94dad013 100644 --- a/enjoy/src/renderer/hooks/use-transcriptions.tsx +++ b/enjoy/src/renderer/hooks/use-transcriptions.tsx @@ -107,6 +107,7 @@ export const useTranscriptions = (media: AudioType | VideoType) => { const generateTranscription = async (params?: { originalText?: string; language?: string; + model?: string; service?: SttEngineOptionEnum | "upload"; isolate?: boolean; }) => { @@ -114,6 +115,7 @@ export const useTranscriptions = (media: AudioType | VideoType) => { originalText, language = learningLanguage, service = sttEngine, + model, isolate = false, } = params || {}; setService(service); @@ -279,26 +281,17 @@ export const useTranscriptions = (media: AudioType | VideoType) => { useEffect(() => { if (!transcribing) return; - if (service === "local") { - EnjoyApp.whisper.onProgress((_, p: number) => { - if (p > 100) p = 100; - setTranscribingProgress(p); - }); - } - EnjoyApp.app.onCmdOutput((_, output) => { setTranscribingOutput(output); }); return () => { - EnjoyApp.whisper.removeProgressListeners(); EnjoyApp.app.removeCmdOutputListeners(); setTranscribingOutput(null); }; }, [media, service, transcribing]); const abortGenerateTranscription = () => { - EnjoyApp.whisper.abort(); setTranscribing(false); }; diff --git a/enjoy/src/types/enjoy-app.d.ts b/enjoy/src/types/enjoy-app.d.ts index a20bf846..36077cc0 100644 --- a/enjoy/src/types/enjoy-app.d.ts +++ b/enjoy/src/types/enjoy-app.d.ts @@ -273,6 +273,10 @@ type EnjoyAppType = { ) => Promise; }; echogarden: { + recognize: ( + input: string, + options: RecognitionOptions + ) => Promise; align: ( input: string | Uint8Array, transcript: string, @@ -289,26 +293,7 @@ type EnjoyAppType = { language: string ) => Promise; transcode: (input: string) => Promise; - check: () => Promise; - }; - whisper: { - config: () => Promise; - check: () => Promise<{ success: boolean; log: string }>; - setModel: (model: string) => Promise; - transcribe: ( - params: { - file?: string; - blob?: { type: string; arrayBuffer: ArrayBuffer }; - }, - options?: { - language?: string; - force?: boolean; - extra?: string[]; - } - ) => Promise>; - onProgress: (callback: (event, progress: number) => void) => void; - abort: () => Promise; - removeProgressListeners: () => Promise; + check: (options?: any) => Promise<{ success: boolean; log: string }>; }; ffmpeg: { check: () => Promise; diff --git a/enjoy/src/types/index.d.ts b/enjoy/src/types/index.d.ts index 9080f147..57ce468d 100644 --- a/enjoy/src/types/index.d.ts +++ b/enjoy/src/types/index.d.ts @@ -235,3 +235,26 @@ type DictSettingType = { removing: string[]; mdicts: MDict[]; }; + +type TranscribeParamsType = { + mediaSrc: string | Blob; + params?: { + targetId?: string; + targetType?: string; + originalText?: string; + language: string; + service: SttEngineOptionEnum | "upload"; + isolate?: boolean; + align?: boolean; + }; +}; + +type TranscribeResultType = { + engine: string; + model: string; + transcript: string; + timeline: TimelineEntry[]; + originalText?: string; + tokenId?: number; + url: string; +}; diff --git a/enjoy/vite.main.config.ts b/enjoy/vite.main.config.ts index df684736..97bed2e3 100644 --- a/enjoy/vite.main.config.ts +++ b/enjoy/vite.main.config.ts @@ -44,16 +44,6 @@ export default defineConfig((env) => { pluginHotRestart("restart"), viteStaticCopy({ targets: [ - { - src: `lib/whisper.cpp/${ - process.env.PACKAGE_OS_ARCH || os.arch() - }/${os.platform()}/*`, - dest: "lib/whisper", - }, - { - src: `lib/whisper.cpp/models/*`, - dest: "lib/whisper/models", - }, { src: `lib/youtubedr/${ process.env.PACKAGE_OS_ARCH || os.arch() diff --git a/yarn.lock b/yarn.lock index bb67f2c4..018f9aa3 100644 --- a/yarn.lock +++ b/yarn.lock @@ -32,10 +32,10 @@ __metadata: "@nuxtjs/seo": "npm:^2.0.0-rc.23" autoprefixer: "npm:^10.4.20" nuxt: "npm:^3.13.2" - nuxt-og-image: "npm:^3.0.4" + nuxt-og-image: "npm:^3.0.6" postcss: "npm:^8.4.47" sass: "npm:^1.79.5" - tailwindcss: "npm:^3.4.13" + tailwindcss: "npm:^3.4.14" vue: "npm:^3.5.12" vue-router: "npm:^4.4.5" languageName: unknown @@ -363,7 +363,7 @@ __metadata: languageName: node linkType: hard -"@aws-sdk/client-polly@npm:^3.669.0": +"@aws-sdk/client-polly@npm:^3.670.0": version: 3.670.0 resolution: "@aws-sdk/client-polly@npm:3.670.0" dependencies: @@ -622,9 +622,9 @@ __metadata: languageName: node linkType: hard -"@aws-sdk/client-transcribe-streaming@npm:^3.669.0": - version: 3.670.0 - resolution: "@aws-sdk/client-transcribe-streaming@npm:3.670.0" +"@aws-sdk/client-transcribe-streaming@npm:^3.672.0": + version: 3.672.0 + resolution: "@aws-sdk/client-transcribe-streaming@npm:3.672.0" dependencies: "@aws-crypto/sha256-browser": "npm:5.2.0" "@aws-crypto/sha256-js": "npm:5.2.0" @@ -674,7 +674,7 @@ __metadata: "@smithy/util-retry": "npm:^3.0.7" "@smithy/util-utf8": "npm:^3.0.0" tslib: "npm:^2.6.2" - checksum: 10c0/24952fb07dddce9c82cb8082fd7382681e980d01a072b64278aa9c4d3e20b36b345ad813d4f890d902b3f34fcfc9c252c37f7af165e0e2b40c520cb797ad1852 + checksum: 10c0/8470177bcf36036a76db9567e01d747c7db5e4977cacee4577a1864da64cda069fd96971829e61a28fb98da34e07de55c0cb6e202443edb7d1504e72dcc53782 languageName: node linkType: hard @@ -2380,17 +2380,6 @@ __metadata: languageName: node linkType: hard -"@es-joy/jsdoccomment@npm:^0.46.0": - version: 0.46.0 - resolution: "@es-joy/jsdoccomment@npm:0.46.0" - dependencies: - comment-parser: "npm:1.4.1" - esquery: "npm:^1.6.0" - jsdoc-type-pratt-parser: "npm:~4.0.0" - checksum: 10c0/a7a67936ebf6d9aaf74af018c3ac744769af3552b05ad9b88fca96b2ffdca16e724b0ff497f53634ec4cca81e98d8c471b6b6bde0fa5b725af4222ad9a0707f0 - languageName: node - linkType: hard - "@esbuild/aix-ppc64@npm:0.20.2": version: 0.20.2 resolution: "@esbuild/aix-ppc64@npm:0.20.2" @@ -3839,7 +3828,7 @@ __metadata: languageName: node linkType: hard -"@nuxt/devtools-kit@npm:1.6.0, @nuxt/devtools-kit@npm:^1.4.1, @nuxt/devtools-kit@npm:^1.4.2, @nuxt/devtools-kit@npm:^1.5.1": +"@nuxt/devtools-kit@npm:1.6.0, @nuxt/devtools-kit@npm:^1.4.1, @nuxt/devtools-kit@npm:^1.4.2, @nuxt/devtools-kit@npm:^1.5.1, @nuxt/devtools-kit@npm:^1.6.0": version: 1.6.0 resolution: "@nuxt/devtools-kit@npm:1.6.0" dependencies: @@ -4670,14 +4659,14 @@ __metadata: languageName: node linkType: hard -"@playwright/test@npm:^1.48.0": - version: 1.48.0 - resolution: "@playwright/test@npm:1.48.0" +"@playwright/test@npm:^1.48.1": + version: 1.48.1 + resolution: "@playwright/test@npm:1.48.1" dependencies: - playwright: "npm:1.48.0" + playwright: "npm:1.48.1" bin: playwright: cli.js - checksum: 10c0/595a5db99f75d1ffce11df2980ca0dd20523a0a20dd962e4256e536dddccca0138883ae2e9c82aaa26094fb50def73062badc14c001d20fb62a3a0b2df18ebca + checksum: 10c0/32cedc3b2d375cb8f4a830bc820d7726b0235be7a6202e1d6ee46e739b83666271c47c100c11311cf5a916468c18e6a4dc526accf9ef090786e7614c2633b2b8 languageName: node linkType: hard @@ -6056,10 +6045,10 @@ __metadata: languageName: node linkType: hard -"@rails/actioncable@npm:7.2.100": - version: 7.2.100 - resolution: "@rails/actioncable@npm:7.2.100" - checksum: 10c0/dc24fcd1a9f153e5550c691ddff5ea7d3d7589a3be39cf14991d3fb6e77d1d952dbf9ab5f00c9781f85e750d94f7a80ade3231ab284b008d6de311ed33b21cb9 +"@rails/actioncable@npm:7.2.101": + version: 7.2.101 + resolution: "@rails/actioncable@npm:7.2.101" + checksum: 10c0/5917b0d5c51c8b38bb33532fd6483937d6bf030add46e2bca0272aed5b5b43ceec881f8c89a9784b416cb86a2655198f0faac4929a67caa4812ed622346cdc90 languageName: node linkType: hard @@ -8048,7 +8037,7 @@ __metadata: languageName: node linkType: hard -"@unocss/preset-wind@npm:^0.63.1": +"@unocss/preset-wind@npm:^0.63.1, @unocss/preset-wind@npm:^0.63.4": version: 0.63.4 resolution: "@unocss/preset-wind@npm:0.63.4" dependencies: @@ -9748,12 +9737,12 @@ __metadata: languageName: node linkType: hard -"chart.js@npm:^4.4.4": - version: 4.4.4 - resolution: "chart.js@npm:4.4.4" +"chart.js@npm:^4.4.5": + version: 4.4.5 + resolution: "chart.js@npm:4.4.5" dependencies: "@kurkle/color": "npm:^0.3.0" - checksum: 10c0/9fa3206403a6103916f7762c2665d322c42b0cc07fba91526b1d033ddb887c1ba74b3ebc0bd0748a9e55abd1017f25fdb2292cdd6579d8c2d3bcb1c58f71281c + checksum: 10c0/042393725d1dc6a7fc55757abc5db88cbce0e84ad115c08d552af56d72197e806eb78e04253d12a5f0cf331a57a059dc30fc2dab9324e4f695806b3a087d1b0c languageName: node linkType: hard @@ -10251,13 +10240,6 @@ __metadata: languageName: node linkType: hard -"comment-parser@npm:1.4.1": - version: 1.4.1 - resolution: "comment-parser@npm:1.4.1" - checksum: 10c0/d6c4be3f5be058f98b24f2d557f745d8fe1cc9eb75bebbdccabd404a0e1ed41563171b16285f593011f8b6a5ec81f564fb1f2121418ac5cbf0f49255bf0840dd - languageName: node - linkType: hard - "commondir@npm:^1.0.1": version: 1.0.1 resolution: "commondir@npm:1.0.1" @@ -11725,12 +11707,12 @@ __metadata: languageName: node linkType: hard -"echogarden@npm:^1.8.1": - version: 1.8.1 - resolution: "echogarden@npm:1.8.1" +"echogarden@npm:^1.8.4": + version: 1.8.4 + resolution: "echogarden@npm:1.8.4" dependencies: - "@aws-sdk/client-polly": "npm:^3.669.0" - "@aws-sdk/client-transcribe-streaming": "npm:^3.669.0" + "@aws-sdk/client-polly": "npm:^3.670.0" + "@aws-sdk/client-transcribe-streaming": "npm:^3.672.0" "@echogarden/espeak-ng-emscripten": "npm:^0.1.2" "@echogarden/fasttext-wasm": "npm:^0.1.0" "@echogarden/flite-wasi": "npm:^0.1.1" @@ -11760,7 +11742,7 @@ __metadata: jsdom: "npm:^25.0.1" json5: "npm:^2.2.3" kuromoji: "npm:^0.1.2" - microsoft-cognitiveservices-speech-sdk: "npm:^1.40.0" + microsoft-cognitiveservices-speech-sdk: "npm:^1.41.0" moving-median: "npm:^1.0.0" msgpack-lite: "npm:^0.1.26" onnxruntime-node: "npm:^1.19.2" @@ -11975,16 +11957,16 @@ __metadata: languageName: node linkType: hard -"electron@npm:^32.2.0": - version: 32.2.0 - resolution: "electron@npm:32.2.0" +"electron@npm:^33.0.1": + version: 33.0.1 + resolution: "electron@npm:33.0.1" dependencies: "@electron/get": "npm:^2.0.0" "@types/node": "npm:^20.9.0" extract-zip: "npm:^2.0.1" bin: electron: cli.js - checksum: 10c0/28d988a9d05c89e93d70cc790bd53ecc97135cc3fa9efe3617f10b87cdf85ada468d383afd7858bcf8f064aa189ea7f8987e32c7ebaac70bda64e8f9f85621e8 + checksum: 10c0/cd5e5cef21df2d5e1ffc95cfe90397c984152db54ac1032deb97f1716f778d41f5fa7c83e42bc86ae3c509a6d45f578515a540b6453ebf818310cc303e95f083 languageName: node linkType: hard @@ -12123,7 +12105,7 @@ __metadata: "@langchain/core": "npm:^0.3.11" "@langchain/ollama": "npm:^0.1.0" "@mozilla/readability": "npm:^0.5.0" - "@playwright/test": "npm:^1.48.0" + "@playwright/test": "npm:^1.48.1" "@radix-ui/react-accordion": "npm:^1.2.1" "@radix-ui/react-alert-dialog": "npm:^1.1.2" "@radix-ui/react-aspect-ratio": "npm:^1.1.0" @@ -12148,7 +12130,7 @@ __metadata: "@radix-ui/react-toast": "npm:^1.2.2" "@radix-ui/react-toggle": "npm:^1.1.0" "@radix-ui/react-tooltip": "npm:^1.1.3" - "@rails/actioncable": "npm:7.2.100" + "@rails/actioncable": "npm:7.2.101" "@tailwindcss/typography": "npm:^0.5.15" "@types/ahoy.js": "npm:^0.4.2" "@types/autosize": "npm:^4.0.3" @@ -12180,7 +12162,7 @@ __metadata: axios: "npm:^1.7.7" camelcase: "npm:^8.0.0" camelcase-keys: "npm:^9.1.3" - chart.js: "npm:^4.4.4" + chart.js: "npm:^4.4.5" cheerio: "npm:^1.0.0" class-variance-authority: "npm:^0.7.0" clsx: "npm:^2.1.1" @@ -12192,8 +12174,8 @@ __metadata: dayjs: "npm:^1.11.13" decamelize: "npm:^6.0.0" decamelize-keys: "npm:^2.0.1" - echogarden: "npm:^1.8.1" - electron: "npm:^32.2.0" + echogarden: "npm:^1.8.4" + electron: "npm:^33.0.1" electron-context-menu: "npm:^4.0.4" electron-devtools-installer: "npm:^3.2.0" electron-log: "npm:^5.2.0" @@ -12215,9 +12197,9 @@ __metadata: langchain: "npm:^0.3.2" lodash: "npm:^4.17.21" lru-cache: "npm:^11.0.1" - lucide-react: "npm:^0.452.0" + lucide-react: "npm:^0.453.0" mark.js: "npm:^8.11.1" - microsoft-cognitiveservices-speech-sdk: "npm:^1.40.0" + microsoft-cognitiveservices-speech-sdk: "npm:^1.41.0" mustache: "npm:^4.2.0" next-themes: "npm:^0.3.0" octokit: "npm:^4.0.2" @@ -12251,10 +12233,10 @@ __metadata: tailwind-merge: "npm:^2.5.4" tailwind-scrollbar: "npm:^3.1.0" tailwind-scrollbar-hide: "npm:^1.1.7" - tailwindcss: "npm:^3.4.13" + tailwindcss: "npm:^3.4.14" tailwindcss-animate: "npm:^1.0.7" ts-node: "npm:^10.9.2" - tslib: "npm:^2.7.0" + tslib: "npm:^2.8.0" typescript: "npm:^5.6.3" umzug: "npm:^3.8.2" unzipper: "npm:^0.12.3" @@ -12925,7 +12907,7 @@ __metadata: languageName: node linkType: hard -"esquery@npm:^1.5.0, esquery@npm:^1.6.0": +"esquery@npm:^1.5.0": version: 1.6.0 resolution: "esquery@npm:1.6.0" dependencies: @@ -15652,13 +15634,6 @@ __metadata: languageName: node linkType: hard -"jsdoc-type-pratt-parser@npm:~4.0.0": - version: 4.0.0 - resolution: "jsdoc-type-pratt-parser@npm:4.0.0" - checksum: 10c0/b23ef7bbbe2f56d72630d1c5a233dc9fecaff399063d373c57bef136908c1b05e723dac107177303c03ccf8d75aa51507510b282aa567600477479c5ea0c36d1 - languageName: node - linkType: hard - "jsdom@npm:^25.0.1": version: 25.0.1 resolution: "jsdom@npm:25.0.1" @@ -16367,12 +16342,12 @@ __metadata: languageName: node linkType: hard -"lucide-react@npm:^0.452.0": - version: 0.452.0 - resolution: "lucide-react@npm:0.452.0" +"lucide-react@npm:^0.453.0": + version: 0.453.0 + resolution: "lucide-react@npm:0.453.0" peerDependencies: react: ^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0-rc - checksum: 10c0/8e8dd351dfcdffc65a3a76eba04f08310fb13c8fce3e2cab74246c72df45d37e467d521b36053298fb33fe1c5c4770be4750eefa224ae5cf6258dffe95c72da1 + checksum: 10c0/eec257049609eff958329943da6e129c8db1969aa789c6e922f59418d7b6f6913c22e971ac7885efc80eeb3e41e180c48afdee92738a63a31ef5a39836037b7e languageName: node linkType: hard @@ -16395,7 +16370,7 @@ __metadata: languageName: node linkType: hard -"magic-string@npm:^0.30.0, magic-string@npm:^0.30.10, magic-string@npm:^0.30.11, magic-string@npm:^0.30.3, magic-string@npm:^0.30.4, magic-string@npm:^0.30.8": +"magic-string@npm:^0.30.0, magic-string@npm:^0.30.10, magic-string@npm:^0.30.11, magic-string@npm:^0.30.12, magic-string@npm:^0.30.3, magic-string@npm:^0.30.4, magic-string@npm:^0.30.8": version: 0.30.12 resolution: "magic-string@npm:0.30.12" dependencies: @@ -17112,18 +17087,17 @@ __metadata: languageName: node linkType: hard -"microsoft-cognitiveservices-speech-sdk@npm:^1.40.0": - version: 1.40.0 - resolution: "microsoft-cognitiveservices-speech-sdk@npm:1.40.0" +"microsoft-cognitiveservices-speech-sdk@npm:^1.41.0": + version: 1.41.0 + resolution: "microsoft-cognitiveservices-speech-sdk@npm:1.41.0" dependencies: - "@es-joy/jsdoccomment": "npm:^0.46.0" "@types/webrtc": "npm:^0.0.37" agent-base: "npm:^6.0.1" bent: "npm:^7.3.12" https-proxy-agent: "npm:^4.0.0" uuid: "npm:^9.0.0" ws: "npm:^7.5.6" - checksum: 10c0/a4a1de9a045dd2235dd09ba0df1b1b0d0451fc52ad5cf4f53d2467017d041baef252e00002ff58d9ff8182b4976ce58e4d8b14b17e49c25ebc6de019eefb1be5 + checksum: 10c0/0335660d8ca809a5b0b86d7664482a169e660ef30d28665201b36d12ef0c9e4dbc18dcd9b4b5d9a473d3f29610b264afb367a1c7a6f78171c968a03f113c64e7 languageName: node linkType: hard @@ -18059,7 +18033,7 @@ __metadata: languageName: node linkType: hard -"nuxt-og-image@npm:^3.0.3, nuxt-og-image@npm:^3.0.4": +"nuxt-og-image@npm:^3.0.3": version: 3.0.4 resolution: "nuxt-og-image@npm:3.0.4" dependencies: @@ -18096,6 +18070,43 @@ __metadata: languageName: node linkType: hard +"nuxt-og-image@npm:^3.0.6": + version: 3.0.6 + resolution: "nuxt-og-image@npm:3.0.6" + dependencies: + "@nuxt/devtools-kit": "npm:^1.6.0" + "@nuxt/kit": "npm:^3.13.2" + "@resvg/resvg-js": "npm:^2.6.2" + "@resvg/resvg-wasm": "npm:^2.6.2" + "@unocss/core": "npm:^0.63.4" + "@unocss/preset-wind": "npm:^0.63.4" + chrome-launcher: "npm:^1.1.2" + defu: "npm:^6.1.4" + execa: "npm:^9.4.0" + image-size: "npm:^1.1.1" + magic-string: "npm:^0.30.12" + nuxt-site-config: "npm:^2.2.18" + nuxt-site-config-kit: "npm:^2.2.18" + nypm: "npm:^0.3.12" + ofetch: "npm:^1.4.1" + ohash: "npm:^1.1.4" + pathe: "npm:^1.1.2" + pkg-types: "npm:^1.2.1" + playwright-core: "npm:^1.48.0" + radix3: "npm:^1.1.2" + satori: "npm:^0.11.2" + satori-html: "npm:^0.3.2" + sirv: "npm:^3.0.0" + std-env: "npm:^3.7.0" + strip-literal: "npm:^2.1.0" + ufo: "npm:^1.5.4" + unplugin: "npm:^1.14.1" + unwasm: "npm:^0.3.9" + yoga-wasm-web: "npm:^0.3.3" + checksum: 10c0/d71168e687ab973ab79e3a879520d86155689e5400f1e5241013354a48d3ef5ee2b815412e4fa0cbe129121f22a6d1cf4fb668166d276743f87a9dfc47ee18ca + languageName: node + linkType: hard + "nuxt-schema-org@npm:^3.4.0": version: 3.4.0 resolution: "nuxt-schema-org@npm:3.4.0" @@ -18358,7 +18369,7 @@ __metadata: languageName: node linkType: hard -"ofetch@npm:^1.3.4, ofetch@npm:^1.4.0": +"ofetch@npm:^1.3.4, ofetch@npm:^1.4.0, ofetch@npm:^1.4.1": version: 1.4.1 resolution: "ofetch@npm:1.4.1" dependencies: @@ -19137,7 +19148,7 @@ __metadata: languageName: node linkType: hard -"pkg-types@npm:^1.0.3, pkg-types@npm:^1.1.1, pkg-types@npm:^1.2.0": +"pkg-types@npm:^1.0.3, pkg-types@npm:^1.1.1, pkg-types@npm:^1.2.0, pkg-types@npm:^1.2.1": version: 1.2.1 resolution: "pkg-types@npm:1.2.1" dependencies: @@ -19148,7 +19159,16 @@ __metadata: languageName: node linkType: hard -"playwright-core@npm:1.48.0, playwright-core@npm:^1.47.2": +"playwright-core@npm:1.48.1, playwright-core@npm:^1.48.0": + version: 1.48.1 + resolution: "playwright-core@npm:1.48.1" + bin: + playwright-core: cli.js + checksum: 10c0/2f75532b9b7dfa0e586f5660ac1d8ea729bbdbd28dd2c0711e7cfc1adfe5cf7448d7f15a018ec9851a8f50c0743c3990cb9df23064bed603627baeac4dce3915 + languageName: node + linkType: hard + +"playwright-core@npm:^1.47.2": version: 1.48.0 resolution: "playwright-core@npm:1.48.0" bin: @@ -19157,18 +19177,18 @@ __metadata: languageName: node linkType: hard -"playwright@npm:1.48.0": - version: 1.48.0 - resolution: "playwright@npm:1.48.0" +"playwright@npm:1.48.1": + version: 1.48.1 + resolution: "playwright@npm:1.48.1" dependencies: fsevents: "npm:2.3.2" - playwright-core: "npm:1.48.0" + playwright-core: "npm:1.48.1" dependenciesMeta: fsevents: optional: true bin: playwright: cli.js - checksum: 10c0/4c18b06f7de87442c697c5df5b1e1e4027c5d9f05c07679aca0135dcf18e69a4424ec7a0c91a323487ea4a92fb5146fc80c04329479426722ea742fd5eff1afa + checksum: 10c0/96280ae656226e52015c0c69c4c19e9f594c19353a79012a19bd7b7175d7b409c1aed289a629df49ef897a57ccd24668ad15b86c283db10f76212a4db90a94ac languageName: node linkType: hard @@ -20886,7 +20906,7 @@ __metadata: languageName: node linkType: hard -"satori@npm:^0.11.1": +"satori@npm:^0.11.1, satori@npm:^0.11.2": version: 0.11.2 resolution: "satori@npm:0.11.2" dependencies: @@ -21274,6 +21294,17 @@ __metadata: languageName: node linkType: hard +"sirv@npm:^3.0.0": + version: 3.0.0 + resolution: "sirv@npm:3.0.0" + dependencies: + "@polka/url": "npm:^1.0.0-next.24" + mrmime: "npm:^2.0.0" + totalist: "npm:^3.0.0" + checksum: 10c0/282c52ee5a93cafa297096ad31aa6c3004a21d4c93abe728b701e51e4329acb887f6e92f07696225414fd6bb4a7782fd64a42d0b6b6467ae0f66bd3fde90b865 + languageName: node + linkType: hard + "sisteransi@npm:^1.0.5": version: 1.0.5 resolution: "sisteransi@npm:1.0.5" @@ -22014,9 +22045,9 @@ __metadata: languageName: node linkType: hard -"tailwindcss@npm:^3.4.13": - version: 3.4.13 - resolution: "tailwindcss@npm:3.4.13" +"tailwindcss@npm:^3.4.14": + version: 3.4.14 + resolution: "tailwindcss@npm:3.4.14" dependencies: "@alloc/quick-lru": "npm:^5.2.0" arg: "npm:^5.0.2" @@ -22043,7 +22074,7 @@ __metadata: bin: tailwind: lib/cli.js tailwindcss: lib/cli.js - checksum: 10c0/c6525be3dd26febc4ec5e45e80596bff8b48ade7de258c1ec8704297bf47c1ec7b2b186b13662ebaa6ab4795ad8879fb64064f796756bfc8b46558b542b01a6c + checksum: 10c0/f6c23f8a3293ce3b2511bca1e50008ac94bd8562cb09fec32fe4f8e8a4f54d9e9fc10e567b7f974abdd4b33e550564a2616d4e793c736955432f28448141ce45 languageName: node linkType: hard @@ -22457,13 +22488,20 @@ __metadata: languageName: node linkType: hard -"tslib@npm:^2.0.0, tslib@npm:^2.0.1, tslib@npm:^2.1.0, tslib@npm:^2.2.0, tslib@npm:^2.6.2, tslib@npm:^2.7.0": +"tslib@npm:^2.0.0, tslib@npm:^2.0.1, tslib@npm:^2.1.0, tslib@npm:^2.2.0, tslib@npm:^2.6.2": version: 2.7.0 resolution: "tslib@npm:2.7.0" checksum: 10c0/469e1d5bf1af585742128827000711efa61010b699cb040ab1800bcd3ccdd37f63ec30642c9e07c4439c1db6e46345582614275daca3e0f4abae29b0083f04a6 languageName: node linkType: hard +"tslib@npm:^2.8.0": + version: 2.8.0 + resolution: "tslib@npm:2.8.0" + checksum: 10c0/31e4d14dc1355e9b89e4d3c893a18abb7f90b6886b089c2da91224d0a7752c79f3ddc41bc1aa0a588ac895bd97bb99c5bc2bfdb2f86de849f31caeb3ba79bbe5 + languageName: node + linkType: hard + "tunnel-agent@npm:^0.6.0": version: 0.6.0 resolution: "tunnel-agent@npm:0.6.0"