Jlm
binary.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2010 2011 2012 2014 Helge Bahmann <hcb@chaoticmind.net>
3  * Copyright 2011 2012 2013 2014 Nico Reißmann <nico.reissmann@gmail.com>
4  * See COPYING for terms of redistribution.
5  */
6 
7 #include <jlm/rvsdg/binary.hpp>
9 #include <jlm/rvsdg/region.hpp>
12 #include <jlm/rvsdg/traverser.hpp>
13 #include <jlm/util/strfmt.hpp>
14 
15 #include <deque>
16 
17 namespace jlm::rvsdg
18 {
19 
20 /* binary normal form */
21 
22 namespace
23 {
24 
25 std::vector<jlm::rvsdg::Output *>
26 reduce_operands(const BinaryOperation & op, std::vector<jlm::rvsdg::Output *> args)
27 {
28  /* pair-wise reduce */
29  if (op.is_commutative())
30  {
32  std::move(args),
33  [&op](jlm::rvsdg::Output * arg1, jlm::rvsdg::Output * arg2)
34  {
35  binop_reduction_path_t reduction = op.can_reduce_operand_pair(arg1, arg2);
36  return reduction != binop_reduction_none ? op.reduce_operand_pair(reduction, arg1, arg2)
37  : nullptr;
38  });
39  }
40  else
41  {
43  std::move(args),
44  [&op](jlm::rvsdg::Output * arg1, jlm::rvsdg::Output * arg2)
45  {
46  binop_reduction_path_t reduction = op.can_reduce_operand_pair(arg1, arg2);
47  return reduction != binop_reduction_none ? op.reduce_operand_pair(reduction, arg1, arg2)
48  : nullptr;
49  });
50  }
51 }
52 
53 }
54 
55 BinaryOperation::~BinaryOperation() noexcept = default;
56 
58 BinaryOperation::flags() const noexcept
59 {
60  return flags::none;
61 }
62 
63 std::optional<std::vector<rvsdg::Output *>>
65  const BinaryOperation & operation,
66  const std::vector<rvsdg::Output *> & operands)
67 {
68  JLM_ASSERT(!operands.empty());
69  auto region = operands[0]->region();
70 
71  if (!operation.is_associative())
72  {
73  return std::nullopt;
74  }
75 
76  auto newOperands = base::detail::associative_flatten(
77  operands,
78  [&operation](rvsdg::Output * operand)
79  {
80  auto node = TryGetOwnerNode<Node>(*operand);
81  if (node == nullptr)
82  return false;
83 
84  auto simpleNode = dynamic_cast<const SimpleNode *>(node);
85  if (simpleNode)
86  {
87  auto flattenedBinaryOperation =
88  dynamic_cast<const FlattenedBinaryOperation *>(&simpleNode->GetOperation());
89  return simpleNode->GetOperation() == operation
90  || (flattenedBinaryOperation
91  && flattenedBinaryOperation->bin_operation() == operation);
92  }
93  else
94  {
95  return false;
96  }
97  });
98 
99  if (operands == newOperands)
100  {
101  JLM_ASSERT(newOperands.size() == 2);
102  return std::nullopt;
103  }
104 
105  JLM_ASSERT(newOperands.size() > 2);
106  auto flattenedBinaryOperation =
107  std::make_unique<FlattenedBinaryOperation>(operation, newOperands.size());
108  return outputs(&SimpleNode::Create(*region, std::move(flattenedBinaryOperation), newOperands));
109 }
110 
111 std::optional<std::vector<rvsdg::Output *>>
113  const BinaryOperation & operation,
114  const std::vector<rvsdg::Output *> & operands)
115 {
116  JLM_ASSERT(!operands.empty());
117  auto region = operands[0]->region();
118 
119  auto newOperands = reduce_operands(operation, operands);
120 
121  if (newOperands.size() == 1)
122  {
123  // The operands could be reduced to a single value by applying constant folding.
124  return newOperands;
125  }
126 
127  if (newOperands == operands)
128  {
129  // The operands did not change, which means that none of the normalizations triggered.
130  return std::nullopt;
131  }
132 
133  JLM_ASSERT(newOperands.size() == 2);
134  return outputs(&SimpleNode::Create(*region, operation.copy(), newOperands));
135 }
136 
138 
139 bool
140 FlattenedBinaryOperation::operator==(const Operation & other) const noexcept
141 {
142  const auto op = dynamic_cast<const FlattenedBinaryOperation *>(&other);
143  return op && op->bin_operation() == bin_operation() && op->narguments() == narguments();
144 }
145 
146 std::string
148 {
149  return jlm::util::strfmt("FLATTENED[", op_->debug_string(), "]");
150 }
151 
152 std::unique_ptr<Operation>
154 {
155  std::unique_ptr<BinaryOperation> copied_op(static_cast<BinaryOperation *>(op_->copy().release()));
156  return std::make_unique<FlattenedBinaryOperation>(std::move(copied_op), narguments());
157 }
158 
159 /*
160  FIXME: The reduce_parallel and reduce_linear functions only differ in where they add
161  the new output to the working list. Unify both functions.
162 */
163 
164 static jlm::rvsdg::Output *
165 reduce_parallel(const BinaryOperation & op, const std::vector<jlm::rvsdg::Output *> & operands)
166 {
167  JLM_ASSERT(operands.size() > 1);
168  auto region = operands.front()->region();
169 
170  std::deque<jlm::rvsdg::Output *> worklist(operands.begin(), operands.end());
171  while (worklist.size() > 1)
172  {
173  auto op1 = worklist.front();
174  worklist.pop_front();
175  auto op2 = worklist.front();
176  worklist.pop_front();
177 
178  auto output = SimpleNode::Create(*region, op.copy(), { op1, op2 }).output(0);
179  worklist.push_back(output);
180  }
181 
182  JLM_ASSERT(worklist.size() == 1);
183  return worklist.front();
184 }
185 
186 static jlm::rvsdg::Output *
187 reduce_linear(const BinaryOperation & op, const std::vector<jlm::rvsdg::Output *> & operands)
188 {
189  JLM_ASSERT(operands.size() > 1);
190  auto region = operands.front()->region();
191 
192  std::deque<jlm::rvsdg::Output *> worklist(operands.begin(), operands.end());
193  while (worklist.size() > 1)
194  {
195  auto op1 = worklist.front();
196  worklist.pop_front();
197  auto op2 = worklist.front();
198  worklist.pop_front();
199 
200  auto output = SimpleNode::Create(*region, op.copy(), { op1, op2 }).output(0);
201  worklist.push_front(output);
202  }
203 
204  JLM_ASSERT(worklist.size() == 1);
205  return worklist.front();
206 }
207 
211  const std::vector<jlm::rvsdg::Output *> & operands) const
212 {
213  JLM_ASSERT(operands.size() > 1);
214 
215  static std::unordered_map<
217  std::function<
218  jlm::rvsdg::Output *(const BinaryOperation &, const std::vector<jlm::rvsdg::Output *> &)>>
220 
221  JLM_ASSERT(map.find(reduction) != map.end());
222  return map[reduction](bin_operation(), operands);
223 }
224 
225 void
227  rvsdg::Region * region,
229 {
230  for (auto & node : TopDownTraverser(region))
231  {
232  if (auto simpleNode = dynamic_cast<const SimpleNode *>(node))
233  {
234  auto op = dynamic_cast<const FlattenedBinaryOperation *>(&simpleNode->GetOperation());
235  if (op)
236  {
237  auto output = op->reduce(reduction, operands(node));
238  node->output(0)->divert_users(output);
239  remove(node);
240  }
241  }
242  else if (auto structnode = dynamic_cast<const StructuralNode *>(node))
243  {
244  for (size_t n = 0; n < structnode->nsubregions(); n++)
245  reduce(structnode->subregion(n), reduction);
246  }
247  }
248 
249  JLM_ASSERT(!Region::ContainsOperation<FlattenedBinaryOperation>(*region, true));
250 }
251 
252 std::optional<std::vector<rvsdg::Output *>>
254  const FlattenedBinaryOperation & operation,
255  const std::vector<rvsdg::Output *> & operands)
256 {
257  return NormalizeBinaryOperation(operation.bin_operation(), operands);
258 }
259 
260 }
~BinaryOperation() noexcept override
bool is_associative() const noexcept
Definition: binary.hpp:192
~FlattenedBinaryOperation() noexcept override
const BinaryOperation & bin_operation() const noexcept
Definition: binary.hpp:135
std::string debug_string() const override
Definition: binary.cpp:147
jlm::rvsdg::Output * reduce(const FlattenedBinaryOperation::reduction &reduction, const std::vector< jlm::rvsdg::Output * > &operands) const
Definition: binary.cpp:209
std::unique_ptr< BinaryOperation > op_
Definition: binary.hpp:155
std::unique_ptr< Operation > copy() const override
Definition: binary.cpp:153
virtual std::unique_ptr< Operation > copy() const =0
void divert_users(jlm::rvsdg::Output *new_origin)
Definition: node.hpp:301
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
NodeOutput * output(size_t index) const noexcept
Definition: simple-node.hpp:88
static SimpleNode & Create(Region &region, std::unique_ptr< Operation > operation, const std::vector< rvsdg::Output * > &operands)
Definition: simple-node.hpp:49
size_t narguments() const noexcept
Definition: operation.cpp:17
#define JLM_ASSERT(x)
Definition: common.hpp:16
Container pairwise_reduce(Container args, const Reductor &reductor)
std::vector< jlm::rvsdg::Output * > associative_flatten(std::vector< jlm::rvsdg::Output * > args, const FlattenTester &flatten_tester)
Container commutative_pairwise_reduce(Container args, const Reductor &reductor)
std::optional< std::vector< rvsdg::Output * > > NormalizeBinaryOperation(const BinaryOperation &operation, const std::vector< rvsdg::Output * > &operands)
Applies the reductions implemented in the binary operations reduction functions.
Definition: binary.cpp:112
static void remove(Node *node)
Definition: region.hpp:932
static jlm::rvsdg::Output * reduce_linear(const BinaryOperation &op, const std::vector< jlm::rvsdg::Output * > &operands)
Definition: binary.cpp:187
size_t binop_reduction_path_t
Definition: binary.hpp:19
static jlm::rvsdg::Output * reduce_parallel(const BinaryOperation &op, const std::vector< jlm::rvsdg::Output * > &operands)
Definition: binary.cpp:165
static std::vector< jlm::rvsdg::Output * > operands(const Node *node)
Definition: node.hpp:1049
std::optional< std::vector< rvsdg::Output * > > NormalizeFlattenedBinaryOperation(const FlattenedBinaryOperation &operation, const std::vector< rvsdg::Output * > &operands)
Applies the reductions of the binary operation represented by the flattened binary operation.
Definition: binary.cpp:253
static std::vector< jlm::rvsdg::Output * > outputs(const Node *node)
Definition: node.hpp:1058
std::optional< std::vector< rvsdg::Output * > > FlattenAssociativeBinaryOperation(const BinaryOperation &operation, const std::vector< rvsdg::Output * > &operands)
Flattens a cascade of the same binary operations into a single flattened binary operation.
Definition: binary.cpp:64
static const binop_reduction_path_t binop_reduction_none
Definition: binary.hpp:203
detail::TopDownTraverserGeneric< false > TopDownTraverser
Traverser for visiting every node in a region in a top down order.
Definition: traverser.hpp:308
static std::string strfmt(Args... args)
Definition: strfmt.hpp:35