Jlm
IntegerOperationsJlmToMlirToJlmTests.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2024 Halvor Linder Henriksen <halvorlinder@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
6 #include <gtest/gtest.h>
7 
11 #include <jlm/llvm/ir/types.hpp>
16 
17 namespace
18 {
19 
20 // Structure to hold all the info needed to test an integer binary operation
21 template<typename JlmOperation, typename MlirOperation>
22 struct IntegerBinaryOpTest
23 {
24  using JlmOpType = JlmOperation;
25  using MlirOpType = MlirOperation;
26  const char * name;
27 };
28 
29 // Template function to test an integer binary operation
30 template<typename JlmOperation, typename MlirOperation>
31 static void
32 TestIntegerBinaryOperation()
33 {
34  using namespace jlm::llvm;
35  using namespace mlir::rvsdg;
36 
37  const size_t nbits = 64;
38  const uint64_t val1 = 2;
39  const uint64_t val2 = 3;
40 
41  auto rvsdgModule = LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
42  auto graph = &rvsdgModule->Rvsdg();
43 
44  {
45  auto constOp1 =
46  &jlm::rvsdg::BitConstantOperation::create(graph->GetRootRegion(), { nbits, val1 });
47  auto constOp2 =
48  &jlm::rvsdg::BitConstantOperation::create(graph->GetRootRegion(), { nbits, val2 });
49  auto binaryOp = JlmOperation(nbits);
50  jlm::rvsdg::SimpleNode::Create(graph->GetRootRegion(), binaryOp.copy(), { constOp1, constOp2 });
51 
52  // Convert the RVSDG to MLIR
53  std::cout << "Convert to MLIR" << std::endl;
55  auto omega = mlirgen.ConvertModule(*rvsdgModule);
56 
57  // Validate the generated MLIR
58  std::cout << "Validate MLIR" << std::endl;
59  auto & omegaRegion = omega.getRegion();
60  auto & omegaBlock = omegaRegion.front();
61  bool opFound = false;
62  for (auto & op : omegaBlock.getOperations())
63  {
64  auto mlirBinaryOp = ::mlir::dyn_cast<MlirOperation>(&op);
65  if (mlirBinaryOp)
66  {
67  auto inputBitType1 =
68  mlirBinaryOp.getOperand(0).getType().template dyn_cast<::mlir::IntegerType>();
69  auto inputBitType2 =
70  mlirBinaryOp.getOperand(1).getType().template dyn_cast<::mlir::IntegerType>();
71  EXPECT_NE(inputBitType1, nullptr);
72  EXPECT_EQ(inputBitType1.getWidth(), nbits);
73  EXPECT_NE(inputBitType2, nullptr);
74  EXPECT_EQ(inputBitType2.getWidth(), nbits);
75  auto outputBitType =
76  mlirBinaryOp.getResult().getType().template dyn_cast<::mlir::IntegerType>();
77  EXPECT_NE(outputBitType, nullptr);
78  EXPECT_EQ(outputBitType.getWidth(), nbits);
79  opFound = true;
80  }
81  }
82  EXPECT_TRUE(opFound);
83 
84  // Convert the MLIR to RVSDG and check the result
85  std::cout << "Converting MLIR to RVSDG" << std::endl;
86  std::unique_ptr<mlir::Block> rootBlock = std::make_unique<mlir::Block>();
87  rootBlock->push_back(omega);
88  auto convertedRvsdgModule = jlm::mlir::MlirToJlmConverter::CreateAndConvert(rootBlock);
89  auto region = &convertedRvsdgModule->Rvsdg().GetRootRegion();
90 
91  {
92  using namespace jlm::llvm;
93 
94  EXPECT_EQ(region->numNodes(), 3);
95  bool foundBinaryOp = false;
96  for (auto & node : region->Nodes())
97  {
98  auto convertedBinaryOp = dynamic_cast<const JlmOperation *>(&node.GetOperation());
99  if (convertedBinaryOp)
100  {
101  EXPECT_EQ(convertedBinaryOp->nresults(), 1);
102  EXPECT_EQ(convertedBinaryOp->narguments(), 2);
103  auto inputBitType1 = jlm::util::assertedCast<const jlm::rvsdg::BitType>(
104  convertedBinaryOp->argument(0).get());
105  EXPECT_EQ(inputBitType1->nbits(), nbits);
106  auto inputBitType2 = jlm::util::assertedCast<const jlm::rvsdg::BitType>(
107  convertedBinaryOp->argument(1).get());
108  EXPECT_EQ(inputBitType2->nbits(), nbits);
109  auto outputBitType = jlm::util::assertedCast<const jlm::rvsdg::BitType>(
110  convertedBinaryOp->result(0).get());
111  EXPECT_EQ(outputBitType->nbits(), nbits);
112  foundBinaryOp = true;
113  }
114  }
115  EXPECT_TRUE(foundBinaryOp);
116  }
117  }
118 }
119 
120 // Macro to define and register a test for an integer binary operation
121 #define REGISTER_INT_BINARY_OP_TEST(JLM_OP, MLIR_NS, MLIR_OP, TEST_NAME) \
122  TEST(IntegerOperationConversionTests, TEST_NAME) \
123  { \
124  return TestIntegerBinaryOperation< \
125  jlm::llvm::Integer##JLM_OP##Operation, \
126  ::mlir::MLIR_NS::MLIR_OP>(); \
127  }
128 
129 // Register tests for all the integer binary operations
130 REGISTER_INT_BINARY_OP_TEST(Add, arith, AddIOp, Add)
131 REGISTER_INT_BINARY_OP_TEST(Sub, arith, SubIOp, Sub)
132 REGISTER_INT_BINARY_OP_TEST(Mul, arith, MulIOp, Mul)
133 REGISTER_INT_BINARY_OP_TEST(SDiv, arith, DivSIOp, DivSI)
134 REGISTER_INT_BINARY_OP_TEST(UDiv, arith, DivUIOp, DivUI)
135 REGISTER_INT_BINARY_OP_TEST(SRem, arith, RemSIOp, RemSI)
136 REGISTER_INT_BINARY_OP_TEST(URem, arith, RemUIOp, RemUI)
137 REGISTER_INT_BINARY_OP_TEST(Shl, LLVM, ShlOp, ShLI)
138 REGISTER_INT_BINARY_OP_TEST(AShr, LLVM, AShrOp, ShRSI)
139 REGISTER_INT_BINARY_OP_TEST(LShr, LLVM, LShrOp, ShRUI)
140 REGISTER_INT_BINARY_OP_TEST(And, arith, AndIOp, AndI)
141 REGISTER_INT_BINARY_OP_TEST(Or, arith, OrIOp, OrI)
142 REGISTER_INT_BINARY_OP_TEST(Xor, arith, XOrIOp, XOrI)
143 
144 // Structure to hold all the info needed to test an integer comparison operation
145 template<typename JlmOperation>
146 struct IntegerComparisonOpTest
147 {
148  using JlmOpType = JlmOperation;
149  ::mlir::arith::CmpIPredicate predicate;
150  const char * name;
151 };
152 
153 // Template function to test an integer comparison operation
154 template<typename JlmOperation>
155 static void
156 TestIntegerComparisonOperation(const IntegerComparisonOpTest<JlmOperation> & test)
157 {
158  using namespace jlm::llvm;
159  using namespace mlir::rvsdg;
160 
161  const size_t nbits = 64;
162  const uint64_t val1 = 2;
163  const uint64_t val2 = 3;
164 
165  auto rvsdgModule = LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
166  auto graph = &rvsdgModule->Rvsdg();
167 
168  {
169  auto constOp1 =
170  &jlm::rvsdg::BitConstantOperation::create(graph->GetRootRegion(), { nbits, val1 });
171  auto constOp2 =
172  &jlm::rvsdg::BitConstantOperation::create(graph->GetRootRegion(), { nbits, val2 });
173  auto compOp = JlmOperation(nbits);
174  jlm::rvsdg::SimpleNode::Create(graph->GetRootRegion(), compOp.copy(), { constOp1, constOp2 });
175 
176  // Convert the RVSDG to MLIR
177  std::cout << "Convert to MLIR" << std::endl;
179  auto omega = mlirgen.ConvertModule(*rvsdgModule);
180 
181  // Validate the generated MLIR
182  std::cout << "Validate MLIR" << std::endl;
183  auto & omegaRegion = omega.getRegion();
184  auto & omegaBlock = omegaRegion.front();
185  bool opFound = false;
186  for (auto & op : omegaBlock.getOperations())
187  {
188  auto mlirCompOp = ::mlir::dyn_cast<::mlir::arith::CmpIOp>(&op);
189  if (mlirCompOp)
190  {
191  auto inputBitType1 = mlirCompOp.getOperand(0).getType().dyn_cast<::mlir::IntegerType>();
192  auto inputBitType2 = mlirCompOp.getOperand(1).getType().dyn_cast<::mlir::IntegerType>();
193  EXPECT_NE(inputBitType1, nullptr);
194  EXPECT_EQ(inputBitType1.getWidth(), nbits);
195  EXPECT_NE(inputBitType2, nullptr);
196  EXPECT_EQ(inputBitType2.getWidth(), nbits);
197 
198  // Check the output type is i1 (boolean)
199  auto outputType = mlirCompOp.getResult().getType().dyn_cast<::mlir::IntegerType>();
200  EXPECT_NE(outputType, nullptr);
201  EXPECT_EQ(outputType.getWidth(), 1);
202 
203  // Verify the predicate is correct
204  EXPECT_EQ(mlirCompOp.getPredicate(), test.predicate);
205  opFound = true;
206  }
207  }
208  EXPECT_TRUE(opFound);
209 
210  // Convert the MLIR to RVSDG and check the result
211  std::cout << "Converting MLIR to RVSDG" << std::endl;
212  std::unique_ptr<mlir::Block> rootBlock = std::make_unique<mlir::Block>();
213  rootBlock->push_back(omega);
214  auto convertedRvsdgModule = jlm::mlir::MlirToJlmConverter::CreateAndConvert(rootBlock);
215  auto region = &convertedRvsdgModule->Rvsdg().GetRootRegion();
216 
217  {
218  using namespace jlm::llvm;
219 
220  EXPECT_EQ(region->numNodes(), 3);
221  bool foundCompOp = false;
222  for (auto & node : region->Nodes())
223  {
224  auto convertedCompOp = dynamic_cast<const JlmOperation *>(&node.GetOperation());
225  if (convertedCompOp)
226  {
227  EXPECT_EQ(convertedCompOp->nresults(), 1);
228  EXPECT_EQ(convertedCompOp->narguments(), 2);
229  auto inputBitType1 = jlm::util::assertedCast<const jlm::rvsdg::BitType>(
230  convertedCompOp->argument(0).get());
231  EXPECT_EQ(inputBitType1->nbits(), nbits);
232  auto inputBitType2 = jlm::util::assertedCast<const jlm::rvsdg::BitType>(
233  convertedCompOp->argument(1).get());
234  EXPECT_EQ(inputBitType2->nbits(), nbits);
235 
236  // Check the output type is bit1 (boolean)
237  auto outputBitType =
238  jlm::util::assertedCast<const jlm::rvsdg::BitType>(convertedCompOp->result(0).get());
239  EXPECT_EQ(outputBitType->nbits(), 1);
240 
241  foundCompOp = true;
242  }
243  }
244  EXPECT_TRUE(foundCompOp);
245  }
246  }
247 }
248 
249 // Macro to define and register a test for an integer comparison operation
250 #define REGISTER_INT_COMP_OP_TEST(JLM_OP, PREDICATE, TEST_NAME) \
251  TEST(IntegerOperationConversionTests, TEST_NAME) \
252  { \
253  IntegerComparisonOpTest<jlm::llvm::Integer##JLM_OP##Operation> test = { \
254  ::mlir::arith::CmpIPredicate::PREDICATE, \
255  #TEST_NAME \
256  }; \
257  return TestIntegerComparisonOperation(test); \
258  }
259 
260 // Register tests for all the integer comparison operations
261 REGISTER_INT_COMP_OP_TEST(Eq, eq, Eq)
262 REGISTER_INT_COMP_OP_TEST(Ne, ne, Ne)
263 REGISTER_INT_COMP_OP_TEST(Slt, slt, Slt)
264 REGISTER_INT_COMP_OP_TEST(Sle, sle, Sle)
265 REGISTER_INT_COMP_OP_TEST(Sgt, sgt, Sgt)
266 REGISTER_INT_COMP_OP_TEST(Sge, sge, Sge)
267 REGISTER_INT_COMP_OP_TEST(Ult, ult, Ult)
268 REGISTER_INT_COMP_OP_TEST(Ule, ule, Ule)
269 REGISTER_INT_COMP_OP_TEST(Ugt, ugt, Ugt)
270 REGISTER_INT_COMP_OP_TEST(Uge, uge, Uge)
271 
272 }
#define REGISTER_INT_COMP_OP_TEST(JLM_OP, PREDICATE, TEST_NAME)
#define REGISTER_INT_BINARY_OP_TEST(JLM_OP, MLIR_NS, MLIR_OP, TEST_NAME)
static std::unique_ptr< LlvmRvsdgModule > Create(const util::FilePath &sourceFileName, const std::string &targetTriple, const std::string &dataLayout)
::mlir::rvsdg::OmegaNode ConvertModule(const llvm::LlvmRvsdgModule &rvsdgModule)
static std::unique_ptr< llvm::LlvmRvsdgModule > CreateAndConvert(std::unique_ptr<::mlir::Block > &block)
static Output & create(Region &region, BitValueRepresentation value)
Definition: constant.hpp:44
static SimpleNode & Create(Region &region, std::unique_ptr< Operation > operation, const std::vector< rvsdg::Output * > &operands)
Definition: simple-node.hpp:49
Global memory state passed between functions.