32 for (
auto & func : response_functions)
34 std::unordered_set<rvsdg::Output *> visited;
35 std::vector<rvsdg::SimpleNode *> reponse_calls;
37 for (
auto & rc : reponse_calls)
40 if (*response_constant == *request_constant)
49 static std::pair<rvsdg::Input *, std::vector<rvsdg::Input *>>
52 std::vector<rvsdg::Input *> encountered_muxes;
63 else if (rvsdg::TryGetOwnerNode<LoopNode>(*state_edge))
68 auto sn = &rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(*si);
69 auto [branchNode, branchOperation] =
70 rvsdg::TryGetSimpleNodeAndOptionalOp<BranchOperation>(*state_edge);
71 auto [muxNode, muxOperation] = rvsdg::TryGetSimpleNodeAndOptionalOp<MuxOperation>(*state_edge);
79 else if (muxOperation && !muxOperation->loop)
82 encountered_muxes.push_back(si);
86 rvsdg::IsOwnerNodeOperation<llvm::MemoryStateMergeOperation>(*state_edge)
87 || rvsdg::IsOwnerNodeOperation<llvm::LambdaExitMemoryStateMergeOperation>(*state_edge))
89 return { state_edge, encountered_muxes };
104 for (
auto si : encountered_muxes)
106 auto & sn = rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(*si);
107 for (
size_t i = 1; i < sn.ninputs(); ++i)
109 if (i != si->index())
112 sn.input(i)->divert_to(state_dummy);
125 auto & merge_node = rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(*merge_in);
126 std::vector<rvsdg::Output *> merge_origins;
127 for (
size_t i = 0; i < merge_node.ninputs(); ++i)
129 if (i != merge_in->index())
131 merge_origins.push_back(merge_node.input(i)->origin());
135 merge_node.output(0)->divert_users(new_merge_output);
147 auto channel = decouple_request->
input(1)->
origin();
154 auto req_mem_state = decouple_request->
input(decouple_request->
ninputs() - 1)->
origin();
158 req_mem_state = sg_out[1];
163 int load_capacity = 10;
164 if (rvsdg::is<const rvsdg::BitType>(decouple_response->input(2)->Type()))
166 auto constant =
trace_constant(decouple_response->input(2)->origin());
167 load_capacity = constant->Representation().to_int();
168 assert(load_capacity >= 0);
172 auto dload_node = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*dload_out[0]);
175 decouple_response->output(0)->divert_users(routed_data);
176 auto response_state_origin = decouple_response->input(decouple_response->ninputs() - 1)->origin();
178 if (decouple_request->
region() != decouple_response->region())
182 *response_state_origin->region(),
183 response_state_origin->Type());
185 decouple_response->output(decouple_response->noutputs() - 1)->divert_users(sg_resp[1]);
187 remove(decouple_response);
198 decouple_response->output(decouple_response->noutputs() - 1)
199 ->divert_users(response_state_origin);
202 *response_state_origin->region(),
203 response_state_origin->Type());
206 dload_node->input(1)->divert_to(sg_resp[0]);
208 state_user->divert_to(sg_resp[1]);
211 remove(decouple_response);
227 std::vector<rvsdg::Node *> & loadNodes,
228 std::vector<rvsdg::Node *> & storeNodes,
229 std::vector<rvsdg::Node *> & decoupleNodes,
230 std::unordered_set<rvsdg::Node *> exclude)
236 for (
size_t n = 0; n < structnode->nsubregions(); n++)
237 gather_mem_nodes(structnode->subregion(n), loadNodes, storeNodes, decoupleNodes, exclude);
241 if (exclude.find(simplenode) != exclude.end())
247 storeNodes.push_back(simplenode);
251 loadNodes.push_back(simplenode);
257 decoupleNodes.push_back(simplenode);
273 std::unordered_set<rvsdg::Output *> & visited,
276 if (!rvsdg::is<llvm::PointerType>(output->
Type()))
281 if (visited.count(output))
286 visited.insert(output);
287 for (
auto & user : output->
Users())
289 if (
auto simplenode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(user))
293 tracedPointerNodes.
storeNodes.push_back(simplenode);
297 tracedPointerNodes.
loadNodes.push_back(simplenode);
307 for (
size_t i = 0; i < simplenode->noutputs(); ++i)
309 TracePointer(simplenode->output(i), visited, tracedPointerNodes);
315 for (
auto & arg : sti->arguments)
324 TracePointer(ber->argument(), visited, tracedPointerNodes);
338 std::vector<TracedPointerNodes>
341 std::vector<TracedPointerNodes> tracedPointerNodes;
344 if (rvsdg::is<llvm::PointerType>(argument->Type()))
346 std::unordered_set<rvsdg::Output *> visited;
347 tracedPointerNodes.emplace_back();
348 TracePointer(argument, visited, tracedPointerNodes.back());
356 std::unordered_set<rvsdg::Output *> visited;
357 tracedPointerNodes.emplace_back();
358 TracePointer(cv.inner, visited, tracedPointerNodes.back());
362 return tracedPointerNodes;
379 for (
auto node : tracedPointerNodes.
loadNodes)
381 auto loadOp = util::assertedCast<const llvm::LoadNonVolatileOperation>(&node->GetOperation());
382 auto sz =
JlmSize(loadOp->GetLoadedType().get());
383 max_width = sz > max_width ? sz : max_width;
385 for (
auto node : tracedPointerNodes.
storeNodes)
387 auto storeOp = util::assertedCast<const llvm::StoreNonVolatileOperation>(&node->GetOperation());
388 auto sz =
JlmSize(&storeOp->GetStoredType());
389 max_width = sz > max_width ? sz : max_width;
391 for (
auto decoupleRequest : tracedPointerNodes.
decoupleNodes)
394 auto channel = decoupleRequest->input(1)->origin();
397 auto sz =
JlmSize(reponse->output(0)->Type().get());
398 max_width = sz > max_width ? sz : max_width;
414 &rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(smap.
lookup(*originalLoad->
output(0)));
416 auto loadAddress = replacedLoad->input(0)->origin();
417 std::vector<rvsdg::Output *> states;
418 for (
size_t i = 1; i < replacedLoad->ninputs(); ++i)
420 states.push_back(replacedLoad->input(i)->origin());
426 size_t load_capacity = 10;
437 for (
size_t i = 0; i < replacedLoad->noutputs(); ++i)
440 replacedLoad->output(i)->divert_users(newLoad->
output(i));
456 &rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(smap.
lookup(*originalStore->
output(0)));
458 auto addr = replacedStore->input(0)->origin();
459 JLM_ASSERT(rvsdg::is<llvm::PointerType>(addr->Type()));
460 auto data = replacedStore->input(1)->origin();
461 std::vector<rvsdg::Output *> states;
462 for (
size_t i = 2; i < replacedStore->ninputs(); ++i)
464 states.push_back(replacedStore->input(i)->origin());
469 for (
size_t i = 0; i < replacedStore->noutputs(); ++i)
477 replacedStore->output(i)->divert_users(bo);
486 size_t argumentIndex,
488 const std::vector<rvsdg::Node *> & originalLoadNodes,
489 const std::vector<rvsdg::Node *> & originalStoreNodes,
490 const std::vector<rvsdg::Node *> & originalDecoupledNodes)
496 std::vector<rvsdg::SimpleNode *> loadNodes;
497 std::vector<std::shared_ptr<const rvsdg::Type>> responseTypes;
498 for (
auto loadNode : originalLoadNodes)
500 auto oldLoadedValue = loadNode->output(0);
502 auto & newLoadNode = rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(smap.
lookup(*oldLoadedValue));
503 loadNodes.push_back(&newLoadNode);
505 util::assertedCast<const llvm::LoadNonVolatileOperation>(&newLoadNode.GetOperation());
506 responseTypes.push_back(loadOp->GetLoadedType());
508 std::vector<rvsdg::SimpleNode *> decoupledNodes;
509 for (
auto decoupleRequest : originalDecoupledNodes)
511 auto oldOutput = decoupleRequest->output(0);
513 auto & decoupledRequestNode =
514 rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(smap.
lookup(*oldOutput));
515 decoupledNodes.push_back(&decoupledRequestNode);
517 auto channel = decoupleRequest->input(1)->origin();
520 auto vt = reponse->output(0)->Type();
521 responseTypes.push_back(vt);
523 std::vector<rvsdg::SimpleNode *> storeNodes;
524 for (
auto storeNode : originalStoreNodes)
526 auto oldOutput = storeNode->output(0);
528 auto & newStoreNode = rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(smap.
lookup(*oldOutput));
529 storeNodes.push_back(&newStoreNode);
531 auto vt = std::make_shared<llvm::MemoryStateType>();
532 responseTypes.push_back(vt);
537 CalculatePortWidth({ originalLoadNodes, originalStoreNodes, originalDecoupledNodes });
539 *lambdaRegion->argument(argumentIndex),
543 std::vector<std::shared_ptr<const rvsdg::Type>> loadTypes;
544 std::vector<rvsdg::Output *> loadAddresses;
545 for (
size_t i = 0; i < loadNodes.size(); ++i)
550 auto replacement =
ReplaceLoad(smap, originalLoadNodes[i], routed);
553 loadAddresses.push_back(address);
554 std::shared_ptr<const rvsdg::Type>
type;
555 if (
auto loadOperation =
dynamic_cast<const LoadOperation *
>(&replacement->GetOperation()))
557 type = loadOperation->GetLoadedType();
563 type = loadOperation->GetLoadedType();
570 loadTypes.push_back(
type);
572 for (
size_t i = 0; i < decoupledNodes.size(); ++i)
574 auto response = responses[loadNodes.size() + i];
575 auto node = decoupledNodes[i];
581 loadAddresses.push_back(addr);
585 std::vector<rvsdg::Output *> storeOperands;
586 for (
size_t i = 0; i < storeNodes.size(); ++i)
588 auto response = responses[loadNodes.size() + decoupledNodes.size() + i];
592 auto replacement =
ReplaceStore(smap, originalStoreNodes[i], routed);
593 auto addr =
route_request_rhls(lambdaRegion, replacement->output(replacement->noutputs() - 2));
594 auto data =
route_request_rhls(lambdaRegion, replacement->output(replacement->noutputs() - 1));
595 storeOperands.push_back(addr);
596 storeOperands.push_back(data);
614 const auto & graph = rvsdgModule.
Rvsdg();
616 if (rootRegion->numNodes() != 1)
618 throw std::logic_error(
"Root should have only one node now");
621 const auto lambda =
dynamic_cast<rvsdg::LambdaNode *
>(rootRegion->Nodes().begin().ptr());
624 throw std::logic_error(
"Node needs to be a lambda");
632 auto oldFunctionType = op.
type();
633 std::vector<std::shared_ptr<const rvsdg::Type>> newArgumentTypes;
634 for (
size_t i = 0; i < oldFunctionType.NumArguments(); ++i)
636 newArgumentTypes.push_back(oldFunctionType.Arguments()[i]);
638 std::vector<std::shared_ptr<const rvsdg::Type>> newResultTypes;
639 for (
size_t i = 0; i < oldFunctionType.NumResults(); ++i)
641 newResultTypes.push_back(oldFunctionType.Results()[i]);
650 std::unordered_set<rvsdg::Node *> accountedNodes;
651 for (
auto & portNode : tracedPointerNodesVector)
657 newArgumentTypes.push_back(responseTypePtr);
658 if (portNode.storeNodes.empty())
660 newResultTypes.push_back(requestTypePtr);
664 newResultTypes.push_back(requestTypePtrWrite);
666 accountedNodes.insert(portNode.loadNodes.begin(), portNode.loadNodes.end());
667 accountedNodes.insert(portNode.storeNodes.begin(), portNode.storeNodes.end());
668 accountedNodes.insert(portNode.decoupleNodes.begin(), portNode.decoupleNodes.end());
670 std::vector<rvsdg::Node *> unknownLoadNodes;
671 std::vector<rvsdg::Node *> unknownStoreNodes;
672 std::vector<rvsdg::Node *> unknownDecoupledNodes;
677 unknownDecoupledNodes,
679 if (!unknownLoadNodes.empty() || !unknownStoreNodes.empty() || !unknownDecoupledNodes.empty())
687 newArgumentTypes.push_back(responseTypePtr);
688 if (unknownStoreNodes.empty())
690 newResultTypes.push_back(requestTypePtr);
694 newResultTypes.push_back(requestTypePtrWrite);
707 for (
const auto & ctxvar : lambda->GetContextVars())
709 smap.
insert(ctxvar.inner, newLambda->AddContextVar(*ctxvar.input->origin()).inner);
712 auto args = lambda->GetFunctionArguments();
713 auto newArgs = newLambda->GetFunctionArguments();
718 for (
size_t i = 0; i < args.size(); ++i)
720 smap.
insert(args[i], newArgs[i]);
722 lambda->subregion()->copy(newLambda->subregion(), smap);
730 std::vector<rvsdg::Output *> newResults;
733 auto newArgumentsIndex = args.size();
734 for (
auto & portNode : tracedPointerNodesVector)
742 portNode.decoupleNodes));
744 if (!unknownLoadNodes.empty() || !unknownStoreNodes.empty() || !unknownDecoupledNodes.empty())
752 unknownDecoupledNodes));
755 std::vector<rvsdg::Output *> originalResults;
756 for (
auto result : lambda->GetFunctionResults())
758 originalResults.push_back(&smap.
lookup(*result->origin()));
760 originalResults.insert(originalResults.end(), newResults.begin(), newResults.end());
761 auto newOut = newLambda->finalize(originalResults);
766 lambda->region()->RemoveResults({ (*lambda->output()->Users().begin()).index() });
772 dne.
Run(*newLambda->subregion(), statisticsCollector);
784 newLambda = util::assertedCast<rvsdg::LambdaNode>(rootRegion->Nodes().begin().ptr());
787 for (
auto cv : decouple_funcs)
792 newLambda->PruneLambdaInputs();
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &value, size_t capacity, bool pass_through=false)
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &addr, jlm::rvsdg::Output &load_result, size_t capacity)
std::shared_ptr< const rvsdg::Type > GetLoadedType() const noexcept
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &addr, const std::vector< jlm::rvsdg::Output * > &states, jlm::rvsdg::Output &load_result)
~MemoryConverter() noexcept override
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
static std::vector< jlm::rvsdg::Output * > create(const std::vector< jlm::rvsdg::Output * > &load_operands, const std::vector< std::shared_ptr< const rvsdg::Type >> &loadTypes, const std::vector< jlm::rvsdg::Output * > &store_operands, rvsdg::Region *)
static std::vector< jlm::rvsdg::Output * > create(rvsdg::Output &result, const std::vector< std::shared_ptr< const rvsdg::Type >> &output_types, int in_width)
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &addr, const std::vector< jlm::rvsdg::Output * > &states)
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &addr, jlm::rvsdg::Output &value, const std::vector< jlm::rvsdg::Output * > &states, jlm::rvsdg::Output &resp)
static void CreateAndRun(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector)
rvsdg::GraphExport * GetRvsdgExport() const noexcept
static std::unique_ptr< LlvmLambdaOperation > Create(std::shared_ptr< const jlm::rvsdg::FunctionType > type, std::string name, const jlm::llvm::Linkage &linkage, jlm::llvm::AttributeSet attributes)
static rvsdg::Output * Create(const std::vector< rvsdg::Output * > &operands)
static jlm::rvsdg::Output * Create(rvsdg::Region ®ion, std::shared_ptr< const jlm::rvsdg::Type > type)
static std::shared_ptr< const BitType > Create(std::size_t nbits)
Creates bit type of specified width.
static std::shared_ptr< const FunctionType > Create(std::vector< std::shared_ptr< const jlm::rvsdg::Type >> argumentTypes, std::vector< std::shared_ptr< const jlm::rvsdg::Type >> resultTypes)
static GraphExport & Create(Output &origin, std::string name)
Region & GetRootRegion() const noexcept
std::vector< rvsdg::Output * > GetFunctionArguments() const
rvsdg::Region * subregion() const noexcept
static LambdaNode * Create(rvsdg::Region &parent, std::unique_ptr< LambdaOperation > operation)
std::vector< ContextVar > GetContextVars() const noexcept
Gets all bound context variables.
const FunctionType & type() const noexcept
bool IsDead() const noexcept
Determines whether the node is dead.
rvsdg::Region * region() const noexcept
NodeOutput * output(size_t index) const noexcept
size_t ninputs() const noexcept
size_t noutputs() const noexcept
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
void divert_users(jlm::rvsdg::Output *new_origin)
Represents the result of a region.
Represent acyclic RVSDG subgraphs.
rvsdg::StructuralNode * node() const noexcept
const SimpleOperation & GetOperation() const noexcept override
NodeInput * input(size_t index) const noexcept
NodeOutput * output(size_t index) const noexcept
void insert(const Output *original, Output *substitute)
bool contains(const Output &original) const noexcept
Output & lookup(const Output &original) const
#define JLM_UNREACHABLE(msg)
std::shared_ptr< const BundleType > get_mem_res_type(std::shared_ptr< const jlm::rvsdg::Type > dataType)
rvsdg::Output * route_response_rhls(rvsdg::Region *target, rvsdg::Output *response)
void gather_mem_nodes(rvsdg::Region *region, std::vector< rvsdg::Node * > &loadNodes, std::vector< rvsdg::Node * > &storeNodes, std::vector< rvsdg::Node * > &decoupleNodes, std::unordered_set< rvsdg::Node * > exclude)
std::shared_ptr< const BundleType > get_mem_req_type(std::shared_ptr< const rvsdg::Type > elementType, bool write)
void trace_function_calls(rvsdg::Output *output, std::vector< rvsdg::SimpleNode * > &calls, std::unordered_set< rvsdg::Output * > &visited)
static rvsdg::Output * ConnectRequestResponseMemPorts(const rvsdg::LambdaNode *lambda, size_t argumentIndex, rvsdg::SubstitutionMap &smap, const std::vector< rvsdg::Node * > &originalLoadNodes, const std::vector< rvsdg::Node * > &originalStoreNodes, const std::vector< rvsdg::Node * > &originalDecoupledNodes)
bool is_function_argument(const rvsdg::LambdaNode::ContextVar &cv)
static rvsdg::SimpleNode * ReplaceLoad(rvsdg::SubstitutionMap &smap, const rvsdg::Node *originalLoad, rvsdg::Output *response)
rvsdg::Output * route_request_rhls(rvsdg::Region *target, rvsdg::Output *request)
static void TracePointer(rvsdg::Output *output, std::unordered_set< rvsdg::Output * > &visited, TracedPointerNodes &tracedPointerNodes)
int JlmSize(const jlm::rvsdg::Type *type)
rvsdg::SimpleNode * find_decouple_response(const rvsdg::LambdaNode *lambda, const llvm::IntegerConstantOperation *request_constant)
static rvsdg::SimpleNode * ReplaceStore(rvsdg::SubstitutionMap &smap, const rvsdg::Node *originalStore, rvsdg::Output *response)
rvsdg::Output * route_to_region_rhls(rvsdg::Region *target, rvsdg::Output *out)
void OptimizeReqMemState(rvsdg::Output *req_mem_state)
rvsdg::LambdaNode * find_containing_lambda(rvsdg::Region *region)
static void ConvertMemory(rvsdg::RvsdgModule &rvsdgModule)
const llvm::IntegerConstantOperation * trace_constant(const rvsdg::Output *dst)
std::vector< TracedPointerNodes > TracePointerArguments(const rvsdg::LambdaNode *lambda)
rvsdg::Input * get_mem_state_user(rvsdg::Output *state_edge)
rvsdg::SimpleNode * ReplaceDecouple(const rvsdg::LambdaNode *lambda, rvsdg::SimpleNode *decouple_request, rvsdg::Output *resp)
static size_t CalculatePortWidth(const TracedPointerNodes &tracedPointerNodes)
std::vector< rvsdg::LambdaNode::ContextVar > find_function_arguments(const rvsdg::LambdaNode *lambda, std::string name_contains)
bool is_dec_req(rvsdg::SimpleNode *node)
void OptimizeResMemState(rvsdg::Output *res_mem_state)
static std::pair< rvsdg::Input *, std::vector< rvsdg::Input * > > TraceEdgeToMerge(rvsdg::Input *state_edge)
CallSummary ComputeCallSummary(const rvsdg::LambdaNode &lambdaNode)
static void remove(Node *node)
static std::string type(const Node *n)
static std::vector< jlm::rvsdg::Output * > outputs(const Node *node)
std::vector< rvsdg::Node * > loadNodes
std::vector< rvsdg::Node * > decoupleNodes
std::vector< rvsdg::Node * > storeNodes