Jlm
instrument-ref.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2021 David Metz <david.c.metz@ntnu.no>
3  * See COPYING for terms of redistribution.
4  */
5 
11 #include <jlm/rvsdg/gamma.hpp>
12 #include <jlm/rvsdg/traverser.hpp>
13 
14 #include <cmath>
15 
16 namespace jlm::hls
17 {
18 
19 rvsdg::LambdaNode *
20 change_function_name(rvsdg::LambdaNode * ln, const std::string & name)
21 {
22  const auto & op = dynamic_cast<llvm::LlvmLambdaOperation &>(ln->GetOperation());
23  auto lambda = rvsdg::LambdaNode::Create(
24  *ln->region(),
25  llvm::LlvmLambdaOperation::Create(op.Type(), name, op.linkage(), op.attributes()));
26 
27  /* add context variables */
28  rvsdg::SubstitutionMap subregionmap;
29  for (const auto & cv : ln->GetContextVars())
30  {
31  auto origin = cv.input->origin();
32  auto newcv = lambda->AddContextVar(*origin);
33  subregionmap.insert(cv.inner, newcv.inner);
34  }
35  /* collect function arguments */
36  auto args = ln->GetFunctionArguments();
37  auto newArgs = lambda->GetFunctionArguments();
38  JLM_ASSERT(args.size() == newArgs.size());
39  for (std::size_t n = 0; n < args.size(); ++n)
40  {
41  subregionmap.insert(args[n], newArgs[n]);
42  }
43 
44  /* copy subregion */
45  ln->subregion()->copy(lambda->subregion(), subregionmap);
46 
47  /* collect function results */
48  std::vector<jlm::rvsdg::Output *> results;
49  for (auto result : ln->GetFunctionResults())
50  results.push_back(&subregionmap.lookup(*result->origin()));
51 
52  /* finalize lambda */
53  lambda->finalize(results);
54 
55  divert_users(ln, outputs(lambda));
57 
58  return lambda;
59 }
60 
61 void
63 {
64  auto & graph = rm.Rvsdg();
65  auto root = &graph.GetRootRegion();
66  auto lambda = dynamic_cast<rvsdg::LambdaNode *>(root->Nodes().begin().ptr());
67 
68  auto newLambda = change_function_name(lambda, "instrumented_ref");
69 
70  auto functionType = newLambda->GetOperation().type();
71  auto numArguments = functionType.NumArguments();
72  if (numArguments == 0)
73  {
74  // The lambda has no arguments so it shouldn't have any memory operations
75  return;
76  }
77 
78  auto memStateArgumentIndex = numArguments - 1;
79  if (!rvsdg::is<llvm::MemoryStateType>(functionType.ArgumentType(memStateArgumentIndex)))
80  {
81  // The lambda has no memory state so it shouldn't have any memory operations
82  return;
83  }
84  // The function should always have an IO state if it has a memory state
85  auto ioStateArgumentIndex = numArguments - 2;
86  JLM_ASSERT(rvsdg::is<llvm::IOStateType>(functionType.ArgumentType(ioStateArgumentIndex)));
87 
88  // addr, width, memstate
89  auto loadFunctionType = jlm::rvsdg::FunctionType::Create(
95  auto & reference_load = llvm::LlvmGraphImport::Create(
96  graph,
97  loadFunctionType,
98  loadFunctionType,
99  "reference_load",
101  auto & reference_store = llvm::LlvmGraphImport::Create(
102  graph,
103  loadFunctionType,
104  loadFunctionType,
105  "reference_store",
107  // addr, size, memstate
108  auto allocaFunctionType = jlm::rvsdg::FunctionType::Create(
114  auto & reference_alloca = llvm::LlvmGraphImport::Create(
115  graph,
116  allocaFunctionType,
117  allocaFunctionType,
118  "reference_alloca",
120 
122  root,
123  newLambda->subregion()->argument(ioStateArgumentIndex),
124  &reference_load,
125  loadFunctionType,
126  &reference_store,
127  loadFunctionType,
128  &reference_alloca,
129  allocaFunctionType);
130 }
131 
132 void
134  rvsdg::Region * region,
135  jlm::rvsdg::Output * ioState,
136  jlm::rvsdg::Output * load_func,
137  const std::shared_ptr<const jlm::rvsdg::FunctionType> & loadFunctionType,
138  jlm::rvsdg::Output * store_func,
139  const std::shared_ptr<const jlm::rvsdg::FunctionType> & storeFunctionType,
140  jlm::rvsdg::Output * alloca_func,
141  const std::shared_ptr<const jlm::rvsdg::FunctionType> & allocaFunctionType)
142 {
143  load_func = &rvsdg::RouteToRegion(*load_func, *region);
144  store_func = &rvsdg::RouteToRegion(*store_func, *region);
145  alloca_func = &rvsdg::RouteToRegion(*alloca_func, *region);
146  auto void_ptr = jlm::llvm::PointerType::Create();
147  for (auto & node : rvsdg::TopDownTraverser(region))
148  {
149  if (auto structnode = dynamic_cast<rvsdg::StructuralNode *>(node))
150  {
151  for (size_t n = 0; n < structnode->nsubregions(); n++)
152  {
153  auto subregion = structnode->subregion(n);
154  auto ioStateRouted = &rvsdg::RouteToRegion(*ioState, *subregion);
156  subregion,
157  ioStateRouted,
158  load_func,
159  loadFunctionType,
160  store_func,
161  storeFunctionType,
162  alloca_func,
163  allocaFunctionType);
164  }
165  }
166  else if (
167  auto loadOp =
168  dynamic_cast<const jlm::llvm::LoadNonVolatileOperation *>(&(node->GetOperation())))
169  {
170  auto addr = node->input(0)->origin();
171  JLM_ASSERT(rvsdg::is<jlm::llvm::PointerType>(addr->Type()));
172  size_t bitWidth = BaseHLS::JlmSize(&*loadOp->GetLoadedType());
173  int log2Bytes = log2(bitWidth / 8);
174  auto & widthNode = llvm::IntegerConstantOperation::Create(*region, 64, log2Bytes);
175 
176  // Does this IF make sense now when the void_ptr doesn't have a type?
177  if (*addr->Type() != *void_ptr)
178  {
179  addr = llvm::BitCastOperation::create(addr, void_ptr);
180  }
181  auto memstate = node->input(1)->origin();
182  auto callOp = jlm::llvm::CallOperation::Create(
183  load_func,
184  loadFunctionType,
185  { addr, widthNode.output(0), ioState, memstate });
186  // Divert the memory state of the load to the new memstate from the call operation
187  node->input(1)->divert_to(callOp[1]);
188  }
189  else if (auto ao = dynamic_cast<const jlm::llvm::AllocaOperation *>(&(node->GetOperation())))
190  {
191  // ensure that the size is one
192  JLM_ASSERT(node->ninputs() == 1);
193  auto constant_output = dynamic_cast<rvsdg::NodeOutput *>(node->input(0)->origin());
194  JLM_ASSERT(constant_output);
195  auto constant_operation = dynamic_cast<const llvm::IntegerConstantOperation *>(
196  &constant_output->node()->GetOperation());
197  JLM_ASSERT(constant_operation);
198  JLM_ASSERT(constant_operation->Representation().to_uint() == 1);
199  jlm::rvsdg::Output * addr = node->output(0);
200  // ensure that the alloca is an array type
201  JLM_ASSERT(jlm::rvsdg::is<llvm::PointerType>(addr->Type()));
202  auto at = dynamic_cast<const llvm::ArrayType *>(&ao->value_type());
203  JLM_ASSERT(at);
204  auto & sizeNode =
206 
207  // Does this IF make sense now when the void_ptr doesn't have a type?
208  if (*addr->Type() != *void_ptr)
209  {
210  addr = llvm::BitCastOperation::create(addr, void_ptr);
211  }
212  std::vector<jlm::rvsdg::Input *> old_users;
213  for (auto & user : node->output(1)->Users())
214  old_users.push_back(&user);
215  auto memstate = node->output(1);
216  auto callOp = jlm::llvm::CallOperation::Create(
217  alloca_func,
218  allocaFunctionType,
219  { addr, sizeNode.output(0), ioState, memstate });
220  for (auto ou : old_users)
221  {
222  // Divert the memory state of the load to the new memstate from the call operation
223  ou->divert_to(callOp[1]);
224  }
225  }
226  else if (
227  auto so =
228  dynamic_cast<const jlm::llvm::StoreNonVolatileOperation *>(&(node->GetOperation())))
229  {
230  auto addr = node->input(0)->origin();
231  JLM_ASSERT(rvsdg::is<jlm::llvm::PointerType>(addr->Type()));
232  auto bitWidth = JlmSize(&so->GetStoredType());
233  int log2Bytes = log2(bitWidth / 8);
234  auto & widthNode = llvm::IntegerConstantOperation::Create(*region, 64, log2Bytes);
235 
236  // Does this IF make sense now when the void_ptr doesn't have a type?
237  if (*addr->Type() != *void_ptr)
238  {
239  addr = llvm::BitCastOperation::create(addr, void_ptr);
240  }
241  auto memstate = node->output(0);
242  std::vector<jlm::rvsdg::Input *> oldUsers;
243  for (auto & user : memstate->Users())
244  oldUsers.push_back(&user);
245  auto callOp = jlm::llvm::CallOperation::Create(
246  store_func,
247  storeFunctionType,
248  { addr, widthNode.output(0), ioState, memstate });
249  // Divert the memory state after the store to the new memstate from the call operation
250  for (auto user : oldUsers)
251  {
252  user->divert_to(callOp[1]);
253  }
254  }
255  }
256 }
257 
258 } // namespace jlm::hls
static int JlmSize(const jlm::rvsdg::Type *type)
Definition: base-hls.cpp:110
static std::unique_ptr< llvm::ThreeAddressCode > create(const Variable *operand, std::shared_ptr< const jlm::rvsdg::Type > type)
Definition: operators.hpp:1502
static std::vector< rvsdg::Output * > Create(rvsdg::Output *function, std::shared_ptr< const rvsdg::FunctionType > functionType, const std::vector< rvsdg::Output * > &arguments)
Definition: call.hpp:440
static std::shared_ptr< const IOStateType > Create()
Definition: types.cpp:343
static rvsdg::Node & Create(rvsdg::Region &region, IntegerValueRepresentation representation)
static LlvmGraphImport & Create(rvsdg::Graph &graph, std::shared_ptr< const rvsdg::Type > valueType, std::shared_ptr< const rvsdg::Type > importedType, std::string name, Linkage linkage, bool isConstant=false)
Definition: RvsdgModule.hpp:81
Lambda operation.
Definition: lambda.hpp:30
static std::unique_ptr< LlvmLambdaOperation > Create(std::shared_ptr< const jlm::rvsdg::FunctionType > type, std::string name, const jlm::llvm::Linkage &linkage, jlm::llvm::AttributeSet attributes)
Definition: lambda.hpp:77
static std::shared_ptr< const MemoryStateType > Create()
Definition: types.cpp:379
static std::shared_ptr< const PointerType > Create()
Definition: types.cpp:45
static std::shared_ptr< const BitType > Create(std::size_t nbits)
Creates bit type of specified width.
Definition: type.cpp:45
static std::shared_ptr< const FunctionType > Create(std::vector< std::shared_ptr< const jlm::rvsdg::Type >> argumentTypes, std::vector< std::shared_ptr< const jlm::rvsdg::Type >> resultTypes)
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Lambda node.
Definition: lambda.hpp:83
std::vector< rvsdg::Output * > GetFunctionArguments() const
Definition: lambda.cpp:57
rvsdg::Region * subregion() const noexcept
Definition: lambda.hpp:138
static LambdaNode * Create(rvsdg::Region &parent, std::unique_ptr< LambdaOperation > operation)
Definition: lambda.cpp:140
std::vector< rvsdg::Input * > GetFunctionResults() const
Definition: lambda.cpp:69
std::vector< ContextVar > GetContextVars() const noexcept
Gets all bound context variables.
Definition: lambda.cpp:119
LambdaOperation & GetOperation() const noexcept override
Definition: lambda.cpp:51
rvsdg::Region * region() const noexcept
Definition: node.hpp:761
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
Definition: node.hpp:366
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
void copy(Region *target, SubstitutionMap &smap) const
Copy a region with substitutions.
Definition: region.cpp:314
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
void insert(const Output *original, Output *substitute)
Output & lookup(const Output &original) const
#define JLM_ASSERT(x)
Definition: common.hpp:16
int JlmSize(const jlm::rvsdg::Type *type)
Definition: hls.cpp:344
static void divert_users(jlm::rvsdg::Output *output, Context &ctx)
Definition: cne.cpp:504
void instrument_ref(llvm::LlvmRvsdgModule &rm)
rvsdg::LambdaNode * change_function_name(rvsdg::LambdaNode *ln, const std::string &name)
static void remove(Node *node)
Definition: region.hpp:932
Output & RouteToRegion(Output &output, Region &region)
Definition: node.cpp:381
static std::vector< jlm::rvsdg::Output * > outputs(const Node *node)
Definition: node.hpp:1058