Jlm
CastingTests.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2026 HÃ¥vard Krogstie <krogstie.havard@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
6 #include <gtest/gtest.h>
7 
11 #include <jlm/llvm/ir/print.hpp>
12 #include <jlm/llvm/ir/types.hpp>
14 #include <jlm/rvsdg/MatchType.hpp>
15 
16 #include <llvm/IR/BasicBlock.h>
17 #include <llvm/IR/DerivedTypes.h>
18 #include <llvm/IR/IRBuilder.h>
19 #include <llvm/IR/Module.h>
20 
21 #include <algorithm>
22 
23 class LlvmFrontendCastingFixture : public testing::TestWithParam<int>
24 {
25 };
26 
28 {
29  /*
30  * Creates a function equivalent to the following C code in LLVM IR:
31  *
32  * <SIZE x i8> f(uint32_t x) {
33  * uint64_t zext = x;
34  * uint32_t trunc = (uint32_t) zext;
35  * int64_t sext = (int32_t) trunc;
36  * void* inttoptr = (void*) sext;
37  * uint64_t ptrtoint = (uint64_t) inttoptr;
38  * <SIZE x i8> bitcast = (<SIZE x i8>) ptrtoint; // where SIZE = sizeof(x);
39  * return bitcast;
40  * }
41  *
42  * The test is parameterized with a vectorization width.
43  * When it is non-zero, all scalar types are replaced with vectors of the given width.
44  * E.g., with a width of 4, all uint64_t values become <4 x uint64_t> instead.
45  */
46 
47  // If 0, no vectorization is used
48  int vectorization = GetParam();
49 
50  // Arrange
51  llvm::LLVMContext context;
52  llvm::Module llvmModule("module", context);
53 
54  llvm::Type * int64Type = llvm::Type::getInt64Ty(context);
55  llvm::Type * int32Type = llvm::Type::getInt32Ty(context);
56  llvm::Type * pointerType = llvm::PointerType::getUnqual(context);
57 
58  // If vectorization is enabled, make all the types vectors
59  if (vectorization != 0)
60  {
61  auto ec = llvm::ElementCount::getFixed(vectorization);
62  int64Type = llvm::VectorType::get(int64Type, ec);
63  int32Type = llvm::VectorType::get(int32Type, ec);
64  pointerType = llvm::VectorType::get(pointerType, ec);
65  }
66 
67  // Create the final <SIZE x i8> type
68  size_t sizeofX = 8 * std::max(1, vectorization);
69  auto ec = llvm::ElementCount::getFixed(sizeofX);
70  llvm::Type * byteVectorType = llvm::VectorType::get(llvm::Type::getInt8Ty(context), ec);
71 
72  auto functionType =
73  llvm::FunctionType::get(byteVectorType, llvm::ArrayRef<llvm::Type *>({ int32Type }), false);
74  auto function =
75  llvm::Function::Create(functionType, llvm::GlobalValue::ExternalLinkage, "f", &llvmModule);
76 
77  auto basicBlock = llvm::BasicBlock::Create(context, "bb0", function);
78 
79  llvm::IRBuilder<> builder(basicBlock);
80  auto zext = builder.CreateZExt(function->arg_begin(), int64Type);
81  auto trunc = builder.CreateTrunc(zext, int32Type);
82  auto sext = builder.CreateSExt(trunc, int64Type);
83  auto inttoptr = builder.CreateIntToPtr(sext, pointerType);
84  auto ptrtoint = builder.CreatePtrToInt(inttoptr, int64Type);
85  auto bitcast = builder.CreateBitCast(ptrtoint, byteVectorType);
86  builder.CreateRet(bitcast);
87 
88  llvmModule.print(llvm::errs(), nullptr);
89 
90  // Act
91  auto ipgModule = jlm::llvm::ConvertLlvmModule(llvmModule);
92  print(*ipgModule, stdout);
93 
94  // Assert
95  {
96  using namespace jlm::llvm;
97 
98  const auto jlmBits32 = jlm::rvsdg::BitType::Create(32);
99  const auto jlmBits64 = jlm::rvsdg::BitType::Create(64);
100  const auto jlmPointerType = jlm::llvm::PointerType::Create();
101  const auto jlmByteType = jlm::rvsdg::BitType::Create(8);
102  const auto jlmByteVectorType = jlm::llvm::FixedVectorType::Create(jlmByteType, sizeofX);
103 
104  auto controlFlowGraph =
105  dynamic_cast<const FunctionNode *>(ipgModule->ipgraph().find("f"))->cfg();
106  auto basicBlock =
107  dynamic_cast<const jlm::llvm::BasicBlock *>(controlFlowGraph->entry()->OutEdge(0)->sink());
108 
109  size_t numUnaryVector = 0;
110  size_t numZext = 0;
111  size_t numTrunc = 0;
112  size_t numSext = 0;
113  size_t numInttoptr = 0;
114  size_t numPtrtoint = 0;
115  size_t numBitcasts = 0;
116  for (auto it = basicBlock->begin(); it != basicBlock->end(); it++)
117  {
118  auto op = &(*it)->operation();
119 
120  // If the operation is wrapped in a vector unary, unwrap it
121  if (auto vecOp = dynamic_cast<const VectorUnaryOperation *>(op))
122  {
123  numUnaryVector++;
124  op = &vecOp->operation();
125  }
126 
127  std::cout << op->debug_string() << std::endl;
129  *op,
130  [&]([[maybe_unused]] const UndefValueOperation & op)
131  {
132  // Ignore the undef operation created as a default return value
133  },
134  [&]([[maybe_unused]] const AssignmentOperation & op)
135  {
136  // Ignore the assignment to the dummy return variable
137  },
138  [&](const ZExtOperation & op)
139  {
140  numZext++;
141  EXPECT_EQ(*op.argument(0), *jlmBits32);
142  EXPECT_EQ(*op.result(0), *jlmBits64);
143  },
144  [&](const TruncOperation & op)
145  {
146  numTrunc++;
147  EXPECT_EQ(*op.argument(0), *jlmBits64);
148  EXPECT_EQ(*op.result(0), *jlmBits32);
149  },
150  [&](const SExtOperation & op)
151  {
152  numSext++;
153  EXPECT_EQ(*op.argument(0), *jlmBits32);
154  EXPECT_EQ(*op.result(0), *jlmBits64);
155  },
156  [&](const IntegerToPointerOperation & op)
157  {
158  numInttoptr++;
159  EXPECT_EQ(*op.argument(0), *jlmBits64);
160  EXPECT_EQ(*op.result(0), *jlmPointerType);
161  },
162  [&](const PtrToIntOperation & op)
163  {
164  numPtrtoint++;
165  EXPECT_EQ(*op.argument(0), *jlmPointerType);
166  EXPECT_EQ(*op.result(0), *jlmBits64);
167  },
168  [&](const BitCastOperation & op)
169  {
170  numBitcasts++;
171 
172  // BitCasts should never be wrapped in VectorUnaryOperation, so for vectorized
173  // instances of the test we expect the operation to take a vector type.
174  if (vectorization)
175  {
176  const auto jlmBits64Vector =
177  jlm::llvm::FixedVectorType::Create(jlmBits64, vectorization);
178  EXPECT_EQ(*op.argument(0), *jlmBits64Vector);
179  }
180  else
181  {
182  // For the non-vectorized instance of this test, the input is just a uint64
183  EXPECT_EQ(*op.argument(0), *jlmBits64);
184  }
185  EXPECT_EQ(*op.result(0), *jlmByteVectorType);
186  });
187  }
188 
189  EXPECT_EQ(numUnaryVector, vectorization ? 5 : 0u);
190  EXPECT_EQ(numZext, 1u);
191  EXPECT_EQ(numTrunc, 1u);
192  EXPECT_EQ(numSext, 1u);
193  EXPECT_EQ(numInttoptr, 1u);
194  EXPECT_EQ(numPtrtoint, 1u);
195  EXPECT_EQ(numBitcasts, 1u);
196  }
197 }
198 
200  LlvmFrontendCastingTests,
202  testing::Values(0, 1, 2, 4, 8));
TEST_P(LlvmBackendCastingFixture, AllIntegerCasts)
INSTANTIATE_TEST_SUITE_P(LlvmBackendCastingTests, LlvmBackendCastingFixture, testing::Values(0, 1, 2, 4, 8))
static std::shared_ptr< const FixedVectorType > Create(std::shared_ptr< const rvsdg::Type > type, size_t size)
Definition: types.hpp:413
static std::shared_ptr< const PointerType > Create()
Definition: types.cpp:45
UndefValueOperation class.
Definition: operators.hpp:1023
static std::shared_ptr< const BitType > Create(std::size_t nbits)
Creates bit type of specified width.
Definition: type.cpp:45
Global memory state passed between functions.
void print(const AggregationNode &n, const AnnotationMap &dm, FILE *out)
Definition: print.cpp:120
std::unique_ptr< InterProceduralGraphModule > ConvertLlvmModule(::llvm::Module &llvmModule)
void MatchTypeOrFail(T &obj, const Fns &... fns)
Pattern match over subclass type of given object.