Jlm
LlvmModuleConversionTests.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2026 Nico Reißmann <nico.reissmann@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
6 #include <gtest/gtest.h>
7 
10 #include <jlm/llvm/ir/cfg.hpp>
12 #include <jlm/llvm/ir/ipgraph.hpp>
17 #include <jlm/llvm/ir/print.hpp>
19 
20 #include <llvm/IR/BasicBlock.h>
21 #include <llvm/IR/IRBuilder.h>
22 #include <llvm/IR/Module.h>
23 
24 TEST(LlvmModuleConversionTests, SwitchConversion)
25 {
26  using namespace llvm;
27 
28  // Arrange
29  LLVMContext context;
30  const std::unique_ptr<Module> llvmModule(new Module("module", context));
31 
32  auto int64Type = Type::getInt64Ty(context);
33 
34  auto functionType = FunctionType::get(int64Type, ArrayRef<Type *>({ int64Type }), false);
35  auto function =
36  Function::Create(functionType, GlobalValue::ExternalLinkage, "f", llvmModule.get());
37 
38  auto bbSplit = BasicBlock::Create(context, "BasicBlockSplit", function);
39  auto bb1 = BasicBlock::Create(context, "BasicBlock1", function);
40  auto bb2 = BasicBlock::Create(context, "BasicBlock2", function);
41  auto bb3 = BasicBlock::Create(context, "BasicBlock4", function);
42  auto bb4 = BasicBlock::Create(context, "BasicBlock4", function);
43  auto bbJoin = BasicBlock::Create(context, "BasicBlockJoin", function);
44 
45  IRBuilder builder(bbSplit);
46  auto switchInstruction = builder.CreateSwitch(function->arg_begin(), bb4);
47  switchInstruction->addCase(ConstantInt::get(int64Type, 1), bb1);
48  switchInstruction->addCase(ConstantInt::get(int64Type, 2), bb2);
49  switchInstruction->addCase(ConstantInt::get(int64Type, 3), bb2);
50  switchInstruction->addCase(ConstantInt::get(int64Type, 4), bb3);
51  switchInstruction->addCase(ConstantInt::get(int64Type, 5), bb3);
52 
53  builder.SetInsertPoint(bb1);
54  builder.CreateBr(bbJoin);
55 
56  builder.SetInsertPoint(bb2);
57  builder.CreateBr(bbJoin);
58 
59  builder.SetInsertPoint(bb3);
60  builder.CreateBr(bbJoin);
61 
62  builder.SetInsertPoint(bb4);
63  builder.CreateBr(bbJoin);
64 
65  builder.SetInsertPoint(bbJoin);
66  auto phiInstruction = builder.CreatePHI(int64Type, 4);
67  phiInstruction->addIncoming(ConstantInt::get(int64Type, 1), bb1);
68  phiInstruction->addIncoming(ConstantInt::get(int64Type, 2), bb2);
69  phiInstruction->addIncoming(ConstantInt::get(int64Type, 3), bb3);
70  phiInstruction->addIncoming(ConstantInt::get(int64Type, 4), bb4);
71  builder.CreateRet(phiInstruction);
72 
73  llvmModule->print(errs(), nullptr);
74 
75  // Act
76  auto ipgModule = jlm::llvm::ConvertLlvmModule(*llvmModule);
77  print(*ipgModule, stdout);
78 
79  // Assert
80  {
81  using namespace jlm::llvm;
82 
83  const auto controlFlowGraph =
84  dynamic_cast<const FunctionNode *>(ipgModule->ipgraph().find("f"))->cfg();
85 
86  EXPECT_EQ(controlFlowGraph->nnodes(), 6u);
87 
88  // We expect the split node to only have 4 outgoing edges. One for each target basic block of
89  // the original LLVM switch statement
90  const auto splitNode = controlFlowGraph->entry()->OutEdge(0)->sink();
91  EXPECT_EQ(splitNode->NumOutEdges(), 4u);
92  }
93 }
94 
95 TEST(LlvmModuleConversionTests, FreezeConversion)
96 {
101  // Arrange
102  ::llvm::LLVMContext context;
103  ::llvm::Module llvmModule("module", context);
104 
105  auto int64Type = ::llvm::Type::getInt64Ty(context);
106  auto functionType =
107  ::llvm::FunctionType::get(int64Type, ::llvm::ArrayRef<::llvm::Type *>({ int64Type }), false);
108  auto function = ::llvm::Function::Create(
109  functionType,
110  ::llvm::GlobalValue::ExternalLinkage,
111  "f",
112  &llvmModule);
113 
114  auto basicBlock = ::llvm::BasicBlock::Create(context, "BasicBlock", function);
115 
116  ::llvm::IRBuilder<> builder(basicBlock);
117  auto freezeInstruction = builder.CreateFreeze(function->arg_begin());
118  builder.CreateRet(freezeInstruction);
119 
120  // Act
121  auto ipgModule = jlm::llvm::ConvertLlvmModule(llvmModule);
122  print(*ipgModule, stdout);
123 
124  // Assert
125  {
126  using namespace jlm::llvm;
127 
128  const auto jlmInt64Type = jlm::rvsdg::BitType::Create(64);
129  const auto controlFlowGraph =
130  dynamic_cast<const FunctionNode *>(ipgModule->ipgraph().find("f"))->cfg();
131  const auto convertedBasicBlock =
132  dynamic_cast<const jlm::llvm::BasicBlock *>(controlFlowGraph->entry()->OutEdge(0)->sink());
133 
134  size_t numFreezeThreeAddressCodes = 0;
135  for (auto tac : *convertedBasicBlock)
136  {
137  if (!is<FreezeOperation>(tac))
138  continue;
139 
140  numFreezeThreeAddressCodes++;
141  const auto freezeOperation =
142  jlm::util::assertedCast<const FreezeOperation>(&tac->operation());
143 
144  EXPECT_EQ(tac->noperands(), 1u);
145  EXPECT_EQ(tac->nresults(), 1u);
146  EXPECT_EQ(tac->operand(0), controlFlowGraph->entry()->argument(0));
147  EXPECT_EQ(*freezeOperation->argument(0), *jlmInt64Type);
148  EXPECT_EQ(freezeOperation->getType(), *jlmInt64Type);
149  EXPECT_EQ(*freezeOperation->result(0), *jlmInt64Type);
150  }
151 
152  EXPECT_EQ(numFreezeThreeAddressCodes, 1);
153  }
154 }
155 
159 TEST(LlvmModuleConversionTests, InsertValueConversion)
160 {
161 
162  // Arrange
163  llvm::LLVMContext context;
164  llvm::Module llvmModule("module", context);
165 
166  auto int64Type = llvm::Type::getInt64Ty(context);
167  auto structType = llvm::StructType::get(int64Type, int64Type);
168  auto functionType = llvm::FunctionType::get(
169  structType,
170  ::llvm::ArrayRef<llvm::Type *>({ int64Type, int64Type }),
171  false);
172  auto function =
173  llvm::Function::Create(functionType, llvm::GlobalValue::ExternalLinkage, "f", &llvmModule);
174 
175  auto basicBlock = llvm::BasicBlock::Create(context, "BasicBlock", function);
176 
177  llvm::IRBuilder builder(basicBlock);
178  auto poison = llvm::PoisonValue::get(structType);
179  auto insertValue0 = builder.CreateInsertValue(poison, function->arg_begin(), 0);
180  auto insertValue1 = builder.CreateInsertValue(insertValue0, function->arg_begin() + 1, 1);
181  builder.CreateRet(insertValue1);
182 
183  // Act
184  auto ipgModule = jlm::llvm::ConvertLlvmModule(llvmModule);
185  print(*ipgModule, stdout);
186 
187  // Assert
188  {
189  using namespace jlm::llvm;
190 
191  const auto jlmInt64Type = jlm::rvsdg::BitType::Create(64);
192  const auto controlFlowGraph =
193  dynamic_cast<const FunctionNode *>(ipgModule->ipgraph().find("f"))->cfg();
194  const auto convertedBasicBlock =
195  dynamic_cast<const BasicBlock *>(controlFlowGraph->entry()->OutEdge(0)->sink());
196 
197  size_t numInsertValueAddressCodes = 0;
198  for (auto tac : *convertedBasicBlock)
199  {
200  if (auto insertValueOperation = dynamic_cast<const InsertValueOperation *>(&tac->operation()))
201  {
202  EXPECT_EQ(tac->noperands(), 2u);
203  EXPECT_EQ(tac->nresults(), 1u);
204  EXPECT_EQ(insertValueOperation->getIndices().size(), 1u);
205  EXPECT_EQ(*insertValueOperation->getValueType(), *jlmInt64Type);
206 
207  numInsertValueAddressCodes++;
208  if (numInsertValueAddressCodes == 1)
209  {
210  EXPECT_EQ(tac->operand(1), controlFlowGraph->entry()->argument(0));
211  EXPECT_EQ(insertValueOperation->getIndices()[0], 0u);
212  }
213  else if (numInsertValueAddressCodes == 2)
214  {
215  EXPECT_EQ(tac->operand(1), controlFlowGraph->entry()->argument(1));
216  EXPECT_EQ(insertValueOperation->getIndices()[0], 1u);
217  }
218  }
219  }
220 
221  EXPECT_EQ(numInsertValueAddressCodes, 2u);
222  }
223 }
224 
225 TEST(LlvmModuleConversionTests, CallingConvConversion)
226 {
249  // Arrange
250  ::llvm::LLVMContext context;
251  ::llvm::Module llvmModule("module", context);
252  auto int64Type = ::llvm::Type::getInt64Ty(context);
253  auto unaryFunctionType =
254  ::llvm::FunctionType::get(int64Type, ::llvm::ArrayRef<::llvm::Type *>({ int64Type }), false);
255 
256  auto importedFunction = ::llvm::Function::Create(
257  unaryFunctionType,
258  ::llvm::GlobalValue::ExternalLinkage,
259  "imported",
260  &llvmModule);
261  importedFunction->setCallingConv(::llvm::CallingConv::Fast);
262 
263  auto callee = ::llvm::Function::Create(
264  unaryFunctionType,
265  ::llvm::GlobalValue::ExternalLinkage,
266  "callee",
267  &llvmModule);
268  callee->setCallingConv(::llvm::CallingConv::Cold);
269 
270  // Create the function body of the callee function
271  {
272  auto basicBlock = ::llvm::BasicBlock::Create(context, "BasicBlock", callee);
273  ::llvm::IRBuilder<> builder(basicBlock);
274  builder.CreateRet(callee->arg_begin());
275  }
276 
277  auto caller = ::llvm::Function::Create(
278  unaryFunctionType,
279  ::llvm::GlobalValue::ExternalLinkage,
280  "caller",
281  &llvmModule);
282  caller->setCallingConv(::llvm::CallingConv::Tail);
283 
284  // Create the function body of the caller function
285  {
286  auto basicBlock = ::llvm::BasicBlock::Create(context, "BasicBlock", caller);
287  ::llvm::IRBuilder<> builder(basicBlock);
288  auto importedCall = builder.CreateCall(importedFunction, { caller->arg_begin() });
289  importedCall->setCallingConv(::llvm::CallingConv::Fast);
290  auto calleeCall = builder.CreateCall(callee, { importedCall });
291  calleeCall->setCallingConv(::llvm::CallingConv::Cold);
292  builder.CreateRet(calleeCall);
293  }
294 
295  // Act
296  auto ipgModule = jlm::llvm::ConvertLlvmModule(llvmModule);
297 
298  // Assert
299  {
300  using namespace jlm::llvm;
301 
302  const auto importedNode =
303  dynamic_cast<const FunctionNode *>(ipgModule->ipgraph().find("imported"));
304  const auto calleeNode = dynamic_cast<const FunctionNode *>(ipgModule->ipgraph().find("callee"));
305  const auto callerNode = dynamic_cast<const FunctionNode *>(ipgModule->ipgraph().find("caller"));
306 
307  std::cout << ControlFlowGraph::ToAscii(*callerNode->cfg()) << std::endl;
308 
309  ASSERT_NE(importedNode, nullptr);
310  ASSERT_NE(calleeNode, nullptr);
311  ASSERT_NE(callerNode, nullptr);
312 
313  EXPECT_FALSE(importedNode->hasBody());
314  EXPECT_TRUE(calleeNode->hasBody());
315  EXPECT_TRUE(callerNode->hasBody());
316 
317  EXPECT_EQ(importedNode->callingConvention(), CallingConvention::Fast);
318  EXPECT_EQ(calleeNode->callingConvention(), CallingConvention::Cold);
319  EXPECT_EQ(callerNode->callingConvention(), CallingConvention::Tail);
320 
321  const auto convertedBasicBlock =
322  dynamic_cast<const BasicBlock *>(callerNode->cfg()->entry()->OutEdge(0)->sink());
323  ASSERT_NE(convertedBasicBlock, nullptr);
324  auto it = convertedBasicBlock->begin();
325 
326  // Return the next CallOperation in the basic block
327  const auto nextCallTac = [&]()
328  {
329  while (true)
330  {
331  EXPECT_NE(it, convertedBasicBlock->end());
332  if (is<CallOperation>(*it))
333  return *it++;
334  it++;
335  }
336  };
337 
338  // Check that the call to imported has been converted correctly
339  {
340  auto callImportedTac = nextCallTac();
341  EXPECT_EQ(callImportedTac->operand(0), ipgModule->variable(importedNode));
342  auto op = jlm::util::assertedCast<const CallOperation>(&callImportedTac->operation());
343  EXPECT_EQ(op->getCallingConvention(), CallingConvention::Fast);
344  }
345 
346  // Check that the call to callee has been converted correctly
347  {
348  auto callCalleeTac = nextCallTac();
349  EXPECT_EQ(callCalleeTac->operand(0), ipgModule->variable(calleeNode));
350  auto op = jlm::util::assertedCast<const CallOperation>(&callCalleeTac->operation());
351  EXPECT_EQ(op->getCallingConvention(), CallingConvention::Cold);
352  }
353  }
354 }
355 
356 TEST(LlvmModuleConversionTests, MemCpyConversion)
357 {
358  using namespace llvm;
359 
360  // Arrange
361  LLVMContext context;
362  std::unique_ptr<Module> llvmModule(new Module("module", context));
363 
364  auto int64Type = Type::getInt64Ty(context);
365  auto pointerType = PointerType::getUnqual(context);
366  auto voidType = Type::getVoidTy(context);
367 
368  auto functionType =
369  FunctionType::get(voidType, ArrayRef<Type *>({ pointerType, pointerType, int64Type }), false);
370  auto function =
371  Function::Create(functionType, GlobalValue::ExternalLinkage, "f", llvmModule.get());
372  auto destination = function->getArg(0);
373  auto source = function->getArg(1);
374  auto length = function->getArg(2);
375 
376  auto llvmBasicBlock = BasicBlock::Create(context, "BasicBlock", function);
377 
378  IRBuilder<> builder(llvmBasicBlock);
379  builder.CreateMemCpy(destination, MaybeAlign(), source, MaybeAlign(), length, true);
380  builder.CreateMemCpy(destination, MaybeAlign(), source, MaybeAlign(), length, false);
381  builder.CreateMemCpy(destination, MaybeAlign(), source, MaybeAlign(), length, true);
382  builder.CreateRetVoid();
383 
384  llvmModule->print(errs(), nullptr);
385 
386  // Act
387  auto ipgModule = jlm::llvm::ConvertLlvmModule(*llvmModule);
388  print(*ipgModule, stdout);
389 
390  // Assert
391  {
392  using namespace jlm::llvm;
393 
394  auto controlFlowGraph =
395  dynamic_cast<const FunctionNode *>(ipgModule->ipgraph().find("f"))->cfg();
396  auto jlmBasicBlock =
397  dynamic_cast<const jlm::llvm::BasicBlock *>(controlFlowGraph->entry()->OutEdge(0)->sink());
398 
399  size_t numMemCpyThreeAddressCodes = 0;
400  size_t numMemCpyVolatileThreeAddressCodes = 0;
401  for (auto it = jlmBasicBlock->begin(); it != jlmBasicBlock->end(); ++it)
402  {
403  if (is<MemCpyVolatileOperation>(*it))
404  {
405  numMemCpyVolatileThreeAddressCodes++;
406  auto ioStateAssignment = *std::next(it);
407  auto memoryStateAssignment = *std::next(it, 2);
408 
409  EXPECT_TRUE(is<AssignmentOperation>(ioStateAssignment->operation()));
410  EXPECT_TRUE(is<IOStateType>(ioStateAssignment->operand(0)->type()));
411 
412  EXPECT_TRUE(is<AssignmentOperation>(memoryStateAssignment->operation()));
413  EXPECT_TRUE(is<MemoryStateType>(memoryStateAssignment->operand(0)->type()));
414  }
415  else if (is<MemCpyNonVolatileOperation>(*it))
416  {
417  numMemCpyThreeAddressCodes++;
418  auto memoryStateAssignment = *std::next(it, 1);
419 
420  EXPECT_TRUE(is<AssignmentOperation>(memoryStateAssignment->operation()));
421  EXPECT_TRUE(is<MemoryStateType>(memoryStateAssignment->operand(0)->type()));
422  }
423  }
424 
425  EXPECT_EQ(numMemCpyThreeAddressCodes, 1u);
426  EXPECT_EQ(numMemCpyVolatileThreeAddressCodes, 2u);
427  }
428 }
429 
430 TEST(LlvmModuleConversionTests, MemSetConversion)
431 {
432  using namespace llvm;
433 
434  // Arrange
435  LLVMContext context;
436  std::unique_ptr<Module> llvmModule(new Module("module", context));
437 
438  auto int8Type = Type::getInt8Ty(context);
439  auto int64Type = Type::getInt64Ty(context);
440  auto pointerType = PointerType::getUnqual(context);
441  auto voidType = Type::getVoidTy(context);
442 
443  auto functionType =
444  FunctionType::get(voidType, ArrayRef<Type *>({ pointerType, int8Type, int64Type }), false);
445  auto function =
446  Function::Create(functionType, GlobalValue::ExternalLinkage, "f", llvmModule.get());
447  auto destination = function->getArg(0);
448  auto value = function->getArg(1);
449  auto length = function->getArg(2);
450 
451  auto llvmBasicBlock = BasicBlock::Create(context, "BasicBlock", function);
452 
453  IRBuilder builder(llvmBasicBlock);
454  builder.CreateMemSet(destination, value, length, MaybeAlign());
455  builder.CreateRetVoid();
456 
457  llvmModule->print(errs(), nullptr);
458 
459  // Act
460  auto ipgModule = jlm::llvm::ConvertLlvmModule(*llvmModule);
461  print(*ipgModule, stdout);
462 
463  // Assert
464  {
465  using namespace jlm::llvm;
466 
467  auto controlFlowGraph =
468  dynamic_cast<const FunctionNode *>(ipgModule->ipgraph().find("f"))->cfg();
469  auto jlmBasicBlock =
470  dynamic_cast<const jlm::llvm::BasicBlock *>(controlFlowGraph->entry()->OutEdge(0)->sink());
471 
472  size_t numMemsetThreeAddressCodes = 0;
473  for (auto it = jlmBasicBlock->begin(); it != jlmBasicBlock->end(); ++it)
474  {
475  if (is<MemSetNonVolatileOperation>(*it))
476  {
477  numMemsetThreeAddressCodes++;
478  auto memoryStateAssignment = *std::next(it, 1);
479 
480  EXPECT_TRUE(is<AssignmentOperation>(memoryStateAssignment->operation()));
481  EXPECT_TRUE(is<MemoryStateType>(memoryStateAssignment->operand(0)->type()));
482  }
483  }
484 
485  EXPECT_EQ(numMemsetThreeAddressCodes, 1u);
486  }
487 }
TEST(LlvmModuleConversionTests, SwitchConversion)
static std::shared_ptr< const BitType > Create(std::size_t nbits)
Creates bit type of specified width.
Definition: type.cpp:45
Global memory state passed between functions.
void print(const AggregationNode &n, const AnnotationMap &dm, FILE *out)
Definition: print.cpp:120
std::unique_ptr< InterProceduralGraphModule > ConvertLlvmModule(::llvm::Module &llvmModule)