Jlm
CastingTests.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2026 HÃ¥vard Krogstie <krogstie.havard@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
6 #include <gtest/gtest.h>
7 
14 #include <jlm/llvm/ir/print.hpp>
15 #include <jlm/llvm/ir/types.hpp>
17 #include <jlm/rvsdg/graph.hpp>
18 #include <jlm/rvsdg/MatchType.hpp>
20 #include <jlm/util/Statistics.hpp>
21 
22 #include <llvm/IR/DerivedTypes.h>
23 #include <llvm/IR/Instructions.h>
24 #include <llvm/IR/LLVMContext.h>
25 
26 #include <algorithm>
27 
28 class LlvmBackendCastingFixture : public testing::TestWithParam<int>
29 {
30 };
31 
33 {
34  /*
35  * Creates a function equivalent to the following C code in RVSDG:
36  *
37  * <SIZE x i8> f(uint32_t x) {
38  * uint64_t zext = x;
39  * uint32_t trunc = (uint32_t) zext;
40  * int64_t sext = (int32_t) trunc;
41  * void* inttoptr = (void*) sext;
42  * uint64_t ptrtoint = (uint64_t) inttoptr;
43  * <SIZE x i8> bitcast = (<SIZE x i8>) ptrtoint; // where SIZE = sizeof(x);
44  * return bitcast;
45  * }
46  *
47  * The test is parameterized with a vectorization width.
48  * When it is non-zero, all scalar types are replaced with vectors of the given width.
49  * E.g., with a width of 4, all uint64_t values become <4 x uint64_t> instead.
50  *
51  * The test converts the above RVSDG to LLVM IR, checking that all casts have been converted,
52  * with the expected input and output types.
53  */
54 
55  // If 0, no vectorization is used
56  int vectorization = GetParam();
57 
58  // Arrange
59  auto rvsdgModule = jlm::llvm::LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
60  auto & graph = rvsdgModule->Rvsdg();
61 
62  const auto bits32 = jlm::rvsdg::BitType::Create(32);
63  const auto bits64 = jlm::rvsdg::BitType::Create(64);
64  const auto pointerType = jlm::llvm::PointerType::Create();
65 
66  // Create the vectorized versions of the above types
67  std::shared_ptr<const jlm::rvsdg::Type> bits32V, bits64V, pointerTypeV;
68  if (vectorization == 0)
69  {
70  bits32V = bits32;
71  bits64V = bits64;
72  pointerTypeV = pointerType;
73  }
74  else
75  {
76  bits32V = jlm::llvm::FixedVectorType::Create(bits32, vectorization);
77  bits64V = jlm::llvm::FixedVectorType::Create(bits64, vectorization);
78  pointerTypeV = jlm::llvm::FixedVectorType::Create(pointerType, vectorization);
79  }
80 
81  // Create the byte vector used for bitcasting
82  const size_t sizeofX = 8 * std::max(1, vectorization);
83  const auto byteType = jlm::rvsdg::BitType::Create(8);
84  const auto byteVectorType = jlm::llvm::FixedVectorType::Create(byteType, sizeofX);
85 
86  auto functionType = jlm::rvsdg::FunctionType::Create({ bits32V }, { byteVectorType });
87  auto lambda = jlm::rvsdg::LambdaNode::Create(
88  graph.GetRootRegion(),
90  functionType,
91  "f",
93 
94  auto x = lambda->GetFunctionArguments()[0];
95 
96  // Helper function for wrapping operations in VectorUnaryOperation when vectorization is enabled
97  auto createUnary =
98  [&](const auto & op,
99  jlm::rvsdg::Output & operand,
100  const std::shared_ptr<const jlm::rvsdg::Type> & resultType) -> jlm::rvsdg::Output &
101  {
102  if (vectorization == 0)
103  {
104  return *jlm::rvsdg::CreateOpNode<std::decay_t<decltype(op)>>(
105  { &operand },
106  op.argument(0),
107  op.result(0))
108  .output(0);
109  }
110 
111  return *jlm::rvsdg::CreateOpNode<jlm::llvm::VectorUnaryOperation>(
112  { &operand },
113  op,
114  std::static_pointer_cast<const jlm::llvm::VectorType>(operand.Type()),
115  std::static_pointer_cast<const jlm::llvm::VectorType>(resultType))
116  .output(0);
117  };
118 
119  // Create the function body
120  auto & zext = createUnary(jlm::llvm::ZExtOperation(bits32, bits64), *x, bits64V);
121  auto & trunc = createUnary(jlm::llvm::TruncOperation(bits64, bits32), zext, bits32V);
122  auto & sext = createUnary(jlm::llvm::SExtOperation(bits32, bits64), trunc, bits64V);
123  auto & inttoptr =
124  createUnary(jlm::llvm::IntegerToPointerOperation(bits64, pointerType), sext, pointerTypeV);
125  auto & ptrtoint =
126  createUnary(jlm::llvm::PtrToIntOperation(pointerType, bits64), inttoptr, bits64V);
127  // Bitcasts are never wrapped in VectorUnaryOperation
128  auto & bitcast = *jlm::llvm::BitCastOperation::create(&ptrtoint, byteVectorType);
129 
130  auto lambdaOutput = lambda->finalize({ &bitcast });
131  jlm::rvsdg::GraphExport::Create(*lambdaOutput, "f");
132 
133  // Act
135  auto ipgModule =
137  llvm::LLVMContext context;
138  auto llvmModule = jlm::llvm::IpGraphToLlvmConverter::CreateAndConvertModule(*ipgModule, context);
139 
140  // Assert
141  {
142  llvm::Type * expectedBits32 = llvm::Type::getInt32Ty(context);
143  llvm::Type * expectedBits64 = llvm::Type::getInt64Ty(context);
144  llvm::Type * expectedPointer = llvm::PointerType::getUnqual(context);
145 
146  if (vectorization != 0)
147  {
148  auto ec = llvm::ElementCount::getFixed(vectorization);
149  expectedBits32 = llvm::VectorType::get(expectedBits32, ec);
150  expectedBits64 = llvm::VectorType::get(expectedBits64, ec);
151  expectedPointer = llvm::VectorType::get(expectedPointer, ec);
152  }
153 
154  auto byteEc = llvm::ElementCount::getFixed(sizeofX);
155  llvm::Type * expectedByteVector = llvm::VectorType::get(llvm::Type::getInt8Ty(context), byteEc);
156 
157  auto llvmFunction = llvmModule->getFunction("f");
158  ASSERT_NE(llvmFunction, nullptr);
159  EXPECT_EQ(llvmFunction->getReturnType(), expectedByteVector);
160  ASSERT_EQ(llvmFunction->arg_size(), 1u);
161  EXPECT_EQ(llvmFunction->arg_begin()->getType(), expectedBits32);
162 
163  size_t numZext = 0;
164  size_t numTrunc = 0;
165  size_t numSext = 0;
166  size_t numInttoptr = 0;
167  size_t numPtrtoint = 0;
168  size_t numBitcasts = 0;
169 
170  for (auto & basicBlock : *llvmFunction)
171  {
172  for (auto & instruction : basicBlock)
173  {
174  if (auto * zextInstruction = llvm::dyn_cast<llvm::ZExtInst>(&instruction))
175  {
176  numZext++;
177  EXPECT_EQ(zextInstruction->getSrcTy(), expectedBits32);
178  EXPECT_EQ(zextInstruction->getDestTy(), expectedBits64);
179  }
180  else if (auto * truncInstruction = llvm::dyn_cast<llvm::TruncInst>(&instruction))
181  {
182  numTrunc++;
183  EXPECT_EQ(truncInstruction->getSrcTy(), expectedBits64);
184  EXPECT_EQ(truncInstruction->getDestTy(), expectedBits32);
185  }
186  else if (auto * sextInstruction = llvm::dyn_cast<llvm::SExtInst>(&instruction))
187  {
188  numSext++;
189  EXPECT_EQ(sextInstruction->getSrcTy(), expectedBits32);
190  EXPECT_EQ(sextInstruction->getDestTy(), expectedBits64);
191  }
192  else if (auto * intToPtrInstruction = llvm::dyn_cast<llvm::IntToPtrInst>(&instruction))
193  {
194  numInttoptr++;
195  EXPECT_EQ(intToPtrInstruction->getSrcTy(), expectedBits64);
196  EXPECT_EQ(intToPtrInstruction->getDestTy(), expectedPointer);
197  }
198  else if (auto * ptrToIntInstruction = llvm::dyn_cast<llvm::PtrToIntInst>(&instruction))
199  {
200  numPtrtoint++;
201  EXPECT_EQ(ptrToIntInstruction->getSrcTy(), expectedPointer);
202  EXPECT_EQ(ptrToIntInstruction->getDestTy(), expectedBits64);
203  }
204  else if (auto * bitcastInstruction = llvm::dyn_cast<llvm::BitCastInst>(&instruction))
205  {
206  numBitcasts++;
207  EXPECT_EQ(bitcastInstruction->getSrcTy(), expectedBits64);
208  EXPECT_EQ(bitcastInstruction->getDestTy(), expectedByteVector);
209  }
210  }
211  }
212 
213  auto * returnInstruction =
214  llvm::dyn_cast<llvm::ReturnInst>(llvmFunction->back().getTerminator());
215  ASSERT_NE(returnInstruction, nullptr);
216  auto * bitcastInstruction =
217  llvm::dyn_cast<llvm::BitCastInst>(returnInstruction->getReturnValue());
218  EXPECT_NE(bitcastInstruction, nullptr);
219 
220  EXPECT_EQ(numZext, 1u);
221  EXPECT_EQ(numTrunc, 1u);
222  EXPECT_EQ(numSext, 1u);
223  EXPECT_EQ(numInttoptr, 1u);
224  EXPECT_EQ(numPtrtoint, 1u);
225  EXPECT_EQ(numBitcasts, 1u);
226  }
227 }
228 
230  LlvmBackendCastingTests,
232  testing::Values(0, 1, 2, 4, 8));
static jlm::util::StatisticsCollector statisticsCollector
TEST_P(LlvmBackendCastingFixture, AllIntegerCasts)
INSTANTIATE_TEST_SUITE_P(LlvmBackendCastingTests, LlvmBackendCastingFixture, testing::Values(0, 1, 2, 4, 8))
static std::unique_ptr< llvm::ThreeAddressCode > create(const Variable *operand, std::shared_ptr< const jlm::rvsdg::Type > type)
Definition: operators.hpp:1586
static std::shared_ptr< const FixedVectorType > Create(std::shared_ptr< const rvsdg::Type > type, size_t size)
Definition: types.hpp:413
static std::unique_ptr<::llvm::Module > CreateAndConvertModule(InterProceduralGraphModule &ipGraphModule, ::llvm::LLVMContext &ctx)
static std::unique_ptr< LlvmLambdaOperation > Create(std::shared_ptr< const jlm::rvsdg::FunctionType > type, std::string name, const jlm::llvm::Linkage &linkage, jlm::llvm::CallingConvention callingConvention, jlm::llvm::AttributeSet attributes)
Definition: lambda.hpp:84
static std::unique_ptr< LlvmRvsdgModule > Create(const util::FilePath &sourceFileName, const std::string &targetTriple, const std::string &dataLayout)
static std::shared_ptr< const PointerType > Create()
Definition: types.cpp:45
static std::unique_ptr< InterProceduralGraphModule > CreateAndConvertModule(LlvmRvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector)
static std::shared_ptr< const BitType > Create(std::size_t nbits)
Creates bit type of specified width.
Definition: type.cpp:45
static std::shared_ptr< const FunctionType > Create(std::vector< std::shared_ptr< const jlm::rvsdg::Type >> argumentTypes, std::vector< std::shared_ptr< const jlm::rvsdg::Type >> resultTypes)
static GraphExport & Create(Output &origin, std::string name)
Definition: graph.cpp:62
static LambdaNode * Create(rvsdg::Region &parent, std::unique_ptr< LambdaOperation > operation)
Definition: lambda.cpp:140
SimpleNode & CreateOpNode(const std::vector< Output * > &operands, OperatorArguments... operatorArguments)
Creates a simple node characterized by its operator.