Jlm
stream-conv.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2025 David Metz <david.c.metz@ntnu.no>
3  * See COPYING for terms of redistribution.
4  */
5 
9 #include <jlm/hls/ir/hls.hpp>
10 #include <jlm/rvsdg/lambda.hpp>
11 
12 #include <algorithm>
13 
14 namespace jlm::hls
15 {
16 
17 const int DefaultBufferCapacity = 10;
18 
19 void
21 {
22  int buffer_capacity = DefaultBufferCapacity;
23  // buffer size as second argument
24  if (rvsdg::is<const rvsdg::BitType>(deq_call->input(2)->Type()))
25  {
26  auto constant = trace_constant(deq_call->input(2)->origin());
27  buffer_capacity = constant->Representation().to_int();
28  JLM_ASSERT(buffer_capacity >= 0);
29  }
30  auto buf = BufferOperation::create(*enq_call->input(2)->origin(), buffer_capacity, false)[0];
31  auto routed = route_to_region_rhls(deq_call->region(), buf);
32  // remove call nodes
33  for (size_t i = 0; i < deq_call->ninputs(); ++i)
34  {
35  if (deq_call->input(i)->Type()->Kind() == rvsdg::TypeKind::State)
36  {
37  int oi = deq_call->noutputs() - deq_call->ninputs() + i;
38  deq_call->output(oi)->divert_users(deq_call->input(i)->origin());
39  }
40  }
41  deq_call->output(0)->divert_users(routed);
42  remove(deq_call);
43  for (size_t i = 0; i < enq_call->ninputs(); ++i)
44  {
45  if (enq_call->input(i)->Type()->Kind() == rvsdg::TypeKind::State)
46  {
47  int oi = enq_call->noutputs() - enq_call->ninputs() + i;
48  enq_call->output(oi)->divert_users(enq_call->input(i)->origin());
49  }
50  }
51  remove(enq_call);
52 }
53 
54 static void
56 {
57  const auto & graph = rm.Rvsdg();
58  const auto rootRegion = &graph.GetRootRegion();
59  if (rootRegion->numNodes() != 1)
60  {
61  throw std::logic_error("Root should have only one node now");
62  }
63 
64  const auto lambda = dynamic_cast<rvsdg::LambdaNode *>(rootRegion->Nodes().begin().ptr());
65  if (!lambda)
66  {
67  throw std::logic_error("Node needs to be a lambda");
68  }
69 
70  auto stream_enqs = find_function_arguments(lambda, "hls_stream_enq");
71  auto stream_deqs = find_function_arguments(lambda, "hls_stream_deq");
72  if (stream_enqs.empty())
73  {
74  JLM_ASSERT(stream_deqs.empty());
75  return;
76  }
77  std::vector<rvsdg::SimpleNode *> enq_calls, deq_calls;
78  std::unordered_set<rvsdg::Output *> visited;
79  for (auto stream_enq : stream_enqs)
80  {
81  JLM_ASSERT(stream_enq.inner);
82  trace_function_calls(stream_enq.inner, enq_calls, visited);
83  visited.clear();
84  }
85  for (auto stream_deq : stream_deqs)
86  {
87  trace_function_calls(stream_deq.inner, deq_calls, visited);
88  visited.clear();
89  }
90  JLM_ASSERT(!enq_calls.empty());
91  JLM_ASSERT(!deq_calls.empty());
92  for (auto & enq_call : enq_calls)
93  {
94  auto enq_constant = trace_constant(enq_call->input(1)->origin());
95  for (auto & deq_call : deq_calls)
96  {
97  auto deq_constant = trace_constant(deq_call->input(1)->origin());
98  if (*enq_constant == *deq_constant)
99  {
100  ConnectStreamBuffer(enq_call, deq_call);
101  deq_calls.erase(std::find(deq_calls.begin(), deq_calls.end(), deq_call));
102  break;
103  }
104  }
105  }
106  // clean up routed function pointers
108  util::StatisticsCollector statisticsCollector;
109  dne.Run(*lambda->subregion(), statisticsCollector);
110 
111  std::vector<rvsdg::LambdaNode::ContextVar> remove_vars(stream_enqs);
112  remove_vars.insert(remove_vars.cend(), stream_deqs.begin(), stream_deqs.end());
113  // make sure context vars are actually dead
114  for (auto cv : remove_vars)
115  {
116  JLM_ASSERT(cv.inner->nusers() == 0);
117  }
118  // remove dead cvargs
119  lambda->PruneLambdaInputs();
120 }
121 
122 StreamConversion::~StreamConversion() noexcept = default;
123 
126 {}
127 
128 void
130 {
131  stream_conv(rvsdgModule);
132 }
133 
134 }
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &value, size_t capacity, bool pass_through=false)
Definition: hls.hpp:438
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
Definition: rhls-dne.cpp:516
~StreamConversion() noexcept override
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Output * origin() const noexcept
Definition: node.hpp:58
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
Definition: node.hpp:67
Lambda node.
Definition: lambda.hpp:83
rvsdg::Region * region() const noexcept
Definition: node.hpp:761
size_t ninputs() const noexcept
Definition: node.hpp:609
size_t noutputs() const noexcept
Definition: node.hpp:644
void divert_users(jlm::rvsdg::Output *new_origin)
Definition: node.hpp:301
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
NodeInput * input(size_t index) const noexcept
Definition: simple-node.hpp:82
NodeOutput * output(size_t index) const noexcept
Definition: simple-node.hpp:88
Represents an RVSDG transformation.
#define JLM_ASSERT(x)
Definition: common.hpp:16
void trace_function_calls(rvsdg::Output *output, std::vector< rvsdg::SimpleNode * > &calls, std::unordered_set< rvsdg::Output * > &visited)
void ConnectStreamBuffer(rvsdg::SimpleNode *enq_call, rvsdg::SimpleNode *deq_call)
Definition: stream-conv.cpp:20
rvsdg::Output * route_to_region_rhls(rvsdg::Region *target, rvsdg::Output *out)
const llvm::IntegerConstantOperation * trace_constant(const rvsdg::Output *dst)
static void stream_conv(rvsdg::RvsdgModule &rm)
Definition: stream-conv.cpp:55
const int DefaultBufferCapacity
Definition: stream-conv.cpp:17
std::vector< rvsdg::LambdaNode::ContextVar > find_function_arguments(const rvsdg::LambdaNode *lambda, std::string name_contains)
static void remove(Node *node)
Definition: region.hpp:932
@ State
Designate a state type.