Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
76227e2
Initial commit for vectorized BF16 GEMV. Added GEMM_GEMV_FORWARD_BF16…
Sep 6, 2024
8541b25
Special case beta is one.
Sep 6, 2024
39fd29f
Minor improvement and turn off BF16 GEMV forwarding by default.
Sep 8, 2024
2f142ee
More common code.
Sep 9, 2024
72216d2
Fix bug with inc_y adding results twice.
Sep 11, 2024
7947970
Move common code.
Sep 13, 2024
89a12fa
MMA BF16 GEMV code.
Sep 23, 2024
c9ce37d
Force vector pairs in clang.
Sep 23, 2024
05aa63e
More MMA BF16 GEMV code.
Sep 24, 2024
df19375
Almost final code for MMA.
Sep 24, 2024
8ab6245
Small change.
Sep 24, 2024
fb287d1
Common code.
Sep 25, 2024
eb6f3a0
Common MMA code.
Sep 26, 2024
d7c0d87
Small changes.
Sep 26, 2024
c878820
Fixing block issue with transpose version.
Sep 27, 2024
32095b0
Remove parameter.
Oct 1, 2024
e238a68
Remove duplicate.
Oct 1, 2024
7cc00f6
Remove more duplicate.
Oct 1, 2024
7ec3c16
Remove beta from optimized functions.
Oct 3, 2024
915a6d6
Add casting.
Oct 3, 2024
9ac0fb0
Merge branch 'develop' into vectorizeBF16GEMV
Oct 4, 2024
d6bb8dc
Common code.
Oct 6, 2024
c8f53b8
Merge remote-tracking branch 'origin/develop' into vectorizeBF16GEMV
Oct 11, 2024
0082240
Merge branch 'thread_sbgemv' into vectorizeBF16GEMV
Oct 11, 2024
a53a197
Merge remote-tracking branch 'origin/develop' into vectorizeBF16GEMV
Oct 12, 2024
f8e113f
Replace types with include file.
Oct 13, 2024
36bd3ee
Vectorize BF16 GEMV (VSX & MMA). Use GEMM_GEMV_FORWARD_BF16 (for Pow…
Oct 13, 2024
2391dc1
Merge branch 'vectorizeBF16GEMV' of github.ibm.com:PowerAppLibs/OpenB…
Oct 13, 2024
ab71a1e
Better VSX.
Oct 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Makefile.system
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,19 @@ GEMM_GEMV_FORWARD = 1
endif
ifeq ($(ARCH), power)
GEMM_GEMV_FORWARD = 1
GEMM_GEMV_FORWARD_BF16 = 1
endif

ifeq ($(SMALL_MATRIX_OPT), 1)
CCOMMON_OPT += -DSMALL_MATRIX_OPT
endif
ifeq ($(GEMM_GEMV_FORWARD), 1)
ifneq ($(ONLY_CBLAS), 1)
ifeq ($(GEMM_GEMV_FORWARD), 1)
CCOMMON_OPT += -DGEMM_GEMV_FORWARD
endif
ifeq ($(GEMM_GEMV_FORWARD_BF16), 1)
CCOMMON_OPT += -DGEMM_GEMV_FORWARD_BF16
endif
endif

# This operation is expensive, so execution should be once.
Expand Down
3 changes: 3 additions & 0 deletions cmake/system.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,9 @@ endif ()
if (GEMM_GEMV_FORWARD AND NOT ONLY_CBLAS)
set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD")
endif ()
if (GEMM_GEMV_FORWARD_BF16 AND NOT ONLY_CBLAS)
set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD_BF16")
endif ()
if (SMALL_MATRIX_OPT)
set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT")
endif ()
Expand Down
2 changes: 1 addition & 1 deletion interface/gemm.c
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
args.m, args.n, args.k, args.lda, args.ldb, args.ldc);
#endif

#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && !defined(BFLOAT16)
#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && (!defined(BFLOAT16) || defined(GEMM_GEMV_FORWARD_BF16))
// Check if we can convert GEMM -> GEMV
if (args.k != 0) {
if (args.n == 1) {
Expand Down
2 changes: 2 additions & 0 deletions kernel/power/KERNEL.POWER10
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,13 @@ ZSWAPKERNEL = zswap.c
#

SGEMVNKERNEL = sgemv_n.c
SBGEMVNKERNEL = sbgemv_n_power10.c
DGEMVNKERNEL = dgemv_n_power10.c
CGEMVNKERNEL = cgemv_n.c
ZGEMVNKERNEL = zgemv_n_power10.c
#
SGEMVTKERNEL = sgemv_t.c
SBGEMVTKERNEL = sbgemv_t_power10.c
DGEMVTKERNEL = dgemv_t_power10.c
CGEMVTKERNEL = cgemv_t.c
ZGEMVTKERNEL = zgemv_t_4.c
Expand Down
2 changes: 2 additions & 0 deletions kernel/power/KERNEL.POWER8
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,13 @@ ZSWAPKERNEL = zswap.c
#

SGEMVNKERNEL = sgemv_n.c
SBGEMVNKERNEL = sbgemv_n_vsx.c
DGEMVNKERNEL = dgemv_n.c
CGEMVNKERNEL = cgemv_n.c
ZGEMVNKERNEL = zgemv_n_4.c
#
SGEMVTKERNEL = sgemv_t.c
SBGEMVTKERNEL = sbgemv_t_vsx.c
DGEMVTKERNEL = dgemv_t.c
CGEMVTKERNEL = cgemv_t.c
ZGEMVTKERNEL = zgemv_t_4.c
Expand Down
2 changes: 2 additions & 0 deletions kernel/power/KERNEL.POWER9
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,13 @@ ZSWAPKERNEL = zswap.c
#

SGEMVNKERNEL = sgemv_n.c
SBGEMVNKERNEL = sbgemv_n_vsx.c
DGEMVNKERNEL = dgemv_n.c
CGEMVNKERNEL = cgemv_n.c
ZGEMVNKERNEL = zgemv_n_4.c
#
SGEMVTKERNEL = sgemv_t.c
SBGEMVTKERNEL = sbgemv_t_vsx.c
DGEMVTKERNEL = dgemv_t.c
CGEMVTKERNEL = cgemv_t.c
ZGEMVTKERNEL = zgemv_t_4.c
Expand Down
153 changes: 153 additions & 0 deletions kernel/power/gemm_common.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#ifndef GEMM_COMMON_C
#define GEMM_COMMON_C
#include "common.h"

#include <altivec.h>
#include <inttypes.h>

#define NBMAX 4096

#define FORCEINLINE inline __attribute__((always_inline))

#ifdef _ARCH_PWR10
#ifdef __has_builtin
#if !__has_builtin(__builtin_vsx_assemble_pair)
#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
#endif
#if !__has_builtin(__builtin_vsx_disassemble_pair)
#define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair
#endif
#endif

#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
#define __builtin_vsx_assemble_pair2(vp0, v0, v1) __builtin_vsx_assemble_pair(vp0, v1, v0)
#else
#define __builtin_vsx_assemble_pair2(vp0, v0, v1) __builtin_vsx_assemble_pair(vp0, v0, v1)
#endif

#define USE_VECTOR_PAIRS
#endif

typedef __vector IFLOAT vec_bf16;
typedef __vector FLOAT vec_f32;
typedef __vector unsigned char vec_uc8;

FORCEINLINE vec_uc8 vec_load_vec(void *src)
{
return vec_xl(0, (unsigned char *)(src));
}

FORCEINLINE void vec_load_pair(vec_f32 *dst, vec_f32 *src)
{
#ifdef USE_VECTOR_PAIRS
__vector_pair vy0p;
#ifdef __clang__
vy0p = __builtin_vsx_lxvp(0L, (const __vector_pair *)(src));
#else
vy0p = *(__vector_pair *)(src);
#endif
__builtin_vsx_disassemble_pair((void *)(dst), &vy0p);
#else
dst[0] = src[0];
dst[1] = src[1];
#endif
}

FORCEINLINE void vec_store_pair(vec_f32 *dst, vec_f32 *src)
{
#ifdef USE_VECTOR_PAIRS
__vector_pair vy0p;
__builtin_vsx_assemble_pair2(&vy0p, (vec_uc8)src[1], (vec_uc8)src[0]);
#ifdef __clang__
__builtin_vsx_stxvp(vy0p, 0L, (__vector_pair *)(dst));
#else
*(__vector_pair *)(dst) = vy0p;
#endif
#else
dst[0] = src[0];
dst[1] = src[1];
#endif
}

FORCEINLINE vec_bf16 vec_loadN(void *src, BLASLONG n)
{
IFLOAT *src2 = (IFLOAT *)(src);
#ifdef _ARCH_PWR9
return vec_xl_len(src2, n * sizeof(IFLOAT));
#else
__attribute__((aligned(16))) IFLOAT data[sizeof(vec_bf16) / sizeof(IFLOAT)];
memset(data, 0, sizeof(vec_bf16));
if (n & 4) {
memcpy(data, src2, sizeof(uint64_t));
}
if (n & 2) {
BLASLONG n4 = n & 4;
memcpy(data + n4, src2 + n4, sizeof(uint32_t));
}
if (n & 1) {
BLASLONG n6 = n & 6;
data[n6] = src2[n6];
}
return (vec_bf16)vec_load_vec(data);
#endif
}

FORCEINLINE vec_f32 vec_loadN_f32(void *src, BLASLONG n)
{
#ifndef _ARCH_PWR9
if (n & 4) {
return (vec_f32)vec_load_vec(src);
}
#endif
return (vec_f32)vec_loadN(src, n * (sizeof(FLOAT) / sizeof(IFLOAT)));
}

FORCEINLINE void vec_loadN2_f32(vec_f32 *data, vec_f32 *src, BLASLONG n)
{
data[0] = src[0];
data[1] = vec_loadN_f32(&src[1], n);
}

FORCEINLINE void vec_storeN(vec_bf16 data, void *dst, BLASLONG n)
{
IFLOAT *dst2 = (IFLOAT *)(dst);
#ifdef _ARCH_PWR9
vec_xst_len(data, dst2, n * sizeof(IFLOAT));
#else
if (n & 8) {
vec_xst(data, 0, dst2);
return;
}
__attribute__((aligned(16))) IFLOAT data2[sizeof(vec_f32) / sizeof(IFLOAT)];
vec_xst(data, 0, data2);
if (n & 4) {
memcpy(dst2, data2, sizeof(uint64_t));
}
if (n & 2) {
BLASLONG n4 = n & 4;
memcpy(dst2 + n4, data2 + n4, sizeof(uint32_t));
}
if (n & 1) {
BLASLONG n6 = n & 6;
dst2[n6] = data2[n6];
}
#endif
}

FORCEINLINE void vec_storeN_f32(vec_f32 data, void *dst, BLASLONG n)
{
#ifndef _ARCH_PWR9
if (n & 4) {
vec_xst(data, 0, (FLOAT *)dst);
return;
}
#endif
return vec_storeN((vec_bf16)data, dst, n * (sizeof(FLOAT) / sizeof(IFLOAT)));
}

FORCEINLINE void vec_storeN2_f32(vec_f32 *data, vec_f32 *dst, BLASLONG n)
{
dst[0] = data[0];
vec_storeN_f32(data[1], &dst[1], n);
}
#endif
Loading