Jlm
JlmToMlirConverterTests.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2024 Louis Maurin <louis7maurin@gmail.com>
3  * Copyright 2024 Magnus Själander <work@sjalander.com>
4  * See COPYING for terms of redistribution.
5  */
6 
7 #include <gtest/gtest.h>
8 
11 #include <jlm/llvm/ir/types.hpp>
15 
16 TEST(JlmToMlirConverterTests, TestLambda)
17 {
18  using namespace jlm::llvm;
19  using namespace mlir::rvsdg;
20 
21  auto rvsdgModule = LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
22  auto graph = &rvsdgModule->Rvsdg();
23 
24  {
25  // Setup the function
26  std::cout << "Function Setup" << std::endl;
27  auto functionType = jlm::rvsdg::FunctionType::Create(
30 
31  auto lambda = jlm::rvsdg::LambdaNode::Create(
32  graph->GetRootRegion(),
34  auto iOStateArgument = lambda->GetFunctionArguments()[0];
35  auto memoryStateArgument = lambda->GetFunctionArguments()[1];
36 
37  auto constant = &jlm::rvsdg::BitConstantOperation::create(*lambda->subregion(), { 32, 4 });
38 
39  lambda->finalize({ constant, iOStateArgument, memoryStateArgument });
40 
41  // Convert the RVSDG to MLIR
42  std::cout << "Convert to MLIR" << std::endl;
44  auto omega = mlirgen.ConvertModule(*rvsdgModule);
45 
46  // Validate the generated MLIR
47  std::cout << "Validate MLIR" << std::endl;
48  auto & omegaRegion = omega.getRegion();
49  EXPECT_EQ(omegaRegion.getBlocks().size(), 1);
50  auto & omegaBlock = omegaRegion.front();
51  // Lamda + terminating operation
52  EXPECT_EQ(omegaBlock.getOperations().size(), 2);
53  auto & mlirLambda = omegaBlock.front();
54  EXPECT_TRUE(
55  mlirLambda.getName().getStringRef().equals(mlir::rvsdg::LambdaNode::getOperationName()));
56 
57  // Verify function name
58  std::cout << "Verify function name" << std::endl;
59  auto functionNameAttribute = mlirLambda.getAttr(::llvm::StringRef("sym_name"));
60  auto * functionName = static_cast<mlir::StringAttr *>(&functionNameAttribute);
61  auto string = functionName->getValue().str();
62  EXPECT_EQ(string, "test");
63 
64  // Verify function signature
65  std::cout << "Verify function signature" << std::endl;
66 
67  auto result = mlirLambda.getResult(0).getType();
68  EXPECT_EQ(result.getTypeID(), mlir::FunctionType::getTypeID());
69 
70  auto lambdaOp = ::mlir::dyn_cast<::mlir::rvsdg::LambdaNode>(&mlirLambda);
71 
72  auto lamdbaTerminator = lambdaOp.getRegion().front().getTerminator();
73  auto lambdaResult = mlir::dyn_cast<mlir::rvsdg::LambdaResult>(lamdbaTerminator);
74  EXPECT_NE(lambdaResult, nullptr);
75  lambdaResult->dump();
76 
77  std::vector<mlir::Type> arguments;
78  for (auto argument : lambdaOp->getRegion(0).getArguments())
79  {
80  arguments.push_back(argument.getType());
81  }
82  EXPECT_EQ(arguments[0].getTypeID(), IOStateEdgeType::getTypeID());
83  EXPECT_EQ(arguments[1].getTypeID(), MemStateEdgeType::getTypeID());
84  std::vector<mlir::Type> results;
85  for (auto returnType : lambdaResult->getOperandTypes())
86  {
87  results.push_back(returnType);
88  }
89  EXPECT_TRUE(results[0].isa<mlir::IntegerType>());
90  EXPECT_TRUE(results[1].isa<mlir::rvsdg::IOStateEdgeType>());
91  EXPECT_TRUE(results[2].isa<mlir::rvsdg::MemStateEdgeType>());
92 
93  auto & lambdaRegion = mlirLambda.getRegion(0);
94  auto & lambdaBlock = lambdaRegion.front();
95  // Bitconstant + terminating operation
96  EXPECT_EQ(lambdaBlock.getOperations().size(), 2);
97  EXPECT_TRUE(lambdaBlock.front().getName().getStringRef().equals(
98  mlir::arith::ConstantIntOp::getOperationName()));
99 
100  omega->destroy();
101  }
102 }
103 
114 static void
115 useChainsUpTraverse(mlir::Operation * operation, std::vector<llvm::StringRef> definingOperations)
116 {
117  if (definingOperations.empty())
118  return;
119  std::cout << "Checking if operation: "
120  << operation->getOperand(0).getDefiningOp()->getName().getStringRef().data()
121  << " is equal to: " << definingOperations.back().data() << std::endl;
122  EXPECT_TRUE(operation->getOperand(0).getDefiningOp()->getName().getStringRef().equals(
123  definingOperations.back()));
124  definingOperations.pop_back();
125  useChainsUpTraverse(operation->getOperand(0).getDefiningOp(), definingOperations);
126 }
127 
140 TEST(JlmToMlirConverterTests, TestAddOperation)
141 {
142  using namespace jlm::llvm;
143  using namespace mlir::rvsdg;
144 
145  auto rvsdgModule = LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
146  auto graph = &rvsdgModule->Rvsdg();
147 
148  {
149  // Setup the function
150  std::cout << "Function Setup" << std::endl;
151  auto functionType = jlm::rvsdg::FunctionType::Create(
154 
155  auto lambda = jlm::rvsdg::LambdaNode::Create(
156  graph->GetRootRegion(),
158  auto iOStateArgument = lambda->GetFunctionArguments()[0];
159  auto memoryStateArgument = lambda->GetFunctionArguments()[1];
160 
161  // Create add operation
162  std::cout << "Add Operation" << std::endl;
163  auto constant1 = &jlm::rvsdg::BitConstantOperation::create(*lambda->subregion(), { 32, 4 });
164  auto constant2 = &jlm::rvsdg::BitConstantOperation::create(*lambda->subregion(), { 32, 5 });
165  auto add = jlm::rvsdg::bitadd_op::create(32, constant1, constant2);
166 
167  lambda->finalize({ add, iOStateArgument, memoryStateArgument });
168 
169  // Convert the RVSDG to MLIR
170  std::cout << "Convert to MLIR" << std::endl;
172  auto omega = mlirgen.ConvertModule(*rvsdgModule);
173 
174  // Checking blocks and operations count
175  std::cout << "Checking blocks and operations count" << std::endl;
176  auto & omegaRegion = omega.getRegion();
177  EXPECT_EQ(omegaRegion.getBlocks().size(), 1);
178  auto & omegaBlock = omegaRegion.front();
179  // Lamda + terminating operation
180  EXPECT_EQ(omegaBlock.getOperations().size(), 2);
181 
182  // Checking lambda block operations
183  std::cout << "Checking lambda block operations" << std::endl;
184  auto & mlirLambda = omegaBlock.front();
185  auto & lambdaRegion = mlirLambda.getRegion(0);
186  auto & lambdaBlock = lambdaRegion.front();
187  // 2 Bits contants + add + terminating operation
188  EXPECT_EQ(lambdaBlock.getOperations().size(), 4);
189 
190  // Checking lambda block operations types
191  std::cout << "Checking lambda block operations types" << std::endl;
192  std::vector<mlir::Operation *> operations;
193  for (auto & operation : lambdaBlock.getOperations())
194  {
195  operations.push_back(&operation);
196  }
197 
198  int constCount = 0;
199  for (auto & operation : operations)
200  {
201  if (operation->getName().getStringRef().equals(mlir::rvsdg::LambdaResult::getOperationName()))
202  continue;
203  if (operation->getName().getStringRef().equals(
204  mlir::arith::ConstantIntOp::getOperationName()))
205  {
206  constCount++;
207  continue;
208  }
209  // Checking add operation
210  std::cout << "Checking add operation" << std::endl;
211  EXPECT_TRUE(operation->getName().getStringRef().equals(
212  mlir::arith::AddIOp::getOperationName())); // Last remaining operation is the add
213  // operation
214  EXPECT_EQ(operation->getNumOperands(), 2);
215  auto addOperand1 = operation->getOperand(0);
216  auto addOperand2 = operation->getOperand(1);
217  EXPECT_TRUE(addOperand1.getType().isInteger(32));
218  EXPECT_TRUE(addOperand2.getType().isInteger(32));
219  }
220  EXPECT_EQ(constCount, 2);
221 
223  &lambdaBlock.getOperations().back(),
224  { mlir::arith::ConstantIntOp::getOperationName(),
225  mlir::arith::AddIOp::getOperationName() });
226 
227  omega->destroy();
228  }
229 }
230 
237 TEST(JlmToMlirConverterTests, TestComZeroExt)
238 {
239  using namespace jlm::llvm;
240  using namespace mlir::rvsdg;
241 
242  auto rvsdgModule = LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
243  auto graph = &rvsdgModule->Rvsdg();
244 
245  {
246  // Setup the function
247  std::cout << "Function Setup" << std::endl;
248  auto functionType = jlm::rvsdg::FunctionType::Create(
251 
252  auto lambda = jlm::rvsdg::LambdaNode::Create(
253  graph->GetRootRegion(),
255  auto iOStateArgument = lambda->GetFunctionArguments()[0];
256  auto memoryStateArgument = lambda->GetFunctionArguments()[1];
257 
258  // Create add operation
259  std::cout << "Add Operation" << std::endl;
260  auto constant1 = &jlm::rvsdg::BitConstantOperation::create(*lambda->subregion(), { 8, 4 });
261  jlm::rvsdg::BitConstantOperation::create(*lambda->subregion(), { 16, 5 }); // Unused constant
262  jlm::rvsdg::BitConstantOperation::create(*lambda->subregion(), { 16, 6 }); // Unused constant
263 
264  // zero extension of constant1
265  const auto zeroExt = jlm::rvsdg::CreateOpNode<ZExtOperation>({ constant1 }, 8, 16).output(0);
266 
267  auto mul = jlm::rvsdg::bitmul_op::create(16, zeroExt, zeroExt);
268 
269  auto comp = jlm::rvsdg::bitsgt_op::create(16, mul, mul);
270 
271  lambda->finalize({ comp, iOStateArgument, memoryStateArgument });
272 
273  // Convert the RVSDG to MLIR
274  std::cout << "Convert to MLIR" << std::endl;
276  auto omega = mlirgen.ConvertModule(*rvsdgModule);
277 
278  // Checking blocks and operations count
279  std::cout << "Checking blocks and operations count" << std::endl;
280  auto & omegaRegion = omega.getRegion();
281  EXPECT_EQ(omegaRegion.getBlocks().size(), 1);
282  auto & omegaBlock = omegaRegion.front();
283  // Lamda + terminating operation
284  EXPECT_EQ(omegaBlock.getOperations().size(), 2);
285 
286  // Checking lambda block operations
287  std::cout << "Checking lambda block operations" << std::endl;
288  auto & mlirLambda = omegaBlock.front();
289  auto & lambdaRegion = mlirLambda.getRegion(0);
290  auto & lambdaBlock = lambdaRegion.front();
291  // 3 Bits contants + ZeroExt + Mul + Comp + terminating operation
292  EXPECT_EQ(lambdaBlock.getOperations().size(), 7);
293 
294  // Checking lambda block operations types
295  std::cout << "Checking lambda block operations types" << std::endl;
296  std::vector<mlir::Operation *> operations;
297  for (auto & operation : lambdaBlock.getOperations())
298  {
299  std::cout << "Operation: " << operation.getName().getStringRef().data() << std::endl;
300  operations.push_back(&operation);
301  }
302 
303  int constCount = 0;
304  int extCount = 0;
305  int mulCount = 0;
306  int compCount = 0;
307  for (auto & operation : operations)
308  {
309  if (operation->getName().getStringRef().equals(mlir::rvsdg::LambdaResult::getOperationName()))
310  continue;
311  if (operation->getName().getStringRef().equals(
312  mlir::arith::ConstantIntOp::getOperationName()))
313  {
314  EXPECT_TRUE(
315  operation->getResult(0).getType().isInteger(8)
316  || operation->getResult(0).getType().isInteger(16));
317  constCount++;
318  continue;
319  }
320  if (operation->getName().getStringRef().equals(mlir::arith::ExtUIOp::getOperationName()))
321  {
322  EXPECT_EQ(operation->getNumOperands(), 1);
323  EXPECT_TRUE(operation->getOperand(0).getType().isInteger(8));
324  EXPECT_EQ(operation->getNumResults(), 1);
325  EXPECT_TRUE(operation->getResult(0).getType().isInteger(16));
326  extCount++;
327  continue;
328  }
329  if (operation->getName().getStringRef().equals(mlir::arith::MulIOp::getOperationName()))
330  {
331  EXPECT_EQ(operation->getNumOperands(), 2);
332  EXPECT_TRUE(operation->getOperand(0).getType().isInteger(16));
333  EXPECT_TRUE(operation->getOperand(1).getType().isInteger(16));
334  EXPECT_EQ(operation->getNumResults(), 1);
335  EXPECT_TRUE(operation->getResult(0).getType().isInteger(16));
336  mulCount++;
337  continue;
338  }
339  if (operation->getName().getStringRef().equals(mlir::arith::CmpIOp::getOperationName()))
340  {
341  auto comparisonOp = mlir::cast<mlir::arith::CmpIOp>(operation);
342  EXPECT_EQ(comparisonOp.getPredicate(), mlir::arith::CmpIPredicate::sgt);
343  EXPECT_EQ(operation->getNumOperands(), 2);
344  EXPECT_TRUE(operation->getOperand(0).getType().isInteger(16));
345  EXPECT_TRUE(operation->getOperand(1).getType().isInteger(16));
346  EXPECT_EQ(operation->getNumResults(), 1);
347  compCount++;
348  continue;
349  }
350  FAIL();
351  }
352 
353  // Check counts
354  std::cout << "Checking counts" << std::endl;
355  EXPECT_EQ(constCount, 3);
356  EXPECT_EQ(extCount, 1);
357  EXPECT_EQ(mulCount, 1);
358  EXPECT_EQ(compCount, 1);
359 
361  &lambdaBlock.getOperations().back(),
362  { mlir::arith::ConstantIntOp::getOperationName(),
363  mlir::arith::ExtUIOp::getOperationName(),
364  mlir::arith::MulIOp::getOperationName(),
365  mlir::arith::CmpIOp::getOperationName() });
366 
367  omega->destroy();
368  }
369 }
370 
375 TEST(JlmToMlirConverterTests, TestMatch)
376 {
377  using namespace jlm::llvm;
378  using namespace mlir::rvsdg;
379 
380  auto rvsdgModule = LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
381  auto graph = &rvsdgModule->Rvsdg();
382 
383  {
384  // Setup the function
385  std::cout << "Function Setup" << std::endl;
386  auto functionType = jlm::rvsdg::FunctionType::Create(
389 
390  auto lambda = jlm::rvsdg::LambdaNode::Create(
391  graph->GetRootRegion(),
393  auto iOStateArgument = lambda->GetFunctionArguments()[0];
394  auto memoryStateArgument = lambda->GetFunctionArguments()[1];
395 
396  // Create a match operation
397  std::cout << "Match Operation" << std::endl;
398  auto predicateConst = &jlm::rvsdg::BitConstantOperation::create(*lambda->subregion(), { 8, 4 });
399 
400  auto match =
401  jlm::rvsdg::MatchOperation::Create(*predicateConst, { { 4, 0 }, { 5, 1 }, { 6, 1 } }, 2, 2);
402 
403  lambda->finalize({ match, iOStateArgument, memoryStateArgument });
404 
405  // Convert the RVSDG to MLIR
406  std::cout << "Convert to MLIR" << std::endl;
408  auto omega = mlirgen.ConvertModule(*rvsdgModule);
409 
410  // Checking blocks and operations count
411  std::cout << "Checking blocks and operations count" << std::endl;
412  auto & omegaRegion = omega.getRegion();
413  EXPECT_EQ(omegaRegion.getBlocks().size(), 1);
414  auto & omegaBlock = omegaRegion.front();
415  // Lamda + terminating operation
416  EXPECT_EQ(omegaBlock.getOperations().size(), 2);
417 
418  // Checking lambda block operations
419  std::cout << "Checking lambda block operations" << std::endl;
420  auto & mlirLambda = omegaBlock.front();
421  auto & lambdaRegion = mlirLambda.getRegion(0);
422  auto & lambdaBlock = lambdaRegion.front();
423  // 1 Bits contants + Match + terminating operation
424  EXPECT_EQ(lambdaBlock.getOperations().size(), 3);
425 
426  bool matchFound = false;
427  for (auto & operation : lambdaBlock.getOperations())
428  {
429  if (mlir::isa<mlir::rvsdg::Match>(operation))
430  {
431  matchFound = true;
432  std::cout << "Checking match operation" << std::endl;
433  auto matchOp = mlir::cast<mlir::rvsdg::Match>(operation);
434 
435  EXPECT_TRUE(mlir::isa<mlir::arith::ConstantIntOp>(matchOp.getInput().getDefiningOp()));
436  auto constant = mlir::cast<mlir::arith::ConstantIntOp>(matchOp.getInput().getDefiningOp());
437  EXPECT_EQ(constant.value(), 4);
438  EXPECT_TRUE(constant.getType().isInteger(8));
439 
440  auto mapping = matchOp.getMapping();
441  mapping.dump();
442  // 3 alternatives + default
443  EXPECT_EQ(mapping.size(), 4);
444 
445  // ** region check alternatives *$
446  for (auto & attr : mapping)
447  {
448  EXPECT_TRUE(attr.isa<::mlir::rvsdg::MatchRuleAttr>());
449  auto matchRuleAttr = attr.cast<::mlir::rvsdg::MatchRuleAttr>();
450  if (matchRuleAttr.isDefault())
451  {
452  EXPECT_EQ(matchRuleAttr.getIndex(), 2);
453  EXPECT_TRUE(matchRuleAttr.getValues().empty());
454  continue;
455  }
456 
457  EXPECT_EQ(matchRuleAttr.getValues().size(), 1);
458 
459  const int64_t value = matchRuleAttr.getValues().front();
460 
461  EXPECT_TRUE(
462  (matchRuleAttr.getIndex() == 0 && value == 4)
463  || (matchRuleAttr.getIndex() == 1 && (value == 5 || value == 6)));
464  }
465  // ** endregion check alternatives **
466  }
467  }
468  EXPECT_TRUE(matchFound);
469 
470  omega->destroy();
471  }
472 }
473 
478 TEST(JlmToMlirConverterTests, TestGamma)
479 {
480  using namespace jlm::llvm;
481  using namespace mlir::rvsdg;
482 
483  auto rvsdgModule = LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
484  auto graph = &rvsdgModule->Rvsdg();
485 
486  {
487 
488  // Create a gamma operation
489  std::cout << "Gamma Operation" << std::endl;
490  auto CtrlConstant = &jlm::rvsdg::ControlConstantOperation::create(graph->GetRootRegion(), 3, 1);
491  auto entryvar1 = &jlm::rvsdg::BitConstantOperation::create(graph->GetRootRegion(), { 32, 5 });
492  auto entryvar2 = &jlm::rvsdg::BitConstantOperation::create(graph->GetRootRegion(), { 32, 6 });
493  auto rvsdgGammaNode = jlm::rvsdg::GammaNode::create(
494  CtrlConstant, // predicate
495  3 // nalternatives
496  );
497 
498  rvsdgGammaNode->AddEntryVar(entryvar1);
499  rvsdgGammaNode->AddEntryVar(entryvar2);
500 
501  std::vector<jlm::rvsdg::Output *> exitvars1;
502  std::vector<jlm::rvsdg::Output *> exitvars2;
503  for (int i = 0; i < 3; i++)
504  {
505  exitvars1.push_back(
506  &jlm::rvsdg::BitConstantOperation::create(*rvsdgGammaNode->subregion(i), { 32, i + 1 }));
507  exitvars2.push_back(&jlm::rvsdg::BitConstantOperation::create(
508  *rvsdgGammaNode->subregion(i),
509  { 32, 10 * (i + 1) }));
510  }
511 
512  rvsdgGammaNode->AddExitVar(exitvars1);
513  rvsdgGammaNode->AddExitVar(exitvars2);
514 
515  // Convert the RVSDG to MLIR
516  std::cout << "Convert to MLIR" << std::endl;
518  auto omega = mlirgen.ConvertModule(*rvsdgModule);
519 
520  // Checking blocks and operations count
521  std::cout << "Checking blocks and operations count" << std::endl;
522  auto & omegaRegion = omega.getRegion();
523  EXPECT_EQ(omegaRegion.getBlocks().size(), 1);
524  auto & omegaBlock = omegaRegion.front();
525  // 1 control + 2 constants + gamma + terminating operation
526  EXPECT_EQ(omegaBlock.getOperations().size(), 5);
527 
528  bool gammaFound = false;
529  for (auto & operation : omegaBlock.getOperations())
530  {
531  if (mlir::isa<mlir::rvsdg::GammaNode>(operation))
532  {
533  gammaFound = true;
534  std::cout << "Checking gamma operation" << std::endl;
535  auto gammaOp = mlir::cast<mlir::rvsdg::GammaNode>(operation);
536  EXPECT_EQ(gammaOp.getNumRegions(), 3);
537  // 1 predicate + 2 entryVars
538  EXPECT_EQ(gammaOp.getNumOperands(), 3);
539  EXPECT_EQ(gammaOp.getNumResults(), 2);
540 
541  std::cout << "Checking gamma predicate" << std::endl;
542  EXPECT_TRUE(mlir::isa<mlir::rvsdg::ConstantCtrl>(gammaOp.getPredicate().getDefiningOp()));
543  auto controlConstant =
544  mlir::cast<mlir::rvsdg::ConstantCtrl>(gammaOp.getPredicate().getDefiningOp());
545  EXPECT_EQ(controlConstant.getValue(), 1);
546  EXPECT_TRUE(mlir::isa<mlir::rvsdg::RVSDG_CTRLType>(controlConstant.getType()));
547  auto ctrlType = mlir::cast<mlir::rvsdg::RVSDG_CTRLType>(controlConstant.getType());
548  EXPECT_EQ(ctrlType.getNumOptions(), 3);
549 
550  std::cout << "Checking gamma entryVars" << std::endl;
552  auto entryVars = gammaOp.getInputs();
553  EXPECT_EQ(entryVars.size(), 2);
554  EXPECT_TRUE(mlir::isa<mlir::arith::ConstantIntOp>(entryVars[0].getDefiningOp()));
555  EXPECT_TRUE(mlir::isa<mlir::arith::ConstantIntOp>(entryVars[1].getDefiningOp()));
556  auto entryVar1 = mlir::cast<mlir::arith::ConstantIntOp>(entryVars[0].getDefiningOp());
557  auto entryVar2 = mlir::cast<mlir::arith::ConstantIntOp>(entryVars[1].getDefiningOp());
558  EXPECT_EQ(entryVar1.value(), 5);
559  EXPECT_EQ(entryVar2.value(), 6);
560 
561  std::cout << "Checking gamma subregions" << std::endl;
562  for (size_t i = 0; i < gammaOp.getNumRegions(); i++)
563  {
564  EXPECT_EQ(gammaOp.getRegion(i).getBlocks().size(), 1);
565  auto & gammaBlock = gammaOp.getRegion(i).front();
566  // 2 bit constants + gamma result
567  EXPECT_EQ(gammaBlock.getOperations().size(), 3);
568 
569  std::cout << "Checking gamma exitVars" << std::endl;
570  auto gammaResult = gammaBlock.getTerminator();
571  EXPECT_TRUE(mlir::isa<mlir::rvsdg::GammaResult>(gammaResult));
572  auto gammaResultOp = mlir::cast<mlir::rvsdg::GammaResult>(gammaResult);
573  EXPECT_EQ(gammaResultOp.getNumOperands(), 2);
574  for (size_t j = 0; j < gammaResultOp.getNumOperands(); j++)
575  {
576  EXPECT_TRUE(
577  mlir::isa<mlir::arith::ConstantIntOp>(gammaResultOp.getOperand(j).getDefiningOp()));
578  auto constant =
579  mlir::cast<mlir::arith::ConstantIntOp>(gammaResultOp.getOperand(j).getDefiningOp());
580  EXPECT_EQ(static_cast<size_t>(constant.value()), (1 - j) * (i + 1) + 10 * (i + 1) * j);
581  }
582  }
583  }
584  }
585  EXPECT_TRUE(gammaFound);
586  omega->destroy();
587  }
588 }
589 
594 TEST(JlmToMlirConverterTests, TestTheta)
595 {
596  using namespace jlm::llvm;
597  using namespace mlir::rvsdg;
598 
599  auto rvsdgModule = LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
600  auto graph = &rvsdgModule->Rvsdg();
601 
602  {
603  // Create a theta operation
604  std::cout << "Theta Operation" << std::endl;
605  auto entryvar1 = &jlm::rvsdg::BitConstantOperation::create(graph->GetRootRegion(), { 32, 5 });
606  auto entryvar2 = &jlm::rvsdg::BitConstantOperation::create(graph->GetRootRegion(), { 32, 6 });
607  jlm::rvsdg::ThetaNode * rvsdgThetaNode = jlm::rvsdg::ThetaNode::create(&graph->GetRootRegion());
608 
609  auto predicate =
610  &jlm::rvsdg::ControlConstantOperation::create(*rvsdgThetaNode->subregion(), 2, 0);
611 
612  rvsdgThetaNode->AddLoopVar(entryvar1);
613  rvsdgThetaNode->AddLoopVar(entryvar2);
614  rvsdgThetaNode->set_predicate(predicate);
615 
616  // Convert the RVSDG to MLIR
617  std::cout << "Convert to MLIR" << std::endl;
619  auto omega = mlirgen.ConvertModule(*rvsdgModule);
620 
621  // Checking blocks and operations count
622  std::cout << "Checking blocks and operations count" << std::endl;
623  auto & omegaRegion = omega.getRegion();
624  EXPECT_EQ(omegaRegion.getBlocks().size(), 1);
625  auto & omegaBlock = omegaRegion.front();
626  // 1 theta + 1 predicate + 2 constants
627  EXPECT_EQ(omegaBlock.getOperations().size(), 4);
628 
629  bool thetaFound = false;
630  for (auto & operation : omegaBlock.getOperations())
631  {
632  if (mlir::isa<mlir::rvsdg::ThetaNode>(operation))
633  {
634  thetaFound = true;
635  std::cout << "Checking theta operation" << std::endl;
636  auto thetaOp = mlir::cast<mlir::rvsdg::ThetaNode>(operation);
637  // 2 loop vars
638  EXPECT_EQ(thetaOp.getNumOperands(), 2);
639  EXPECT_EQ(thetaOp.getNumResults(), 2);
640 
641  auto & thetaBlock = thetaOp.getRegion().front();
642  auto thetaResult = thetaBlock.getTerminator();
643 
644  EXPECT_TRUE(mlir::isa<mlir::rvsdg::ThetaResult>(thetaResult));
645  auto thetaResultOp = mlir::cast<mlir::rvsdg::ThetaResult>(thetaResult);
646 
647  std::cout << "Checking theta predicate" << std::endl;
648 
649  EXPECT_TRUE(
650  mlir::isa<mlir::rvsdg::ConstantCtrl>(thetaResultOp.getPredicate().getDefiningOp()));
651  auto controlConstant =
652  mlir::cast<mlir::rvsdg::ConstantCtrl>(thetaResultOp.getPredicate().getDefiningOp());
653 
654  EXPECT_EQ(controlConstant.getValue(), 0);
655 
656  EXPECT_TRUE(mlir::isa<mlir::rvsdg::RVSDG_CTRLType>(controlConstant.getType()));
657  auto ctrlType = mlir::cast<mlir::rvsdg::RVSDG_CTRLType>(controlConstant.getType());
658  EXPECT_EQ(ctrlType.getNumOptions(), 2);
659 
660  std::cout << "Checking theta loop vars" << std::endl;
662  auto loopVars = thetaOp.getInputs();
663  EXPECT_EQ(loopVars.size(), 2);
664  EXPECT_TRUE(mlir::isa<mlir::arith::ConstantIntOp>(loopVars[0].getDefiningOp()));
665  EXPECT_TRUE(mlir::isa<mlir::arith::ConstantIntOp>(loopVars[1].getDefiningOp()));
666  auto loopVar1 = mlir::cast<mlir::arith::ConstantIntOp>(loopVars[0].getDefiningOp());
667  auto loopVar2 = mlir::cast<mlir::arith::ConstantIntOp>(loopVars[1].getDefiningOp());
668  EXPECT_EQ(loopVar1.value(), 5);
669  EXPECT_EQ(loopVar2.value(), 6);
670 
671  // Theta result, constant control predicate
672  EXPECT_EQ(thetaBlock.getOperations().size(), 2);
673 
674  std::cout << "Checking loop exitVars" << std::endl;
675  std::cout << thetaResultOp.getNumOperands() << std::endl;
676 
677  std::cout << "Checking theta subregion" << std::endl;
678 
679  // Two arguments and predicate
680  EXPECT_EQ(thetaResultOp.getNumOperands(), 3);
681  }
682  }
683  // }
684  EXPECT_TRUE(thetaFound);
685  omega->destroy();
686  }
687 }
TEST(JlmToMlirConverterTests, TestLambda)
static void useChainsUpTraverse(mlir::Operation *operation, std::vector< llvm::StringRef > definingOperations)
useChainsUpTraverse
static std::shared_ptr< const IOStateType > Create()
Definition: types.cpp:343
static std::unique_ptr< LlvmLambdaOperation > Create(std::shared_ptr< const jlm::rvsdg::FunctionType > type, std::string name, const jlm::llvm::Linkage &linkage, jlm::llvm::CallingConvention callingConvention, jlm::llvm::AttributeSet attributes)
Definition: lambda.hpp:84
static std::unique_ptr< LlvmRvsdgModule > Create(const util::FilePath &sourceFileName, const std::string &targetTriple, const std::string &dataLayout)
static std::shared_ptr< const MemoryStateType > Create()
Definition: types.cpp:379
::mlir::rvsdg::OmegaNode ConvertModule(const llvm::LlvmRvsdgModule &rvsdgModule)
static Output & create(Region &region, BitValueRepresentation value)
Definition: constant.hpp:44
static std::shared_ptr< const BitType > Create(std::size_t nbits)
Creates bit type of specified width.
Definition: type.cpp:45
static Output & create(Region &region, ControlValueRepresentation value)
Definition: control.hpp:122
static std::shared_ptr< const ControlType > Create(std::size_t nalternatives)
Instantiates control type.
Definition: control.cpp:50
static std::shared_ptr< const FunctionType > Create(std::vector< std::shared_ptr< const jlm::rvsdg::Type >> argumentTypes, std::vector< std::shared_ptr< const jlm::rvsdg::Type >> resultTypes)
static GammaNode * create(jlm::rvsdg::Output *predicate, size_t nalternatives)
Definition: gamma.hpp:161
static LambdaNode * Create(rvsdg::Region &parent, std::unique_ptr< LambdaOperation > operation)
Definition: lambda.cpp:140
std::unique_ptr< BitBinaryOperation > create(size_t nbits) const override
std::unique_ptr< BitCompareOperation > create(size_t nbits) const override
static Output * Create(Output &predicate, const std::unordered_map< uint64_t, uint64_t > &mapping, const uint64_t defaultAlternative, const size_t numAlternatives)
Definition: control.hpp:242
rvsdg::Region * subregion() const noexcept
Definition: theta.hpp:79
void set_predicate(jlm::rvsdg::Output *p)
Definition: theta.hpp:93
static ThetaNode * create(rvsdg::Region *parent)
Definition: theta.hpp:73
LoopVar AddLoopVar(rvsdg::Output *origin)
Creates a new loop-carried variable.
Definition: theta.cpp:49
Global memory state passed between functions.