Skip to content

Commit 1e52e15

Browse files
committed
fix bug
1 parent 6a27bba commit 1e52e15

File tree

4 files changed

+17
-3
lines changed

4 files changed

+17
-3
lines changed

ctmd/linalg/matmul.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ inline constexpr void matmul_impl(const in1_t &in1, const in2_t &in2,
6868
const auto ein2 =
6969
core::eigen::to_eigen(in2).template cast<common_t>();
7070
auto eout = core::eigen::to_eigen(out);
71-
eout = ein1 * ein2;
71+
eout = (ein1 * ein2).template cast<typename out_t::value_type>();
7272
return;
7373
}
7474
}

ctmd/linalg/matvec.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ inline constexpr void matvec_impl(const in1_t &in1, const in2_t &in2,
6565
const auto ein2 =
6666
core::eigen::to_eigen(in2).template cast<common_t>();
6767
auto eout = core::eigen::to_eigen(out);
68-
eout = ein1 * ein2;
68+
eout = (ein1 * ein2).template cast<typename out_t::value_type>();
6969
return;
7070
}
7171
}

ctmd/linalg/vecmat.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ inline constexpr void vecmat_impl(const in1_t &in1, const in2_t &in2,
6565
const auto ein2 =
6666
core::eigen::to_eigen(in2).template cast<common_t>();
6767
auto eout = core::eigen::to_eigen(out);
68-
eout = ein1 * ein2;
68+
eout = (ein1 * ein2).template cast<typename out_t::value_type>();
6969
return;
7070
}
7171
}

tests/linalg/matmul/main.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,17 @@ TEST(heap, matmul) {
3838

3939
ASSERT_TRUE(allclose);
4040
}
41+
42+
TEST(test, mixed) {
43+
using T1 = int;
44+
using T2 = float;
45+
using T3 = double;
46+
47+
const auto a = md::mdarray<T1, md::dims<2>>{std::vector<T1>{1, 2, 3, 4},
48+
md::dims<2>{2, 2}};
49+
const auto b = md::mdarray<T2, md::dims<2>>{std::vector<T2>{5, 6, 7, 8},
50+
md::dims<2>{2, 2}};
51+
auto c = md::mdarray<T3, md::dims<2>>{md::dims<2>{2, 2}};
52+
53+
md::linalg::matmul(a, b, c);
54+
}

0 commit comments

Comments
 (0)