Jlm
instrument-ref.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2021 David Metz <david.c.metz@ntnu.no>
3  * See COPYING for terms of redistribution.
4  */
5 
14 #include <jlm/rvsdg/gamma.hpp>
15 #include <jlm/rvsdg/traverser.hpp>
16 
17 #include <cmath>
18 
19 namespace jlm::hls
20 {
21 
22 rvsdg::LambdaNode *
23 change_function_name(rvsdg::LambdaNode * ln, const std::string & name)
24 {
25  const auto & op = dynamic_cast<llvm::LlvmLambdaOperation &>(ln->GetOperation());
26  auto lambda = rvsdg::LambdaNode::Create(
27  *ln->region(),
29  op.Type(),
30  name,
31  op.linkage(),
32  op.callingConvention(),
33  op.attributes()));
34 
35  /* add context variables */
36  rvsdg::SubstitutionMap subregionmap;
37  for (const auto & cv : ln->GetContextVars())
38  {
39  auto origin = cv.input->origin();
40  auto newcv = lambda->AddContextVar(*origin);
41  subregionmap.insert(cv.inner, newcv.inner);
42  }
43  /* collect function arguments */
44  auto args = ln->GetFunctionArguments();
45  auto newArgs = lambda->GetFunctionArguments();
46  JLM_ASSERT(args.size() == newArgs.size());
47  for (std::size_t n = 0; n < args.size(); ++n)
48  {
49  subregionmap.insert(args[n], newArgs[n]);
50  }
51 
52  /* copy subregion */
53  ln->subregion()->copy(lambda->subregion(), subregionmap);
54 
55  /* collect function results */
56  std::vector<jlm::rvsdg::Output *> results;
57  for (auto result : ln->GetFunctionResults())
58  results.push_back(&subregionmap.lookup(*result->origin()));
59 
60  /* finalize lambda */
61  lambda->finalize(results);
62 
63  divert_users(ln, outputs(lambda));
65 
66  return lambda;
67 }
68 
69 void
71 {
72  auto & graph = rm.Rvsdg();
73  auto root = &graph.GetRootRegion();
74  auto lambda = dynamic_cast<rvsdg::LambdaNode *>(root->Nodes().begin().ptr());
75 
76  auto newLambda = change_function_name(lambda, "instrumented_ref");
77 
78  auto functionType = newLambda->GetOperation().type();
79  auto numArguments = functionType.NumArguments();
80  if (numArguments == 0)
81  {
82  // The lambda has no arguments so it shouldn't have any memory operations
83  return;
84  }
85 
86  auto memStateArgumentIndex = numArguments - 1;
87  if (!rvsdg::is<llvm::MemoryStateType>(functionType.ArgumentType(memStateArgumentIndex)))
88  {
89  // The lambda has no memory state so it shouldn't have any memory operations
90  return;
91  }
92  // The function should always have an IO state if it has a memory state
93  auto ioStateArgumentIndex = numArguments - 2;
94  JLM_ASSERT(rvsdg::is<llvm::IOStateType>(functionType.ArgumentType(ioStateArgumentIndex)));
95 
96  // addr, width, memstate
97  auto loadFunctionType = jlm::rvsdg::FunctionType::Create(
103  auto & reference_load = llvm::LlvmGraphImport::createFunctionImport(
104  graph,
105  loadFunctionType,
106  "reference_load",
109  auto & reference_store = llvm::LlvmGraphImport::createFunctionImport(
110  graph,
111  loadFunctionType,
112  "reference_store",
115  // addr, size, memstate
116  auto allocaFunctionType = jlm::rvsdg::FunctionType::Create(
122  auto & reference_alloca = llvm::LlvmGraphImport::createFunctionImport(
123  graph,
124  allocaFunctionType,
125  "reference_alloca",
128 
130  root,
131  newLambda->subregion()->argument(ioStateArgumentIndex),
132  &reference_load,
133  loadFunctionType,
134  &reference_store,
135  loadFunctionType,
136  &reference_alloca,
137  allocaFunctionType);
138 }
139 
140 void
142  rvsdg::Region * region,
143  jlm::rvsdg::Output * ioState,
144  jlm::rvsdg::Output * load_func,
145  const std::shared_ptr<const jlm::rvsdg::FunctionType> & loadFunctionType,
146  jlm::rvsdg::Output * store_func,
147  const std::shared_ptr<const jlm::rvsdg::FunctionType> & storeFunctionType,
148  jlm::rvsdg::Output * alloca_func,
149  const std::shared_ptr<const jlm::rvsdg::FunctionType> & allocaFunctionType)
150 {
151  load_func = &rvsdg::RouteToRegion(*load_func, *region);
152  store_func = &rvsdg::RouteToRegion(*store_func, *region);
153  alloca_func = &rvsdg::RouteToRegion(*alloca_func, *region);
154  auto void_ptr = jlm::llvm::PointerType::Create();
155  for (auto & node : rvsdg::TopDownTraverser(region))
156  {
157  if (auto structnode = dynamic_cast<rvsdg::StructuralNode *>(node))
158  {
159  for (size_t n = 0; n < structnode->nsubregions(); n++)
160  {
161  auto subregion = structnode->subregion(n);
162  auto ioStateRouted = &rvsdg::RouteToRegion(*ioState, *subregion);
164  subregion,
165  ioStateRouted,
166  load_func,
167  loadFunctionType,
168  store_func,
169  storeFunctionType,
170  alloca_func,
171  allocaFunctionType);
172  }
173  }
174  else if (
175  auto loadOp =
176  dynamic_cast<const jlm::llvm::LoadNonVolatileOperation *>(&(node->GetOperation())))
177  {
178  auto addr = node->input(0)->origin();
179  JLM_ASSERT(rvsdg::is<jlm::llvm::PointerType>(addr->Type()));
180  size_t bitWidth = BaseHLS::JlmSize(&*loadOp->GetLoadedType());
181  int log2Bytes = log2(bitWidth / 8);
182  auto & widthNode = llvm::IntegerConstantOperation::Create(*region, 64, log2Bytes);
183 
184  // Does this IF make sense now when the void_ptr doesn't have a type?
185  if (*addr->Type() != *void_ptr)
186  {
187  addr = llvm::BitCastOperation::create(addr, void_ptr);
188  }
189  auto memstate = node->input(1)->origin();
190  auto callOp = jlm::llvm::CallOperation::Create(
191  load_func,
192  loadFunctionType,
193  { addr, widthNode.output(0), ioState, memstate });
194  // Divert the memory state of the load to the new memstate from the call operation
195  node->input(1)->divert_to(callOp[1]);
196  }
197  else if (auto ao = dynamic_cast<const jlm::llvm::AllocaOperation *>(&(node->GetOperation())))
198  {
199  // ensure that the size is one
200  JLM_ASSERT(node->ninputs() == 1);
201  auto constant_output = dynamic_cast<rvsdg::NodeOutput *>(node->input(0)->origin());
202  JLM_ASSERT(constant_output);
203  auto constant_operation = dynamic_cast<const llvm::IntegerConstantOperation *>(
204  &constant_output->node()->GetOperation());
205  JLM_ASSERT(constant_operation);
206  JLM_ASSERT(constant_operation->Representation().to_uint() == 1);
207  jlm::rvsdg::Output * addr = node->output(0);
208  // ensure that the alloca is an array type
209  JLM_ASSERT(jlm::rvsdg::is<llvm::PointerType>(addr->Type()));
210  auto at = dynamic_cast<const llvm::ArrayType *>(ao->allocatedType().get());
211  JLM_ASSERT(at);
212  auto & sizeNode =
214 
215  // Does this IF make sense now when the void_ptr doesn't have a type?
216  if (*addr->Type() != *void_ptr)
217  {
218  addr = llvm::BitCastOperation::create(addr, void_ptr);
219  }
220  std::vector<jlm::rvsdg::Input *> old_users;
221  for (auto & user : node->output(1)->Users())
222  old_users.push_back(&user);
223  auto memstate = node->output(1);
224  auto callOp = jlm::llvm::CallOperation::Create(
225  alloca_func,
226  allocaFunctionType,
227  { addr, sizeNode.output(0), ioState, memstate });
228  for (auto ou : old_users)
229  {
230  // Divert the memory state of the load to the new memstate from the call operation
231  ou->divert_to(callOp[1]);
232  }
233  }
234  else if (
235  auto so =
236  dynamic_cast<const jlm::llvm::StoreNonVolatileOperation *>(&(node->GetOperation())))
237  {
238  auto addr = node->input(0)->origin();
239  JLM_ASSERT(rvsdg::is<jlm::llvm::PointerType>(addr->Type()));
240  auto bitWidth = JlmSize(&so->GetStoredType());
241  int log2Bytes = log2(bitWidth / 8);
242  auto & widthNode = llvm::IntegerConstantOperation::Create(*region, 64, log2Bytes);
243 
244  // Does this IF make sense now when the void_ptr doesn't have a type?
245  if (*addr->Type() != *void_ptr)
246  {
247  addr = llvm::BitCastOperation::create(addr, void_ptr);
248  }
249  auto memstate = node->output(0);
250  std::vector<jlm::rvsdg::Input *> oldUsers;
251  for (auto & user : memstate->Users())
252  oldUsers.push_back(&user);
253  auto callOp = jlm::llvm::CallOperation::Create(
254  store_func,
255  storeFunctionType,
256  { addr, widthNode.output(0), ioState, memstate });
257  // Divert the memory state after the store to the new memstate from the call operation
258  for (auto user : oldUsers)
259  {
260  user->divert_to(callOp[1]);
261  }
262  }
263  }
264 }
265 
266 } // namespace jlm::hls
static int JlmSize(const jlm::rvsdg::Type *type)
Definition: base-hls.cpp:110
static std::unique_ptr< llvm::ThreeAddressCode > create(const Variable *operand, std::shared_ptr< const jlm::rvsdg::Type > type)
Definition: operators.hpp:1586
static std::vector< rvsdg::Output * > Create(rvsdg::Output *function, std::shared_ptr< const rvsdg::FunctionType > functionType, const std::vector< rvsdg::Output * > &arguments)
Definition: call.hpp:464
static std::shared_ptr< const IOStateType > Create()
Definition: types.cpp:343
static rvsdg::Node & Create(rvsdg::Region &region, IntegerValueRepresentation representation)
static LlvmGraphImport & createFunctionImport(rvsdg::Graph &graph, std::shared_ptr< const rvsdg::FunctionType > functionType, std::string name, Linkage linkage, CallingConvention callingConvention)
Lambda operation.
Definition: lambda.hpp:30
static std::unique_ptr< LlvmLambdaOperation > Create(std::shared_ptr< const jlm::rvsdg::FunctionType > type, std::string name, const jlm::llvm::Linkage &linkage, jlm::llvm::CallingConvention callingConvention, jlm::llvm::AttributeSet attributes)
Definition: lambda.hpp:84
static std::shared_ptr< const MemoryStateType > Create()
Definition: types.cpp:379
static std::shared_ptr< const PointerType > Create()
Definition: types.cpp:45
static std::shared_ptr< const BitType > Create(std::size_t nbits)
Creates bit type of specified width.
Definition: type.cpp:45
static std::shared_ptr< const FunctionType > Create(std::vector< std::shared_ptr< const jlm::rvsdg::Type >> argumentTypes, std::vector< std::shared_ptr< const jlm::rvsdg::Type >> resultTypes)
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Lambda node.
Definition: lambda.hpp:83
std::vector< rvsdg::Output * > GetFunctionArguments() const
Definition: lambda.cpp:57
rvsdg::Region * subregion() const noexcept
Definition: lambda.hpp:138
static LambdaNode * Create(rvsdg::Region &parent, std::unique_ptr< LambdaOperation > operation)
Definition: lambda.cpp:140
std::vector< rvsdg::Input * > GetFunctionResults() const
Definition: lambda.cpp:69
std::vector< ContextVar > GetContextVars() const noexcept
Gets all bound context variables.
Definition: lambda.cpp:119
LambdaOperation & GetOperation() const noexcept override
Definition: lambda.cpp:51
rvsdg::Region * region() const noexcept
Definition: node.hpp:761
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
Definition: node.hpp:366
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
void copy(Region *target, SubstitutionMap &smap) const
Copy a region with substitutions.
Definition: region.cpp:314
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
void insert(const Output *original, Output *substitute)
Output & lookup(const Output &original) const
#define JLM_ASSERT(x)
Definition: common.hpp:16
int JlmSize(const jlm::rvsdg::Type *type)
Definition: hls.cpp:344
static void divert_users(jlm::rvsdg::Output *output, Context &ctx)
Definition: cne.cpp:504
void instrument_ref(llvm::LlvmRvsdgModule &rm)
rvsdg::LambdaNode * change_function_name(rvsdg::LambdaNode *ln, const std::string &name)
static void remove(Node *node)
Definition: region.hpp:978
Output & RouteToRegion(Output &output, Region &region)
Definition: node.cpp:381
static std::vector< jlm::rvsdg::Output * > outputs(const Node *node)
Definition: node.hpp:1058