Jlm
merge-gamma.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2021 David Metz <david.c.metz@ntnu.no>
3  * See COPYING for terms of redistribution.
4  */
5 
8 #include <jlm/hls/ir/hls.hpp>
10 #include <jlm/rvsdg/gamma.hpp>
11 #include <jlm/rvsdg/region.hpp>
13 #include <jlm/rvsdg/theta.hpp>
14 #include <jlm/rvsdg/traverser.hpp>
15 
16 namespace jlm::hls
17 {
18 
19 bool
21 {
22  auto no = dynamic_cast<rvsdg::NodeOutput *>(output);
23  return no && no->node() == node;
24 }
25 
26 bool
28 {
29  auto arg = dynamic_cast<rvsdg::RegionArgument *>(output);
30  if (arg)
31  {
32  return false;
33  }
34  auto no = dynamic_cast<rvsdg::NodeOutput *>(output);
35  JLM_ASSERT(no);
36  if (no->node() == node)
37  {
38  return true;
39  }
40  for (size_t i = 0; i < no->node()->ninputs(); ++i)
41  {
42  if (depends_on(no->node()->input(i)->origin(), node))
43  {
44  return true;
45  }
46  }
47  return false;
48 }
49 
52 {
53  for (auto & user : origin->Users())
54  {
55  if (rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(user) == gamma)
56  {
57  auto rolevar = gamma->MapInput(user);
58  if (auto entryvar = std::get_if<rvsdg::GammaNode::EntryVar>(&rolevar))
59  {
60  return *entryvar;
61  }
62  }
63  }
64  return gamma->AddEntryVar(origin);
65 }
66 
67 bool
69 {
70  for (auto & user : gamma->predicate()->origin()->Users())
71  {
72  auto other_gamma = rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(user);
73  if (other_gamma && gamma != other_gamma)
74  {
75  // other gamma depending on same predicate
76  JLM_ASSERT(other_gamma->nsubregions() == gamma->nsubregions());
77  bool can_merge = true;
78  for (const auto & ev : gamma->GetEntryVars())
79  {
80  // we only merge gammas whose inputs directly, or not at all, depend on the gamma being
81  // merged into
82  can_merge &= is_output_of(ev.input->origin(), other_gamma)
83  || !depends_on(ev.input->origin(), other_gamma);
84  }
85  for (const auto & oev : other_gamma->GetEntryVars())
86  {
87  // prevent cycles
88  can_merge &= !depends_on(oev.input->origin(), gamma);
89  }
90  if (can_merge)
91  {
92  std::vector<rvsdg::SubstitutionMap> rmap(gamma->nsubregions());
93  // populate argument mappings
94  for (const auto & ev : gamma->GetEntryVars())
95  {
96  if (rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(*ev.input->origin()) == other_gamma)
97  {
98  auto oex = other_gamma->MapOutputExitVar(*ev.input->origin());
99  for (size_t j = 0; j < gamma->nsubregions(); ++j)
100  {
101  rmap[j].insert(ev.branchArgument[j], oex.branchResult[j]->origin());
102  }
103  }
104  else
105  {
106  auto oev = get_entryvar(ev.input->origin(), other_gamma);
107  for (size_t j = 0; j < gamma->nsubregions(); ++j)
108  {
109  rmap[j].insert(ev.branchArgument[j], oev.branchArgument[j]);
110  }
111  }
112  }
113  // copy subregions
114  for (size_t j = 0; j < gamma->nsubregions(); ++j)
115  {
116  gamma->subregion(j)->copy(other_gamma->subregion(j), rmap[j]);
117  }
118  // handle exitvars
119  for (const auto & ex : gamma->GetExitVars())
120  {
121  std::vector<jlm::rvsdg::Output *> operands;
122  for (size_t j = 0; j < ex.branchResult.size(); j++)
123  {
124  operands.push_back(&rmap[j].lookup(*ex.branchResult[j]->origin()));
125  }
126  auto oex = other_gamma->AddExitVar(operands).output;
127  ex.output->divert_users(oex);
128  }
129  remove(gamma);
130  return true;
131  }
132  }
133  }
134  return false;
135 }
136 
137 void
139 {
140  bool changed = true;
141  while (changed)
142  {
143  changed = false;
144  for (auto & node : rvsdg::TopDownTraverser(region))
145  {
146  if (auto structnode = dynamic_cast<rvsdg::StructuralNode *>(node))
147  {
148  for (size_t n = 0; n < structnode->nsubregions(); n++)
149  merge_gamma(structnode->subregion(n));
150  if (auto gamma = dynamic_cast<rvsdg::GammaNode *>(node))
151  {
152  if (merge_gamma(gamma))
153  {
154  changed = true;
155  break;
156  }
157  }
158  }
159  }
160  }
161 }
162 
163 GammaMerge::~GammaMerge() noexcept = default;
164 
167 {}
168 
169 void
171 {
172  merge_gamma(&rvsdgModule.Rvsdg().GetRootRegion());
173 }
174 
175 } // namespace jlm::hls
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
~GammaMerge() noexcept override
Conditional operator / pattern matching.
Definition: gamma.hpp:99
std::variant< MatchVar, EntryVar > MapInput(const rvsdg::Input &input) const
Maps gamma input to its role (match variable or entry variable).
Definition: gamma.cpp:314
EntryVar AddEntryVar(rvsdg::Output *origin)
Routes a variable into the gamma branches.
Definition: gamma.cpp:260
std::vector< ExitVar > GetExitVars() const
Gets all exit variables for this gamma.
Definition: gamma.cpp:361
std::vector< EntryVar > GetEntryVars() const
Gets all entry variables for this gamma.
Definition: gamma.cpp:303
rvsdg::Input * predicate() const noexcept
Definition: gamma.hpp:398
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Output * origin() const noexcept
Definition: node.hpp:58
Node * node() const noexcept
Definition: node.hpp:572
UsersRange Users()
Definition: node.hpp:354
Represents the argument of a region.
Definition: region.hpp:41
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
void copy(Region *target, SubstitutionMap &smap) const
Copy a region with substitutions.
Definition: region.cpp:314
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
size_t nsubregions() const noexcept
rvsdg::Region * subregion(size_t index) const noexcept
Represents an RVSDG transformation.
#define JLM_ASSERT(x)
Definition: common.hpp:16
bool merge_gamma(rvsdg::GammaNode *gamma)
Definition: merge-gamma.cpp:68
bool is_output_of(jlm::rvsdg::Output *output, rvsdg::Node *node)
Definition: merge-gamma.cpp:20
rvsdg::GammaNode::EntryVar get_entryvar(jlm::rvsdg::Output *origin, rvsdg::GammaNode *gamma)
Definition: merge-gamma.cpp:51
bool depends_on(jlm::rvsdg::Output *output, rvsdg::Node *node)
Definition: merge-gamma.cpp:27
static void remove(Node *node)
Definition: region.hpp:932
static std::vector< jlm::rvsdg::Output * > operands(const Node *node)
Definition: node.hpp:1049
size_t ninputs(const rvsdg::Region *region) noexcept
Definition: region.cpp:682
A variable routed into all gamma regions.
Definition: gamma.hpp:131