Skip to content

Commit 1ad05bb

Browse files
committed
Add LLVM patch to fix AVX512 codegeneration problem
1 parent 8f81490 commit 1ad05bb

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

third_party/xla/llvm_fix.patch

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
diff --git a/third_party/llvm/llvm_jax_fix.patch b/third_party/llvm/llvm_jax_fix.patch
2+
new file mode 100644
3+
index 0000000000..5a2a60205e
4+
--- /dev/null
5+
+++ b/third_party/llvm/llvm_jax_fix.patch
6+
@@ -0,0 +1,14 @@
7+
+diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
8+
+index 96be91256915d..8bcd8670879a9 100644
9+
+--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
10+
++++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
11+
+@@ -59383,7 +59383,8 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
12+
+
13+
+ // We can always convert per-lane vXf64 shuffles into VSHUFPD.
14+
+ if (!IsSplat &&
15+
+- (VT == MVT::v4f64 || (VT == MVT::v8f64 && Subtarget.useAVX512Regs())) &&
16+
++ ((NumOps == 2 && VT == MVT::v4f64) ||
17+
++ (NumOps == 4 && VT == MVT::v8f64 && Subtarget.useAVX512Regs())) &&
18+
+ all_of(Ops, [](SDValue Op) { return Op.hasOneUse(); })) {
19+
+ // Collect the individual per-lane v2f64/v4f64 shuffles.
20+
+ MVT OpVT = Ops[0].getSimpleValueType();
21+
diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl
22+
index ae0c1b550f..ce408f554a 100644
23+
--- a/third_party/llvm/workspace.bzl
24+
+++ b/third_party/llvm/workspace.bzl
25+
@@ -22,6 +22,7 @@ def repo(name):
26+
"//third_party/llvm:mathextras.patch",
27+
"//third_party/llvm:toolchains.patch",
28+
"//third_party/llvm:zstd.patch",
29+
+ "//third_party/llvm:llvm_jax_fix.patch",
30+
],
31+
link_files = {"//third_party/llvm:run_lit.sh": "mlir/run_lit.sh"},
32+
)

third_party/xla/workspace.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def repo():
3030
sha256 = XLA_SHA256,
3131
strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT),
3232
urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)),
33+
patch_file = ["//third_party/xla:llvm_fix.patch"],
3334
)
3435

3536
# For development, one often wants to make changes to the TF repository as well

0 commit comments

Comments
 (0)