Jlm
RhlsToFirrtlConverter.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2021 Magnus Sjalander <work@sjalander.com> and
3  * David Metz <david.c.metz@ntnu.no>
4  * See COPYING for terms of redistribution.
5  */
6 
10 #include <jlm/util/strfmt.hpp>
11 
12 #include <llvm/ADT/SmallPtrSet.h>
13 
14 namespace jlm::hls
15 {
16 
17 // Handles nodes with 2 inputs and 1 output
18 circt::firrtl::FModuleOp
20 {
21  // Only handles nodes with a single output
22  if (node->noutputs() != 1)
23  {
24  throw std::logic_error(node->DebugString() + " has more than 1 output");
25  }
26 
27  // Create the module and its input/output ports
28  auto module = nodeToModule(node);
29  // Get the body of the module such that we can add contents to the module
30  auto body = module.getBodyBlock();
31 
32  ::llvm::SmallVector<mlir::Value> inBundles;
33 
34  // Get input signals
35  for (size_t i = 0; i < node->ninputs(); i++)
36  {
37  // Get the input bundle
38  auto bundle = GetInPort(module, i);
39  // Get the data signal from the bundle
40  GetSubfield(body, bundle, "data");
41  inBundles.push_back(bundle);
42  }
43 
44  // Get the output bundle
45  auto outBundle = GetOutPort(module, 0);
46  // Get the data signal from the bundle
47  auto outData = GetSubfield(body, outBundle, "data");
48 
49  if (rvsdg::is<llvm::IntegerAddOperation>(node))
50  {
51  auto input0 = GetSubfield(body, inBundles[0], "data");
52  auto input1 = GetSubfield(body, inBundles[1], "data");
53  auto op = AddAddOp(body, input0, input1);
54  // Connect the op to the output data
55  // We drop the carry bit
56  Connect(body, outData, DropMSBs(body, op, 1));
57  }
58  else if (rvsdg::is<llvm::IntegerSubOperation>(node))
59  {
60  auto input0 = GetSubfield(body, inBundles[0], "data");
61  auto input1 = GetSubfield(body, inBundles[1], "data");
62  auto op = AddSubOp(body, input0, input1);
63  // Connect the op to the output data
64  // We drop the carry bit
65  Connect(body, outData, DropMSBs(body, op, 1));
66  }
67  else if (rvsdg::is<llvm::IntegerAndOperation>(node))
68  {
69  auto input0 = GetSubfield(body, inBundles[0], "data");
70  auto input1 = GetSubfield(body, inBundles[1], "data");
71  auto op = AddAndOp(body, input0, input1);
72  // Connect the op to the output data
73  Connect(body, outData, op);
74  }
75  else if (rvsdg::is<llvm::IntegerXorOperation>(node))
76  {
77  auto input0 = GetSubfield(body, inBundles[0], "data");
78  auto input1 = GetSubfield(body, inBundles[1], "data");
79  auto op = AddXorOp(body, input0, input1);
80  // Connect the op to the output data
81  Connect(body, outData, op);
82  }
83  else if (rvsdg::is<llvm::IntegerOrOperation>(node))
84  {
85  auto input0 = GetSubfield(body, inBundles[0], "data");
86  auto input1 = GetSubfield(body, inBundles[1], "data");
87  auto op = AddOrOp(body, input0, input1);
88  // Connect the op to the output data
89  Connect(body, outData, op);
90  }
91  else if (auto bitmulOp = dynamic_cast<const llvm::IntegerMulOperation *>(&(node->GetOperation())))
92  {
93  auto input0 = GetSubfield(body, inBundles[0], "data");
94  auto input1 = GetSubfield(body, inBundles[1], "data");
95  auto op = AddMulOp(body, input0, input1);
96  // Connect the op to the output data
97  // Multiplication results are double the input width, so we drop the upper half of the result
98  Connect(body, outData, DropMSBs(body, op, bitmulOp->Type().nbits()));
99  }
100  else if (rvsdg::is<llvm::IntegerSDivOperation>(node))
101  {
102  auto input0 = GetSubfield(body, inBundles[0], "data");
103  auto input1 = GetSubfield(body, inBundles[1], "data");
104  auto sIntOp0 = AddAsSIntOp(body, input0);
105  auto sIntOp1 = AddAsSIntOp(body, input1);
106  auto divOp = AddDivOp(body, sIntOp0, sIntOp1);
107  auto uIntOp = AddAsUIntOp(body, divOp);
108  // Connect the op to the output data
109  Connect(body, outData, DropMSBs(body, uIntOp, 1));
110  }
111  else if (rvsdg::is<llvm::IntegerLShrOperation>(node))
112  {
113  auto input0 = GetSubfield(body, inBundles[0], "data");
114  auto input1 = GetSubfield(body, inBundles[1], "data");
115  auto op = AddDShrOp(body, input0, input1);
116  // Connect the op to the output data
117  Connect(body, outData, op);
118  }
119  else if (rvsdg::is<llvm::IntegerAShrOperation>(node))
120  {
121  auto input0 = GetSubfield(body, inBundles[0], "data");
122  auto input1 = GetSubfield(body, inBundles[1], "data");
123  auto sIntOp0 = AddAsSIntOp(body, input0);
124  auto shrOp = AddDShrOp(body, sIntOp0, input1);
125  auto uIntOp = AddAsUIntOp(body, shrOp);
126  // Connect the op to the output data
127  Connect(body, outData, uIntOp);
128  }
129  else if (rvsdg::is<llvm::IntegerShlOperation>(node))
130  {
131  auto input0 = GetSubfield(body, inBundles[0], "data");
132  auto input1 = GetSubfield(body, inBundles[1], "data");
133  auto bitsOp = AddBitsOp(body, input1, 7, 0);
134  auto op = AddDShlOp(body, input0, bitsOp);
135  int outSize = JlmSize(node->output(0)->Type().get());
136  auto slice = AddBitsOp(body, op, outSize - 1, 0);
137  // Connect the op to the output data
138  Connect(body, outData, slice);
139  }
140  else if (rvsdg::is<llvm::IntegerSRemOperation>(node))
141  {
142  auto input0 = GetSubfield(body, inBundles[0], "data");
143  auto input1 = GetSubfield(body, inBundles[1], "data");
144  auto sIntOp0 = AddAsSIntOp(body, input0);
145  auto sIntOp1 = AddAsSIntOp(body, input1);
146  auto remOp = AddRemOp(body, sIntOp0, sIntOp1);
147  auto uIntOp = AddAsUIntOp(body, remOp);
148  Connect(body, outData, uIntOp);
149  }
150  else if (rvsdg::is<llvm::IntegerEqOperation>(node))
151  {
152  auto input0 = GetSubfield(body, inBundles[0], "data");
153  auto input1 = GetSubfield(body, inBundles[1], "data");
154  auto op = AddEqOp(body, input0, input1);
155  // Connect the op to the output data
156  Connect(body, outData, op);
157  }
158  else if (rvsdg::is<llvm::IntegerNeOperation>(node))
159  {
160  auto input0 = GetSubfield(body, inBundles[0], "data");
161  auto input1 = GetSubfield(body, inBundles[1], "data");
162  auto op = AddNeqOp(body, input0, input1);
163  // Connect the op to the output data
164  Connect(body, outData, op);
165  }
166  else if (rvsdg::is<llvm::IntegerSgtOperation>(node))
167  {
168  auto input0 = GetSubfield(body, inBundles[0], "data");
169  auto input1 = GetSubfield(body, inBundles[1], "data");
170  auto sIntOp0 = AddAsSIntOp(body, input0);
171  auto sIntOp1 = AddAsSIntOp(body, input1);
172  auto op = AddGtOp(body, sIntOp0, sIntOp1);
173  // Connect the op to the output data
174  Connect(body, outData, op);
175  }
176  else if (rvsdg::is<llvm::IntegerUltOperation>(node))
177  {
178  auto input0 = GetSubfield(body, inBundles[0], "data");
179  auto input1 = GetSubfield(body, inBundles[1], "data");
180  auto op = AddLtOp(body, input0, input1);
181  // Connect the op to the output data
182  Connect(body, outData, op);
183  }
184  else if (rvsdg::is<llvm::IntegerUleOperation>(node))
185  {
186  auto input0 = GetSubfield(body, inBundles[0], "data");
187  auto input1 = GetSubfield(body, inBundles[1], "data");
188  auto op = AddLeqOp(body, input0, input1);
189  // Connect the op to the output data
190  Connect(body, outData, op);
191  }
192  else if (rvsdg::is<llvm::IntegerUgtOperation>(node))
193  {
194  auto input0 = GetSubfield(body, inBundles[0], "data");
195  auto input1 = GetSubfield(body, inBundles[1], "data");
196  auto op = AddGtOp(body, input0, input1);
197  // Connect the op to the output data
198  Connect(body, outData, op);
199  }
200  else if (rvsdg::is<llvm::IntegerSgeOperation>(node))
201  {
202  auto input0 = GetSubfield(body, inBundles[0], "data");
203  auto input1 = GetSubfield(body, inBundles[1], "data");
204  auto sIntOp0 = AddAsSIntOp(body, input0);
205  auto sIntOp1 = AddAsSIntOp(body, input1);
206  auto op = AddGeqOp(body, sIntOp0, sIntOp1);
207  // Connect the op to the output data
208  Connect(body, outData, op);
209  }
210  else if (rvsdg::is<llvm::IntegerUgeOperation>(node))
211  {
212  auto input0 = GetSubfield(body, inBundles[0], "data");
213  auto input1 = GetSubfield(body, inBundles[1], "data");
214  auto op = AddGeqOp(body, input0, input1);
215  // Connect the op to the output data
216  Connect(body, outData, op);
217  }
218  else if (rvsdg::is<llvm::IntegerSleOperation>(node))
219  {
220  auto input0 = GetSubfield(body, inBundles[0], "data");
221  auto input1 = GetSubfield(body, inBundles[1], "data");
222  auto sIntOp0 = AddAsSIntOp(body, input0);
223  auto sIntOp1 = AddAsSIntOp(body, input1);
224  auto op = AddLeqOp(body, sIntOp0, sIntOp1);
225  // Connect the op to the output data
226  Connect(body, outData, op);
227  }
228  else if (rvsdg::is<llvm::IntegerSltOperation>(node))
229  {
230  auto input0 = GetSubfield(body, inBundles[0], "data");
231  auto input1 = GetSubfield(body, inBundles[1], "data");
232  auto sInt0 = AddAsSIntOp(body, input0);
233  auto sInt1 = AddAsSIntOp(body, input1);
234  auto op = AddLtOp(body, sInt0, sInt1);
235  Connect(body, outData, op);
236  }
237  else if (dynamic_cast<const llvm::ZExtOperation *>(&(node->GetOperation())))
238  {
239  auto input0 = GetSubfield(body, inBundles[0], "data");
240  Connect(body, outData, input0);
241  }
242  else if (rvsdg::is<const llvm::TruncOperation>(node->GetOperation()))
243  {
244  auto inData = GetSubfield(body, inBundles[0], "data");
245  int outSize = JlmSize(node->output(0)->Type().get());
246  Connect(body, outData, AddBitsOp(body, inData, outSize - 1, 0));
247  }
248  else if (dynamic_cast<const llvm::LambdaExitMemoryStateMergeOperation *>(&(node->GetOperation())))
249  {
250  auto inData = GetSubfield(body, inBundles[0], "data");
251  Connect(body, outData, inData);
252  }
253  else if (dynamic_cast<const llvm::MemoryStateMergeOperation *>(&(node->GetOperation())))
254  {
255  auto inData = GetSubfield(body, inBundles[0], "data");
256  Connect(body, outData, inData);
257  }
258  else if (auto op = dynamic_cast<const llvm::SExtOperation *>(&(node->GetOperation())))
259  {
260  auto input0 = GetSubfield(body, inBundles[0], "data");
261  auto sintOp = AddAsSIntOp(body, input0);
262  auto padOp = AddPadOp(body, sintOp, op->ndstbits());
263  auto uintOp = AddAsUIntOp(body, padOp);
264  Connect(body, outData, uintOp);
265  }
266  else if (auto op = dynamic_cast<const llvm::IntegerConstantOperation *>(&(node->GetOperation())))
267  {
268  auto & value = op->Representation();
269  auto size = value.nbits();
270  // Create a constant of UInt<size>(value) and connect to output data
271  auto constant = GetConstant(body, size, value.to_uint());
272  Connect(body, outData, constant);
273  }
274  else if (
275  auto op = dynamic_cast<const jlm::rvsdg::ControlConstantOperation *>(&(node->GetOperation())))
276  {
277  auto value = op->value().alternative();
278  auto size = ceil(log2(op->value().nalternatives()));
279  auto constant = GetConstant(body, size, value);
280  Connect(body, outData, constant);
281  }
282  else if (dynamic_cast<const llvm::BitCastOperation *>(&(node->GetOperation())))
283  {
284  auto input0 = GetSubfield(body, inBundles[0], "data");
285  Connect(body, outData, input0);
286  }
287  else if (dynamic_cast<const llvm::IntegerToPointerOperation *>(&(node->GetOperation())))
288  {
289  auto input0 = GetSubfield(body, inBundles[0], "data");
290  Connect(body, outData, input0);
291  }
292  else if (auto op = dynamic_cast<const jlm::rvsdg::MatchOperation *>(&(node->GetOperation())))
293  {
294  auto inData = GetSubfield(body, inBundles[0], "data");
295  auto outData = GetSubfield(body, outBundle, "data");
296  int inSize = JlmSize(node->input(0)->Type().get());
297  int outSize = JlmSize(node->output(0)->Type().get());
298  if (IsIdentityMapping(*op))
299  {
300  if (inSize == outSize)
301  {
302  Connect(body, outData, inData);
303  }
304  else
305  {
306  Connect(body, outData, AddBitsOp(body, inData, outSize - 1, 0));
307  }
308  }
309  else
310  {
311  auto size = op->nbits();
312  mlir::Value result = GetConstant(body, size, op->default_alternative());
313  for (auto it = op->begin(); it != op->end(); it++)
314  {
315  auto comparison = AddEqOp(body, inData, GetConstant(body, size, it->first));
316  auto value = GetConstant(body, size, it->second);
317  result = AddMuxOp(body, comparison, value, result);
318  }
319  if ((unsigned long)outSize != size)
320  {
321  result = AddBitsOp(body, result, outSize - 1, 0);
322  }
323  Connect(body, outData, result);
324  }
325  }
326  else if (auto op = dynamic_cast<const llvm::GetElementPtrOperation *>(&(node->GetOperation())))
327  {
328  // Start of with base pointer
329  auto input0 = GetSubfield(body, inBundles[0], "data");
330  mlir::Value result = AddCvtOp(body, input0);
331 
332  // TODO: support structs
333  const jlm::rvsdg::Type * pointeeType = &op->GetPointeeType();
334  for (size_t i = 1; i < node->ninputs(); i++)
335  {
336  int bits = JlmSize(pointeeType);
337  if (dynamic_cast<const rvsdg::BitType *>(pointeeType)
338  || dynamic_cast<const llvm::FloatingPointType *>(pointeeType))
339  {
340  pointeeType = nullptr;
341  }
342  else if (auto arrayType = dynamic_cast<const llvm::ArrayType *>(pointeeType))
343  {
344  pointeeType = &arrayType->element_type();
345  }
346  else if (auto vectorType = dynamic_cast<const llvm::VectorType *>(pointeeType))
347  {
348  pointeeType = vectorType->Type().get();
349  }
350  else
351  {
352  throw std::logic_error(pointeeType->debug_string() + " pointer not implemented!");
353  }
354  // GEP inputs are signed
355  auto input = GetSubfield(body, inBundles[i], "data");
356  auto asSInt = AddAsSIntOp(body, input);
357  int bytes = bits / 8;
358  auto constantOp = GetConstant(body, GetPointerSizeInBits(), bytes);
359  auto cvtOp = AddCvtOp(body, constantOp);
360  auto offset = AddMulOp(body, asSInt, cvtOp);
361  result = AddAddOp(body, result, offset);
362  }
363  auto asUInt = AddAsUIntOp(body, result);
364  Connect(body, outData, AddBitsOp(body, asUInt, GetPointerSizeInBits() - 1, 0));
365  }
366  else if (auto op = dynamic_cast<const llvm::ExtractElementOperation *>(&(node->GetOperation())))
367  {
368  // Start of with base pointer
369  auto input0 = GetSubfield(body, inBundles[0], "data");
370  auto input1 = GetSubfield(body, inBundles[1], "data");
371  auto vt = dynamic_cast<const llvm::VectorType *>(op->argument(0).get());
372  auto vec = Builder_->create<circt::firrtl::WireOp>(
373  Builder_->getUnknownLoc(),
374  circt::firrtl::FVectorType::get(GetFirrtlType(vt->Type().get()), vt->size()),
375  "vec");
376  auto elementBits = JlmSize(vt->Type().get());
377  body->push_back(vec);
378  for (size_t i = 0; i < vt->size(); ++i)
379  {
380  auto subindexOp = Builder_->create<circt::firrtl::SubindexOp>(
381  Builder_->getUnknownLoc(),
382  vec.getResult(),
383  i);
384  body->push_back(subindexOp);
385  Connect(
386  body,
387  subindexOp,
388  AddBitsOp(body, input0, elementBits * (i + 1) - 1, elementBits * i));
389  }
390  auto subaccessOp = Builder_->create<circt::firrtl::SubaccessOp>(
391  Builder_->getUnknownLoc(),
392  vec.getResult(),
393  input1);
394  body->push_back(subaccessOp);
395  Connect(body, outData, subaccessOp);
396  }
397  else if (dynamic_cast<const llvm::UndefValueOperation *>(&(node->GetOperation())))
398  {
399  ConnectInvalid(body, outData);
400  }
401  else if (auto op = dynamic_cast<const MuxOperation *>(&(node->GetOperation())))
402  {
403  JLM_ASSERT(op->discarding);
404  auto select = GetSubfield(body, inBundles[0], "data");
405  ConnectInvalid(body, outData);
406  for (size_t i = 1; i < node->ninputs(); i++)
407  {
408  auto data = GetSubfield(body, inBundles[i], "data");
409  auto constant = GetConstant(body, JlmSize(node->input(0)->Type().get()), i - 1);
410  auto eqOp = AddEqOp(body, select, constant);
411  auto whenOp = AddWhenOp(body, eqOp, false);
412  auto thenBody = whenOp.getThenBodyBuilder().getBlock();
413  Connect(thenBody, outData, data);
414  }
415  }
416  else
417  {
418  throw std::logic_error("Simple node " + node->DebugString() + " not implemented!");
419  }
420 
421  // Generate the output valid signal
422  auto oneBitValue = GetConstant(body, 1, 1);
423  mlir::Value prevAnd = oneBitValue;
424  for (size_t i = 0; i < node->ninputs(); i++)
425  {
426  auto bundle = inBundles[i];
427  prevAnd = AddAndOp(body, prevAnd, GetSubfield(body, bundle, "valid"));
428  }
429  // Connect the valide signal to the output bundle
430  auto outValid = GetSubfield(body, outBundle, "valid");
431  Connect(body, outValid, prevAnd);
432 
433  // Generate the ready signal
434  auto outReady = GetSubfield(body, outBundle, "ready");
435  auto andReady = AddAndOp(body, outReady, prevAnd);
436  // Connect it to the ready signal of the two input bundles
437  for (size_t i = 0; i < node->ninputs(); i++)
438  {
439  auto bundle = inBundles[i];
440  auto ready = GetSubfield(body, bundle, "ready");
441  Connect(body, ready, andReady);
442  }
443 
444  return module;
445 }
446 
447 circt::firrtl::FModuleOp
449 {
450  // Create the module and its input/output ports
451  auto module = nodeToModule(node);
452  auto body = module.getBodyBlock();
453 
454  // Create a constant of UInt<1>(1)
455  auto intType = GetIntType(1);
456  auto constant = Builder_->create<circt::firrtl::ConstantOp>(
457  Builder_->getUnknownLoc(),
458  intType,
459  ::llvm::APInt(1, 1));
460  body->push_back(constant);
461 
462  // Get the input bundle
463  auto bundle = GetInPort(module, 0);
464  // Get the ready signal from the bundle (first signal in the bundle)
465  auto ready = GetSubfield(body, bundle, "ready");
466  // Connect the constant to the ready signal
467  Connect(body, ready, constant);
468 
469  return module;
470 }
471 
472 circt::firrtl::FModuleOp
474 {
475  // Create the module and its input/output ports
476  auto module = nodeToModule(node);
477  auto body = module.getBodyBlock();
478 
479  auto clock = GetClockSignal(module);
480 
481  // Input signals
482  auto predBundle = GetInPort(module, 0);
483  auto predReady = GetSubfield(body, predBundle, "ready");
484  auto predValid = GetSubfield(body, predBundle, "valid");
485  auto predData = GetSubfield(body, predBundle, "data");
486 
487  auto inBundle = GetInPort(module, 1);
488  auto inReady = GetSubfield(body, inBundle, "ready");
489  auto inValid = GetSubfield(body, inBundle, "valid");
490  auto inData = GetSubfield(body, inBundle, "data");
491 
492  // Output signals
493  auto outBundle = GetOutPort(module, 0);
494  auto outReady = GetSubfield(body, outBundle, "ready");
495  auto outValid = GetSubfield(body, outBundle, "valid");
496  auto outData = GetSubfield(body, outBundle, "data");
497 
498  auto dataReg = Builder_->create<circt::firrtl::RegOp>(
499  Builder_->getUnknownLoc(),
500  GetIntType(node->input(1)->Type().get()),
501  clock,
502  Builder_->getStringAttr("data_reg"));
503  body->push_back(dataReg);
504  // predicate 0 updates register, passes through and consumes input
505  // we always start with predicate 0 due to pred_buf
506  // predicate 1 uses data in register
507  Connect(body, predReady, AddAndOp(body, outReady, outValid));
508  Connect(
509  body,
510  inReady,
511  AddAndOp(body, AddAndOp(body, outReady, AddNotOp(body, predData)), predValid));
512 
513  Connect(body, outValid, AddAndOp(body, AddOrOp(body, predData, inValid), predValid));
514  Connect(body, outData, dataReg.getResult());
515  auto dataPassThrough = AddAndOp(body, inValid, AddNotOp(body, predData));
516  auto dataPassThroughBody =
517  AddWhenOp(body, dataPassThrough, false).getThenBodyBuilder().getBlock();
518  Connect(dataPassThroughBody, outData, inData);
519 
520  auto inFire = AddAndOp(body, inReady, inValid);
521  auto inFireBody = AddWhenOp(body, inFire, false).getThenBodyBuilder().getBlock();
522  Connect(inFireBody, dataReg.getResult(), inData);
523 
524  return module;
525 }
526 
527 circt::firrtl::FModuleOp
529 {
530  auto op = dynamic_cast<const jlm::hls::ForkOperation *>(&node->GetOperation());
531  bool isConstant = op->IsConstant();
532  // Create the module and its input/output ports
533  auto module = nodeToModule(node);
534  auto body = module.getBodyBlock();
535 
536  // Input signals
537  auto inBundle = GetInPort(module, 0);
538  auto inReady = GetSubfield(body, inBundle, "ready");
539  auto inValid = GetSubfield(body, inBundle, "valid");
540  auto inData = GetSubfield(body, inBundle, "data");
541 
542  auto oneBitValue = GetConstant(body, 1, 1);
543  auto zeroBitValue = GetConstant(body, 1, 0);
544 
545  //
546  // Output registers
547  //
548  if (isConstant)
549  {
550  Connect(body, inReady, oneBitValue);
551  for (size_t i = 0; i < node->noutputs(); ++i)
552  {
553  // Get the bundle
554  auto port = GetOutPort(module, i);
555  auto portValid = GetSubfield(body, port, "valid");
556  auto portData = GetSubfield(body, port, "data");
557  Connect(body, portValid, inValid);
558  Connect(body, portData, inData);
559  }
560  }
561  else
562  {
563  auto clock = GetClockSignal(module);
564  auto reset = GetResetSignal(module);
565  ::llvm::SmallVector<circt::firrtl::RegResetOp> firedRegs;
566  ::llvm::SmallVector<circt::firrtl::AndPrimOp> whenConditions;
567  // outputs can only fire if input is valid. This should not be necessary, unless other
568  // components misbehave
569  mlir::Value allFired = inValid;
570  for (size_t i = 0; i < node->noutputs(); ++i)
571  {
572  std::string validName("out");
573  validName.append(std::to_string(i));
574  validName.append("_fired_reg");
575  auto firedReg = Builder_->create<circt::firrtl::RegResetOp>(
576  Builder_->getUnknownLoc(),
577  GetIntType(1),
578  clock,
579  reset,
580  zeroBitValue,
581  Builder_->getStringAttr(validName));
582  body->push_back(firedReg);
583  firedRegs.push_back(firedReg);
584 
585  // Get the bundle
586  auto port = GetOutPort(module, i);
587  auto portReady = GetSubfield(body, port, "ready");
588  auto portValid = GetSubfield(body, port, "valid");
589  auto portData = GetSubfield(body, port, "data");
590 
591  auto notFiredReg = AddNotOp(body, firedReg.getResult());
592  auto andOp = AddAndOp(body, inValid, notFiredReg.getResult());
593  Connect(body, portValid, andOp);
594  Connect(body, portData, inData);
595 
596  auto orOp = AddOrOp(body, portReady, firedReg.getResult());
597  allFired = AddAndOp(body, allFired, orOp);
598 
599  // Conditions needed for the when statements
600  whenConditions.push_back(AddAndOp(body, portReady, portValid));
601  }
602  allFired = AddNodeOp(body, allFired, "all_fired").getResult();
603  Connect(body, inReady, allFired);
604 
605  // When statement
606  auto condition = AddNotOp(body, allFired);
607  auto whenOp = AddWhenOp(body, condition, true);
608  // getThenBlock() cause an error during commpilation
609  // So we first get the builder and then its associated body
610  auto thenBody = whenOp.getThenBodyBuilder().getBlock();
611  // Then region
612  for (size_t i = 0; i < node->noutputs(); i++)
613  {
614  auto nestedWhen = AddWhenOp(thenBody, whenConditions[i], false);
615  auto nestedBody = nestedWhen.getThenBodyBuilder().getBlock();
616  Connect(nestedBody, firedRegs[i].getResult(), oneBitValue);
617  }
618  // Else region
619  auto elseBody = whenOp.getElseBodyBuilder().getBlock();
620  for (size_t i = 0; i < node->noutputs(); i++)
621  {
622  Connect(elseBody, firedRegs[i].getResult(), zeroBitValue);
623  }
624  }
625 
626  return module;
627 }
628 
629 circt::firrtl::FModuleOp
631 {
632  // Create the module and its input/output ports
633  auto module = nodeToModule(node);
634  auto body = module.getBodyBlock();
635 
636  //
637  // Output registers
638  //
639  auto clock = GetClockSignal(module);
640  auto reset = GetResetSignal(module);
641  ::llvm::SmallVector<circt::firrtl::RegResetOp> firedRegs;
642  ::llvm::SmallVector<circt::firrtl::AndPrimOp> whenConditions;
643  auto oneBitValue = GetConstant(body, 1, 1);
644  auto zeroBitValue = GetConstant(body, 1, 0);
645  mlir::Value allInsValid = oneBitValue;
646  for (size_t i = 0; i < node->ninputs(); ++i)
647  {
648  auto inBundle = GetInPort(module, i);
649  // auto inReady = GetSubfield(body, inBundle, "ready");
650  auto inValid = GetSubfield(body, inBundle, "valid");
651  // auto inData = GetSubfield(body, inBundle, "data");
652  allInsValid = AddAndOp(body, allInsValid, inValid);
653  }
654  allInsValid = AddNodeOp(body, allInsValid, "all_ins_valid").getResult();
655  mlir::Value allFired = oneBitValue;
656  for (size_t i = 0; i < node->noutputs(); ++i)
657  {
658  std::string validName("out");
659  validName.append(std::to_string(i));
660  validName.append("_fired_reg");
661  auto firedReg = Builder_->create<circt::firrtl::RegResetOp>(
662  Builder_->getUnknownLoc(),
663  GetIntType(1),
664  clock,
665  reset,
666  zeroBitValue,
667  Builder_->getStringAttr(validName));
668  body->push_back(firedReg);
669  firedRegs.push_back(firedReg);
670 
671  // Get the bundle
672  auto out = GetOutPort(module, i);
673  auto outReady = GetSubfield(body, out, "ready");
674  auto outValid = GetSubfield(body, out, "valid");
675  auto outData = GetSubfield(body, out, "data");
676  auto in = GetInPort(module, i);
677  auto inData = GetSubfield(body, in, "data");
678 
679  auto notFiredReg = AddNotOp(body, firedReg.getResult());
680  auto andOp = AddAndOp(body, allInsValid, notFiredReg);
681  Connect(body, outValid, andOp);
682  Connect(body, outData, inData);
683 
684  auto orOp = AddOrOp(body, AddAndOp(body, outValid, outReady), firedReg.getResult());
685  allFired = AddAndOp(body, allFired, orOp);
686 
687  // Conditions needed for the when statements
688  whenConditions.push_back(AddAndOp(body, outReady, outValid));
689  }
690  allFired = AddNodeOp(body, allFired, "all_fired").getResult();
691  for (size_t i = 0; i < node->ninputs(); ++i)
692  {
693  auto in = GetInPort(module, i);
694  auto inReady = GetSubfield(body, in, "ready");
695  Connect(body, inReady, allFired);
696  }
697 
698  // When statement
699  auto condition = AddNotOp(body, allFired);
700  auto whenOp = AddWhenOp(body, condition, true);
701  // getThenBlock() cause an error during commpilation
702  // So we first get the builder and then its associated body
703  auto thenBody = whenOp.getThenBodyBuilder().getBlock();
704  // Then region
705  for (size_t i = 0; i < node->noutputs(); i++)
706  {
707  auto nestedWhen = AddWhenOp(thenBody, whenConditions[i], false);
708  auto nestedBody = nestedWhen.getThenBodyBuilder().getBlock();
709  Connect(nestedBody, firedRegs[i].getResult(), oneBitValue);
710  }
711  // Else region
712  auto elseBody = whenOp.getElseBodyBuilder().getBlock();
713  for (size_t i = 0; i < node->noutputs(); i++)
714  {
715  Connect(elseBody, firedRegs[i].getResult(), zeroBitValue);
716  }
717 
718  return module;
719 }
720 
721 circt::firrtl::FModuleOp
723 {
724  // Create the module and its input/output ports
725  auto module = nodeToModule(node, false);
726  auto body = module.getBodyBlock();
727 
728  auto zeroBitValue = GetConstant(body, 1, 0);
729  auto oneBitValue = GetConstant(body, 1, 1);
730 
731  for (size_t i = 0; i < node->noutputs(); ++i)
732  {
733  auto outBundle = GetOutPort(module, i);
734  auto outValid = GetSubfield(body, outBundle, "valid");
735  auto outData = GetSubfield(body, outBundle, "data");
736  Connect(body, outValid, zeroBitValue);
737  ConnectInvalid(body, outData);
738  }
739  for (size_t j = 0; j < node->ninputs(); ++j)
740  {
741  mlir::BlockArgument memRes = GetInPort(module, j);
742  auto memResValid = GetSubfield(body, memRes, "valid");
743  auto memResReady = GetSubfield(body, memRes, "ready");
744  auto memResBundle = GetSubfield(body, memRes, "data");
745  auto memResId = GetSubfield(body, memResBundle, "id");
746  auto memResData = GetSubfield(body, memResBundle, "data");
747  auto portWidth =
748  memResData->getResult(0).getType().cast<circt::firrtl::IntType>().getWidth().value();
749 
750  auto elseBody = body;
751  for (size_t i = 0; i < node->noutputs(); ++i)
752  {
753  bool isStore = node->output(i)->Type()->Kind() == rvsdg::TypeKind::State;
754  auto outBundle = GetOutPort(module, i);
755  auto outValid = GetSubfield(elseBody, outBundle, "valid");
756  auto outReady = GetSubfield(elseBody, outBundle, "ready");
757  auto outData = GetSubfield(elseBody, outBundle, "data");
758  auto condition =
759  AddAndOp(elseBody, memResValid, AddEqOp(elseBody, GetConstant(elseBody, 8, i), memResId));
760  auto whenOp = AddWhenOp(elseBody, condition, true);
761  auto thenBody = whenOp.getThenBodyBuilder().getBlock();
762  Connect(thenBody, outValid, oneBitValue);
763  Connect(thenBody, memResReady, outReady);
764  // don't connect data for stores
765  if (!isStore)
766  {
767  int nbits = JlmSize(node->output(i)->Type().get());
768  if (nbits == portWidth)
769  {
770  Connect(thenBody, outData, memResData);
771  }
772  else
773  {
774  Connect(thenBody, outData, AddBitsOp(thenBody, memResData, nbits - 1, 0));
775  }
776  }
777  elseBody = whenOp.getElseBodyBuilder().getBlock();
778  }
779 
780  // Connect to ready for other ids - for example stores
781  Connect(elseBody, memResReady, oneBitValue);
782  // Assert we don't get a response to the same ID on several in ports - if this shows up we need
783  // taken logic for outputs
784  for (size_t i = 0; i < j; ++i)
785  {
786  mlir::BlockArgument memRes2 = GetInPort(module, i);
787  auto memResValid2 = GetSubfield(body, memRes2, "valid");
788  auto memResBundle2 = GetSubfield(body, memRes2, "data");
789  auto memResId2 = GetSubfield(body, memResBundle2, "id");
790  auto id_assert = Builder_->create<circt::firrtl::AssertOp>(
791  Builder_->getUnknownLoc(),
792  GetClockSignal(module),
793  AddNotOp(
794  body,
795  AddAndOp(
796  body,
797  AddAndOp(body, memResValid, memResValid2),
798  AddEqOp(body, memResId, memResId2))),
799  AddNotOp(body, GetResetSignal(module)),
800  "overlapping reponse id",
801  mlir::ValueRange(),
802  "response_id_assert_" + std::to_string(j) + "_" + std::to_string(i));
803  body->push_back(id_assert);
804  }
805  }
806 
807  return module;
808 }
809 
810 circt::firrtl::FModuleOp
812 {
813  // Create the module and its input/output ports
814  auto module = nodeToModule(node, false);
815  auto body = module.getBodyBlock();
816  auto op = dynamic_cast<const MemoryRequestOperation *>(&node->GetOperation());
817 
818  auto loadTypes = op->GetLoadTypes();
819  ::llvm::SmallVector<circt::firrtl::SubfieldOp> loadAddrReadys;
820  ::llvm::SmallVector<circt::firrtl::SubfieldOp> loadAddrValids;
821  ::llvm::SmallVector<circt::firrtl::SubfieldOp> loadAddrDatas;
822  ::llvm::SmallVector<mlir::Value> loadIds;
823 
824  auto storeTypes = op->GetStoreTypes();
825  ::llvm::SmallVector<circt::firrtl::SubfieldOp> storeAddrReadys;
826  ::llvm::SmallVector<circt::firrtl::SubfieldOp> storeAddrValids;
827  ::llvm::SmallVector<circt::firrtl::SubfieldOp> storeAddrDatas;
828  ::llvm::SmallVector<circt::firrtl::SubfieldOp> storeDataReadys;
829  ::llvm::SmallVector<circt::firrtl::SubfieldOp> storeDataValids;
830  ::llvm::SmallVector<circt::firrtl::SubfieldOp> storeDataDatas;
831  ::llvm::SmallVector<mlir::Value> storeIds;
832  // The ports for loads come first and consist only of addresses.
833  // Stores have both addresses and data
834  size_t id = 0;
835  for (size_t i = 0; i < op->get_nloads(); ++i)
836  {
837  auto bundle = GetInPort(module, i);
838  loadAddrReadys.push_back(GetSubfield(body, bundle, "ready"));
839  loadAddrValids.push_back(GetSubfield(body, bundle, "valid"));
840  loadAddrDatas.push_back(GetSubfield(body, bundle, "data"));
841  loadIds.push_back(GetConstant(body, 8, id));
842  id++;
843  }
844  for (size_t i = op->get_nloads(); i < node->ninputs(); ++i)
845  {
846  // Store
847  auto addrBundle = GetInPort(module, i);
848  storeAddrReadys.push_back(GetSubfield(body, addrBundle, "ready"));
849  storeAddrValids.push_back(GetSubfield(body, addrBundle, "valid"));
850  storeAddrDatas.push_back(GetSubfield(body, addrBundle, "data"));
851  i++;
852  auto dataBundle = GetInPort(module, i);
853  storeDataReadys.push_back(GetSubfield(body, dataBundle, "ready"));
854  storeDataValids.push_back(GetSubfield(body, dataBundle, "valid"));
855  storeDataDatas.push_back(GetSubfield(body, dataBundle, "data"));
856  storeIds.push_back(GetConstant(body, 8, id));
857  id++;
858  }
859 
860  auto zeroBitValue = GetConstant(body, 1, 0);
861  auto oneBitValue = GetConstant(body, 1, 1);
862  ::llvm::SmallVector<mlir::Value> loadGranted(loadTypes->size(), zeroBitValue);
863  ::llvm::SmallVector<mlir::Value> storeGranted(storeTypes->size(), zeroBitValue);
864  for (size_t j = 0; j < node->noutputs(); ++j)
865  {
866  auto reqType = util::assertedCast<const BundleType>(node->output(j)->Type().get());
867  auto hasWrite = reqType->elements_.size() == 5;
868  mlir::BlockArgument memReq = GetOutPort(module, j);
869  mlir::Value memReqData;
870  mlir::Value memReqWrite;
871  auto memReqReady = GetSubfield(body, memReq, "ready");
872  auto memReqValid = GetSubfield(body, memReq, "valid");
873  auto memReqBundle = GetSubfield(body, memReq, "data");
874  auto memReqAddr = GetSubfield(body, memReqBundle, "addr");
875  auto memReqSize = GetSubfield(body, memReqBundle, "size");
876  auto memReqId = GetSubfield(body, memReqBundle, "id");
877  if (hasWrite)
878  {
879  memReqData = GetSubfield(body, memReqBundle, "data");
880  memReqWrite = GetSubfield(body, memReqBundle, "write");
881  }
882  // Default request connection
883  Connect(body, memReqValid, zeroBitValue);
884  ConnectInvalid(body, memReqBundle);
885  mlir::Value previousGranted = zeroBitValue;
886  for (size_t i = 0; i < loadTypes->size(); ++i)
887  {
888  if (j == 0)
889  {
890  Connect(body, loadAddrReadys[i], zeroBitValue);
891  }
892  auto canGrant = AddNotOp(body, AddOrOp(body, previousGranted, loadGranted[i]));
893  auto grant = AddAndOp(body, canGrant, loadAddrValids[i]);
894  auto whenOp = AddWhenOp(body, grant, false);
895  auto thenBody = whenOp.getThenBodyBuilder().getBlock();
896  Connect(thenBody, loadAddrReadys[i], memReqReady);
897  Connect(thenBody, memReqValid, loadAddrValids[i]);
898  Connect(thenBody, memReqAddr, loadAddrDatas[i]);
899  Connect(thenBody, memReqId, loadIds[i]);
900  // No data or write
901  auto loadType = loadTypes->at(i).get();
902  int bitWidth = JlmSize(loadType);
903  int log2Bytes = log2(bitWidth / 8);
904  Connect(thenBody, memReqSize, GetConstant(thenBody, 3, log2Bytes));
905  if (hasWrite)
906  {
907  Connect(thenBody, memReqWrite, zeroBitValue);
908  }
909  // Update for next iteration
910  previousGranted = AddOrOp(body, previousGranted, grant);
911  loadGranted[i] = AddOrOp(body, loadGranted[i], grant);
912  }
913  // Stores
914  for (size_t i = 0; hasWrite && i < storeTypes->size(); ++i)
915  {
916  if (j == 0)
917  {
918  Connect(body, storeAddrReadys[i], zeroBitValue);
919  Connect(body, storeDataReadys[i], zeroBitValue);
920  }
921  auto notOp = AddNotOp(body, AddOrOp(body, previousGranted, storeGranted[i]));
922  auto grant = AddAndOp(body, notOp, storeAddrValids[i]);
923  grant = AddAndOp(body, grant, storeDataValids[i]);
924  auto whenOp = AddWhenOp(body, grant, false);
925  auto thenBody = whenOp.getThenBodyBuilder().getBlock();
926  Connect(thenBody, storeAddrReadys[i], memReqReady);
927  Connect(thenBody, storeDataReadys[i], memReqReady);
928  Connect(thenBody, memReqValid, storeAddrValids[i]);
929  Connect(thenBody, memReqAddr, storeAddrDatas[i]);
930  Connect(thenBody, memReqData, storeDataDatas[i]);
931  // TODO: pad
932  // auto portWidth =
933  // memReqData.getType().cast<circt::firrtl::IntType>().getWidth().value();
934  Connect(thenBody, memReqId, storeIds[i]);
935  // No data or write
936  auto storeType = storeTypes->at(i).get();
937  int bitWidth = JlmSize(storeType);
938  int log2Bytes = log2(bitWidth / 8);
939  Connect(thenBody, memReqSize, GetConstant(thenBody, 3, log2Bytes));
940  Connect(thenBody, memReqWrite, oneBitValue);
941  // Update for next iteration
942  previousGranted = AddOrOp(body, previousGranted, grant);
943  storeGranted[i] = AddOrOp(body, storeGranted[i], grant);
944  }
945  }
946 
947  return module;
948 }
949 
950 circt::firrtl::FModuleOp
952 {
953  JLM_ASSERT(rvsdg::is<LoadOperation>(node) || rvsdg::is<LocalLoadOperation>(node));
954 
955  // Create the module and its input/output ports
956  auto module = nodeToModule(node, false);
957  auto body = module.getBodyBlock();
958 
959  // Input signals
960  auto inBundleAddr = GetInPort(module, 0);
961  auto inReadyAddr = GetSubfield(body, inBundleAddr, "ready");
962  auto inValidAddr = GetSubfield(body, inBundleAddr, "valid");
963  auto inDataAddr = GetSubfield(body, inBundleAddr, "data");
964 
965  ::llvm::SmallVector<circt::firrtl::SubfieldOp> inReadyStates;
966  ::llvm::SmallVector<circt::firrtl::SubfieldOp> inValidStates;
967  ::llvm::SmallVector<circt::firrtl::SubfieldOp> inDataStates;
968  for (size_t i = 1; i < node->ninputs() - 1; ++i)
969  {
970  auto bundle = GetInPort(module, i);
971  inReadyStates.push_back(GetSubfield(body, bundle, "ready"));
972  inValidStates.push_back(GetSubfield(body, bundle, "valid"));
973  inDataStates.push_back(GetSubfield(body, bundle, "data"));
974  }
975 
976  auto inBundleMemData = GetInPort(module, node->ninputs() - 1);
977  auto inReadyMemData = GetSubfield(body, inBundleMemData, "ready");
978  auto inValidMemData = GetSubfield(body, inBundleMemData, "valid");
979  auto inDataMemData = GetSubfield(body, inBundleMemData, "data");
980 
981  // Output signals
982  auto outBundleData = GetOutPort(module, 0);
983  auto outReadyData = GetSubfield(body, outBundleData, "ready");
984  auto outValidData = GetSubfield(body, outBundleData, "valid");
985  auto outDataData = GetSubfield(body, outBundleData, "data");
986 
987  ::llvm::SmallVector<circt::firrtl::SubfieldOp> outReadyStates;
988  ::llvm::SmallVector<circt::firrtl::SubfieldOp> outValidStates;
989  ::llvm::SmallVector<circt::firrtl::SubfieldOp> outDataStates;
990  for (size_t i = 1; i < node->noutputs() - 1; ++i)
991  {
992  auto bundle = GetOutPort(module, i);
993  outReadyStates.push_back(GetSubfield(body, bundle, "ready"));
994  outValidStates.push_back(GetSubfield(body, bundle, "valid"));
995  outDataStates.push_back(GetSubfield(body, bundle, "data"));
996  }
997 
998  auto outBundleMemAddr = GetOutPort(module, node->noutputs() - 1);
999  auto outReadyMemAddr = GetSubfield(body, outBundleMemAddr, "ready");
1000  auto outValidMemAddr = GetSubfield(body, outBundleMemAddr, "valid");
1001  auto outDataMemAddr = GetSubfield(body, outBundleMemAddr, "data");
1002 
1003  auto clock = GetClockSignal(module);
1004  auto reset = GetResetSignal(module);
1005  auto zeroBitValue = GetConstant(body, 1, 0);
1006  auto oneBitValue = GetConstant(body, 1, 1);
1007 
1008  // Registers
1009  ::llvm::SmallVector<circt::firrtl::RegResetOp> oValidRegs;
1010  ::llvm::SmallVector<circt::firrtl::RegResetOp> oDataRegs;
1011  for (size_t i = 0; i < node->noutputs() - 1; i++)
1012  {
1013  std::string validName("o");
1014  validName.append(std::to_string(i));
1015  validName.append("_valid_reg");
1016  auto validReg = Builder_->create<circt::firrtl::RegResetOp>(
1017  Builder_->getUnknownLoc(),
1018  GetIntType(1),
1019  clock,
1020  reset,
1021  zeroBitValue,
1022  Builder_->getStringAttr(validName));
1023  body->push_back(validReg);
1024  oValidRegs.push_back(validReg);
1025 
1026  auto zeroValue = GetConstant(body, JlmSize(node->output(i)->Type().get()), 0);
1027  std::string dataName("o");
1028  dataName.append(std::to_string(i));
1029  dataName.append("_data_reg");
1030  auto dataReg = Builder_->create<circt::firrtl::RegResetOp>(
1031  Builder_->getUnknownLoc(),
1032  GetIntType(node->output(i)->Type().get()),
1033  clock,
1034  reset,
1035  zeroValue,
1036  Builder_->getStringAttr(dataName));
1037  body->push_back(dataReg);
1038  oDataRegs.push_back(dataReg);
1039  }
1040  auto sentReg = Builder_->create<circt::firrtl::RegResetOp>(
1041  Builder_->getUnknownLoc(),
1042  GetIntType(1),
1043  clock,
1044  reset,
1045  zeroBitValue,
1046  Builder_->getStringAttr("sent_reg"));
1047  body->push_back(sentReg);
1048 
1049  // mlir::Value canRequest = AddOrOp(body, AddNotOp(body, sentReg), AddAndOp(body,
1050  // inValidMemData, outReadyData));
1051  mlir::Value canRequest = AddNotOp(body, sentReg.getResult());
1052  canRequest = AddAndOp(body, canRequest, inValidAddr);
1053  for (auto vld : inValidStates)
1054  {
1055  canRequest = AddAndOp(body, canRequest, vld);
1056  }
1057  // canRequest = AddAndOp(body, canRequest, AddOrOp(body, AddNotOp(body, oValidRegs[0]),
1058  // outReadyData));
1059  canRequest = AddAndOp(body, canRequest, AddNotOp(body, oValidRegs[0].getResult()));
1060  for (size_t i = 1; i < oValidRegs.size(); i++)
1061  {
1062  // canRequest = AddAndOp(body, canRequest, AddOrOp(body, AddNotOp(body, oValidRegs[i]),
1063  // outReadyStates[i-1]));
1064  canRequest = AddAndOp(body, canRequest, AddNotOp(body, oValidRegs[i].getResult()));
1065  }
1066 
1067  // Block until all inputs and no outputs are valid
1068  Connect(body, outValidMemAddr, canRequest);
1069  Connect(body, outDataMemAddr, inDataAddr);
1070 
1071  Connect(body, outValidData, oValidRegs[0].getResult());
1072  Connect(body, outDataData, oDataRegs[0].getResult());
1073 
1074  for (size_t i = 1; i < node->noutputs() - 1; ++i)
1075  {
1076  Connect(body, outValidStates[i - 1], oValidRegs[i].getResult());
1077  Connect(body, outDataStates[i - 1], oDataRegs[i].getResult());
1078  auto andOp2 = AddAndOp(body, outReadyStates[i - 1], outValidStates[i - 1]);
1079  Connect(
1080  // When o1 fires
1081  AddWhenOp(body, andOp2, false).getThenBodyBuilder().getBlock(),
1082  oValidRegs[i].getResult(),
1083  zeroBitValue);
1084  }
1085 
1086  // mem_res fire
1087  auto whenResFireOp = AddWhenOp(body, AddAndOp(body, sentReg.getResult(), inValidMemData), false);
1088  auto whenResFireBody = whenResFireOp.getThenBodyBuilder().getBlock();
1089  Connect(whenResFireBody, sentReg.getResult(), zeroBitValue);
1090  Connect(whenResFireBody, oDataRegs[0].getResult(), inDataMemData);
1091  Connect(whenResFireBody, oValidRegs[0].getResult(), oneBitValue);
1092  Connect(whenResFireBody, outDataData, inDataMemData);
1093  Connect(whenResFireBody, outValidData, oneBitValue);
1094 
1095  // mem_req fire
1096  auto whenReqFireOp = AddWhenOp(body, outReadyMemAddr, false);
1097  auto whenReqFireBody = whenReqFireOp.getThenBodyBuilder().getBlock();
1098  Connect(whenReqFireBody, sentReg.getResult(), oneBitValue);
1099  for (size_t i = 1; i < node->noutputs() - 1; ++i)
1100  {
1101  Connect(whenReqFireBody, oValidRegs[i].getResult(), oneBitValue);
1102  Connect(whenReqFireBody, oDataRegs[i].getResult(), inDataStates[i - 1]);
1103  }
1104 
1105  // Handshaking
1106  Connect(body, inReadyAddr, outReadyMemAddr);
1107  for (size_t i = 1; i < node->ninputs() - 1; ++i)
1108  {
1109  Connect(body, inReadyStates[i - 1], outReadyMemAddr);
1110  }
1111  Connect(body, inReadyMemData, sentReg.getResult());
1112 
1113  auto andOp = AddAndOp(body, outReadyData, outValidData);
1114  Connect(
1115  // When o0 fires
1116  AddWhenOp(body, andOp, false).getThenBodyBuilder().getBlock(),
1117  oValidRegs[0].getResult(),
1118  zeroBitValue);
1119 
1120  return module;
1121 }
1122 
1123 circt::firrtl::FModuleOp
1125 {
1126  JLM_ASSERT(rvsdg::is<DecoupledLoadOperation>(node));
1127 
1128  // Create the module and its input/output ports
1129  auto module = nodeToModule(node, false);
1130  auto body = module.getBodyBlock();
1131 
1132  // Input signals
1133  auto inBundleAddr = GetInPort(module, 0);
1134  auto inReadyAddr = GetSubfield(body, inBundleAddr, "ready");
1135  auto inValidAddr = GetSubfield(body, inBundleAddr, "valid");
1136  auto inDataAddr = GetSubfield(body, inBundleAddr, "data");
1137 
1138  auto inBundleMemData = GetInPort(module, node->ninputs() - 1);
1139  auto inReadyMemData = GetSubfield(body, inBundleMemData, "ready");
1140  auto inValidMemData = GetSubfield(body, inBundleMemData, "valid");
1141  auto inDataMemData = GetSubfield(body, inBundleMemData, "data");
1142 
1143  // Output signals
1144  auto outBundleData = GetOutPort(module, 0);
1145  auto outReadyData = GetSubfield(body, outBundleData, "ready");
1146  auto outValidData = GetSubfield(body, outBundleData, "valid");
1147  auto outDataData = GetSubfield(body, outBundleData, "data");
1148 
1149  auto outBundleMemAddr = GetOutPort(module, node->noutputs() - 1);
1150  auto outReadyMemAddr = GetSubfield(body, outBundleMemAddr, "ready");
1151  auto outValidMemAddr = GetSubfield(body, outBundleMemAddr, "valid");
1152  auto outDataMemAddr = GetSubfield(body, outBundleMemAddr, "data");
1153 
1154  // Block until all inputs and no outputs are valid
1155  Connect(body, outValidMemAddr, inValidAddr);
1156  Connect(body, outDataMemAddr, inDataAddr);
1157 
1158  // Handshaking
1159  Connect(body, inReadyAddr, outReadyMemAddr);
1160  Connect(body, inReadyMemData, outReadyData);
1161 
1162  Connect(body, outValidData, inValidMemData);
1163  Connect(body, outDataData, inDataMemData);
1164  AddAndOp(body, outReadyData, outValidData);
1165 
1166  return module;
1167 }
1168 
1169 circt::firrtl::FModuleOp
1171 {
1172  auto lmem_op = util::assertedCast<const LocalMemoryOperation>(&node->GetOperation());
1173  auto res_node = rvsdg::TryGetOwnerNode<rvsdg::Node>(*node->output(0)->Users().begin());
1174  JLM_ASSERT(rvsdg::is<LocalMemoryResponseOperation>(res_node));
1175  auto req_node = rvsdg::TryGetOwnerNode<rvsdg::Node>(*node->output(1)->Users().begin());
1176  JLM_ASSERT(rvsdg::is<LocalMemoryRequestOperation>(req_node));
1177 
1178  // Create the module and its input/output ports - we use a non-standard way here
1179  // Generate a vector with all inputs and outputs of the module
1180  ::llvm::SmallVector<circt::firrtl::PortInfo> ports;
1181  // Clock and reset ports
1182  AddClockPort(&ports);
1183  AddResetPort(&ports);
1184  // Input bundle port
1185  // virtual in/outputs based on request/reponse ports
1186  for (size_t i = 1; i < req_node->ninputs(); ++i)
1187  {
1188  std::string name("i");
1189  name.append(std::to_string(i - 1));
1190  AddBundlePort(
1191  &ports,
1192  circt::firrtl::Direction::In,
1193  name,
1194  GetFirrtlType(req_node->input(i)->Type().get()));
1195  }
1196  for (size_t i = 0; i < res_node->noutputs(); ++i)
1197  {
1198  std::string name("o");
1199  name.append(std::to_string(i));
1200  AddBundlePort(
1201  &ports,
1202  circt::firrtl::Direction::Out,
1203  name,
1204  GetFirrtlType(res_node->output(i)->Type().get()));
1205  }
1206 
1207  // Creat a name for the module
1208  auto nodeName = GetModuleName(node);
1209  mlir::StringAttr name = Builder_->getStringAttr(nodeName);
1210  // Create the module
1211  auto module = Builder_->create<circt::firrtl::FModuleOp>(
1212  Builder_->getUnknownLoc(),
1213  name,
1214  circt::firrtl::ConventionAttr::get(
1215  Builder_->getContext(),
1216  circt::firrtl::Convention::Internal),
1217  ports);
1218 
1219  auto body = module.getBodyBlock();
1220 
1221  size_t loads = rvsdg::TryGetOwnerNode<rvsdg::Node>(*node->output(0)->Users().begin())->noutputs();
1222 
1223  // Input signals
1224  ::llvm::SmallVector<circt::firrtl::SubfieldOp> loadAddrReadys;
1225  ::llvm::SmallVector<circt::firrtl::SubfieldOp> loadAddrValids;
1226  ::llvm::SmallVector<circt::firrtl::SubfieldOp> loadAddrDatas;
1227 
1228  ::llvm::SmallVector<circt::firrtl::SubfieldOp> storeAddrReadys;
1229  ::llvm::SmallVector<circt::firrtl::SubfieldOp> storeAddrValids;
1230  ::llvm::SmallVector<circt::firrtl::SubfieldOp> storeAddrDatas;
1231  ::llvm::SmallVector<circt::firrtl::SubfieldOp> storeDataReadys;
1232  ::llvm::SmallVector<circt::firrtl::SubfieldOp> storeDataValids;
1233  ::llvm::SmallVector<circt::firrtl::SubfieldOp> storeDataDatas;
1234  // the ports for loads come first and consist only of addresses. Stores have both addresses and
1235  // data
1236  for (size_t i = 1; i < req_node->ninputs(); ++i)
1237  {
1238  if (i - 1 < loads)
1239  {
1240  // Load
1241  JLM_ASSERT(storeAddrReadys.empty()); // no stores yet
1242  auto bundle = GetInPort(module, i - 1);
1243  loadAddrReadys.push_back(GetSubfield(body, bundle, "ready"));
1244  loadAddrValids.push_back(GetSubfield(body, bundle, "valid"));
1245  loadAddrDatas.push_back(GetSubfield(body, bundle, "data"));
1246  }
1247  else
1248  {
1249  // Store
1250  auto addrBundle = GetInPort(module, i - 1);
1251  storeAddrReadys.push_back(GetSubfield(body, addrBundle, "ready"));
1252  storeAddrValids.push_back(GetSubfield(body, addrBundle, "valid"));
1253  storeAddrDatas.push_back(GetSubfield(body, addrBundle, "data"));
1254  i++;
1255  auto dataBundle = GetInPort(module, i - 1);
1256  storeDataReadys.push_back(GetSubfield(body, dataBundle, "ready"));
1257  storeDataValids.push_back(GetSubfield(body, dataBundle, "valid"));
1258  storeDataDatas.push_back(GetSubfield(body, dataBundle, "data"));
1259  }
1260  }
1261 
1262  ::llvm::SmallVector<circt::firrtl::SubfieldOp> loadDataReadys;
1263  ::llvm::SmallVector<circt::firrtl::SubfieldOp> loadDataValids;
1264  ::llvm::SmallVector<circt::firrtl::SubfieldOp> loadDataDatas;
1265  for (size_t i = 0; i < res_node->noutputs(); ++i)
1266  {
1267  auto bundle = GetOutPort(module, i);
1268  loadDataReadys.push_back(GetSubfield(body, bundle, "ready"));
1269  loadDataValids.push_back(GetSubfield(body, bundle, "valid"));
1270  loadDataDatas.push_back(GetSubfield(body, bundle, "data"));
1271  }
1272 
1273  auto clock = GetClockSignal(module);
1274  auto reset = GetResetSignal(module);
1275  auto zeroBitValue = GetConstant(body, 1, 0);
1276  auto oneBitValue = GetConstant(body, 1, 1);
1277 
1278  // memory
1279  auto arraytype = std::dynamic_pointer_cast<const llvm::ArrayType>(lmem_op->result(0));
1280  size_t depth = arraytype->nelements();
1281  auto dataType = GetFirrtlType(&arraytype->element_type());
1282  ::llvm::SmallVector<mlir::Type> memTypes;
1283  ::llvm::SmallVector<mlir::Attribute> memNames;
1284  memTypes.push_back(circt::firrtl::MemOp::getTypeForPort(
1285  depth,
1286  dataType,
1287  circt::firrtl::MemOp::PortKind::ReadWrite));
1288  memNames.push_back(Builder_->getStringAttr("rw0"));
1289  // memTypes.push_back(circt::firrtl::MemOp::getTypeForPort(depth, dataType,
1290  // circt::firrtl::MemOp::PortKind::ReadWrite));
1291  // memNames.push_back(Builder_->getStringAttr("rw1"));
1292  // TODO: figure out why writeLatency is wrong here
1293  auto memory = Builder_->create<circt::firrtl::MemOp>(
1294  Builder_->getUnknownLoc(),
1295  memTypes,
1296  2,
1297  1,
1298  depth,
1299  circt::firrtl::RUWAttr::New,
1300  memNames,
1301  "mem");
1302  body->push_back(memory);
1303  auto rw0 = memory.getPortNamed("rw0");
1304  Connect(body, GetSubfield(body, rw0, "clk"), clock);
1305  auto rw0_wmode = GetSubfield(body, rw0, "wmode");
1306  Connect(body, GetSubfield(body, rw0, "en"), oneBitValue);
1307  Connect(body, GetSubfield(body, rw0, "wmask"), oneBitValue);
1308  auto rw0_addr = GetSubfield(body, rw0, "addr");
1309  auto rw0_rdata = GetSubfield(body, rw0, "rdata");
1310  auto rw0_wdata = GetSubfield(body, rw0, "wdata");
1311  Connect(body, rw0_wdata, GetConstant(body, JlmSize(&arraytype->element_type()), 0));
1312  // auto rw1 = memory.getPortNamed("rw1");
1313  // Connect(body, GetSubfield(body, rw1, "clk"), clock);
1314  int addrwidth = ceil(log2(depth));
1315 
1316  // do stores first, because they pass state edges on directly; having loads first might create a
1317  // combinatorial cycle
1318  for (size_t i = 0; i < storeDataReadys.size(); ++i)
1319  {
1320  Connect(body, storeDataReadys[i], zeroBitValue);
1321  Connect(body, storeAddrReadys[i], zeroBitValue);
1322  }
1323  ::llvm::SmallVector<circt::firrtl::RegResetOp> loadValidRegs;
1324  for (size_t i = 0; i < loadAddrReadys.size(); ++i)
1325  {
1326  auto validReg = Builder_->create<circt::firrtl::RegResetOp>(
1327  Builder_->getUnknownLoc(),
1328  GetIntType(1),
1329  clock,
1330  reset,
1331  zeroBitValue,
1332  Builder_->getStringAttr("load_valid_" + std::to_string(i)));
1333  body->push_back(validReg);
1334  loadValidRegs.push_back(validReg);
1335  Connect(body, validReg.getResult(), zeroBitValue);
1336  Connect(body, loadDataValids[i], validReg.getResult());
1337  Connect(body, loadDataDatas[i], rw0_rdata);
1338  Connect(body, loadAddrReadys[i], zeroBitValue);
1339  }
1340  // mlir::Value assigned = zeroBitValue;
1341  mlir::Block * elsewhen = body;
1342  for (size_t i = 0; i < storeDataReadys.size(); ++i)
1343  {
1344  auto whenReqFireOp =
1345  AddWhenOp(elsewhen, AddAndOp(elsewhen, storeAddrValids[i], storeDataValids[i]), true);
1346  auto whenReqFireBody = whenReqFireOp.getThenBodyBuilder().getBlock();
1347  Connect(whenReqFireBody, storeDataReadys[i], oneBitValue);
1348  Connect(whenReqFireBody, storeAddrReadys[i], oneBitValue);
1349  Connect(whenReqFireBody, rw0_wmode, oneBitValue);
1350  Connect(whenReqFireBody, rw0_wdata, storeDataDatas[i]);
1351  Connect(
1352  whenReqFireBody,
1353  rw0_addr,
1354  AddBitsOp(whenReqFireBody, storeAddrDatas[i], addrwidth - 1, 0));
1355  elsewhen = whenReqFireOp.getElseBodyBuilder().getBlock();
1356  }
1357  for (size_t i = 0; i < loadAddrReadys.size(); ++i)
1358  {
1359  auto whenReqFireOp = AddWhenOp(elsewhen, loadAddrValids[i], true);
1360  auto whenReqFireBody = whenReqFireOp.getThenBodyBuilder().getBlock();
1361  Connect(whenReqFireBody, loadAddrReadys[i], oneBitValue);
1362  Connect(whenReqFireBody, rw0_wmode, zeroBitValue);
1363  Connect(
1364  whenReqFireBody,
1365  rw0_addr,
1366  AddBitsOp(whenReqFireBody, loadAddrDatas[i], addrwidth - 1, 0));
1367  Connect(whenReqFireBody, loadValidRegs[i].getResult(), oneBitValue);
1368  elsewhen = whenReqFireOp.getElseBodyBuilder().getBlock();
1369  }
1370  Connect(elsewhen, rw0_wmode, zeroBitValue);
1371  Connect(elsewhen, rw0_addr, GetConstant(elsewhen, addrwidth, 0));
1372 
1373  return module;
1374 }
1375 
1376 circt::firrtl::FModuleOp
1378 {
1379  JLM_ASSERT(rvsdg::is<StoreOperation>(node) || rvsdg::is<LocalStoreOperation>(node));
1380 
1381  // Create the module and its input/output ports
1382  auto module = nodeToModule(node, false);
1383  auto body = module.getBodyBlock();
1384 
1385  // Input signals
1386  auto inBundleAddr = GetInPort(module, 0);
1387  auto inReadyAddr = GetSubfield(body, inBundleAddr, "ready");
1388  auto inValidAddr = GetSubfield(body, inBundleAddr, "valid");
1389  auto inDataAddr = GetSubfield(body, inBundleAddr, "data");
1390 
1391  auto inBundleData = GetInPort(module, 1);
1392  auto inReadyData = GetSubfield(body, inBundleData, "ready");
1393  auto inValidData = GetSubfield(body, inBundleData, "valid");
1394  auto inDataData = GetSubfield(body, inBundleData, "data");
1395 
1396  ::llvm::SmallVector<circt::firrtl::SubfieldOp> inReadyStates;
1397  ::llvm::SmallVector<circt::firrtl::SubfieldOp> inValidStates;
1398  ::llvm::SmallVector<circt::firrtl::SubfieldOp> inDataStates;
1399  for (size_t i = 2; i < node->ninputs() - 1; ++i)
1400  {
1401  auto bundle = GetInPort(module, i);
1402  inReadyStates.push_back(GetSubfield(body, bundle, "ready"));
1403  inValidStates.push_back(GetSubfield(body, bundle, "valid"));
1404  inDataStates.push_back(GetSubfield(body, bundle, "data"));
1405  }
1406 
1407  auto inBundleResp = GetInPort(module, node->ninputs() - 1);
1408  auto inReadyResp = GetSubfield(body, inBundleResp, "ready");
1409  auto inValidResp = GetSubfield(body, inBundleResp, "valid");
1410 
1411  ::llvm::SmallVector<circt::firrtl::SubfieldOp> outReadyStates;
1412  ::llvm::SmallVector<circt::firrtl::SubfieldOp> outValidStates;
1413  ::llvm::SmallVector<circt::firrtl::SubfieldOp> outDataStates;
1414  for (size_t i = 0; i < node->noutputs() - 2; ++i)
1415  {
1416  auto bundle = GetOutPort(module, i);
1417  outReadyStates.push_back(GetSubfield(body, bundle, "ready"));
1418  outValidStates.push_back(GetSubfield(body, bundle, "valid"));
1419  outDataStates.push_back(GetSubfield(body, bundle, "data"));
1420  }
1421 
1422  auto outBundleMemAddr = GetOutPort(module, node->noutputs() - 2);
1423  auto outReadyMemAddr = GetSubfield(body, outBundleMemAddr, "ready");
1424  auto outValidMemAddr = GetSubfield(body, outBundleMemAddr, "valid");
1425  auto outDataMemAddr = GetSubfield(body, outBundleMemAddr, "data");
1426 
1427  // Output signals
1428  auto outBundleMemData = GetOutPort(module, node->noutputs() - 1);
1429  auto outValidMemData = GetSubfield(body, outBundleMemData, "valid");
1430  auto outDataMemData = GetSubfield(body, outBundleMemData, "data");
1431 
1432  auto oneBitValue = GetConstant(body, 1, 1);
1433 
1434  mlir::Value canRequest = inValidAddr;
1435  canRequest = AddAndOp(body, canRequest, inValidData);
1436  for (auto vld : inValidStates)
1437  {
1438  canRequest = AddAndOp(body, canRequest, vld);
1439  }
1440  // TODO: for now just assume that there is always room for state edges
1441  // for (size_t i = 0; i < oValidRegs.size(); ++i)
1442  // {
1443  // // register is empty or being drained
1444  // // canRequest = AddAndOp(body, canRequest, AddOrOp(body, AddNotOp(body,
1445  // oValidRegs[i]),
1446  // // outReadyStates[i]));
1447  // canRequest = AddAndOp(body, canRequest, AddNotOp(body, oValidRegs[i].getResult()));
1448  // }
1449 
1450  // Block until all inputs and no outputs are valid
1451  Connect(body, outValidMemAddr, canRequest);
1452  Connect(body, outDataMemAddr, inDataAddr);
1453  Connect(body, outValidMemData, canRequest);
1454  Connect(body, outDataMemData, inDataData);
1455 
1456  mlir::Value outStatesReady = oneBitValue;
1457  for (size_t i = 0; i < node->noutputs() - 2; ++i)
1458  {
1459  Connect(body, outValidStates[i], inValidResp);
1460  ConnectInvalid(body, outDataStates[i]);
1461  outStatesReady = AddAndOp(body, outReadyStates[i], outStatesReady);
1462  }
1463  Connect(body, inReadyResp, outStatesReady);
1464 
1465  // Handshaking
1466  Connect(body, inReadyAddr, outReadyMemAddr);
1467  // TODO: check readyness seperately?
1468  Connect(body, inReadyData, outReadyMemAddr);
1469  for (size_t i = 2; i < node->ninputs() - 1; ++i)
1470  {
1471  Connect(body, inReadyStates[i - 2], outReadyMemAddr);
1472  }
1473  return module;
1474 }
1475 
1476 circt::firrtl::FModuleOp
1478 {
1479  // Create the module and its input/output ports
1480  auto module = nodeToModule(node, true);
1481  auto body = module.getBodyBlock();
1482 
1483  // Check if it's a load or store GetOperation
1484  bool store = dynamic_cast<const llvm::StoreNonVolatileOperation *>(&(node->GetOperation()));
1485 
1486  InitializeMemReq(module);
1487  // Input signals
1488  auto inBundle0 = GetInPort(module, 0);
1489  auto inReady0 = GetSubfield(body, inBundle0, "ready");
1490  auto inValid0 = GetSubfield(body, inBundle0, "valid");
1491  auto inData0 = GetSubfield(body, inBundle0, "data");
1492 
1493  auto inBundle1 = GetInPort(module, 1);
1494  auto inReady1 = GetSubfield(body, inBundle1, "ready");
1495  auto inValid1 = GetSubfield(body, inBundle1, "valid");
1496  auto inData1 = GetSubfield(body, inBundle1, "data");
1497 
1498  // Stores also have a data input that needs to be handled
1499  // The input is not used by loads but code below reference
1500  // these variables so we need to define them
1501  mlir::BlockArgument inBundle2 = NULL;
1502  circt::firrtl::SubfieldOp inReady2 = NULL;
1503  circt::firrtl::SubfieldOp inValid2 = NULL;
1504  circt::firrtl::SubfieldOp inData2 = NULL;
1505  if (store)
1506  {
1507  inBundle2 = GetInPort(module, 2);
1508  inReady2 = GetSubfield(body, inBundle2, "ready");
1509  inValid2 = GetSubfield(body, inBundle2, "valid");
1510  inData2 = GetSubfield(body, inBundle2, "data");
1511  }
1512 
1513  // Output signals
1514  auto outBundle0 = GetOutPort(module, 0);
1515  auto outReady0 = GetSubfield(body, outBundle0, "ready");
1516  auto outValid0 = GetSubfield(body, outBundle0, "valid");
1517  auto outData0 = GetSubfield(body, outBundle0, "data");
1518 
1519  // Mem signals
1520  mlir::BlockArgument memReq = GetPort(module, "mem_req");
1521  mlir::BlockArgument memRes = GetPort(module, "mem_res");
1522 
1523  auto memReqReady = GetSubfield(body, memReq, "ready");
1524  auto memReqValid = GetSubfield(body, memReq, "valid");
1525  auto memReqAddr = GetSubfield(body, memReq, "addr");
1526  auto memReqData = GetSubfield(body, memReq, "data");
1527  auto memReqWrite = GetSubfield(body, memReq, "write");
1528  auto memReqWidth = GetSubfield(body, memReq, "width");
1529 
1530  auto memResValid = GetSubfield(body, memRes, "valid");
1531  auto memResData = GetSubfield(body, memRes, "data");
1532 
1533  auto clock = GetClockSignal(module);
1534  auto reset = GetResetSignal(module);
1535  auto zeroBitValue = GetConstant(body, 1, 0);
1536  auto oneBitValue = GetConstant(body, 1, 1);
1537 
1538  // Registers
1539  ::llvm::SmallVector<circt::firrtl::RegResetOp> oValidRegs;
1540  ::llvm::SmallVector<circt::firrtl::RegResetOp> oDataRegs;
1541  for (size_t i = 0; i < node->noutputs(); i++)
1542  {
1543  std::string validName("o");
1544  validName.append(std::to_string(i));
1545  validName.append("_valid_reg");
1546  auto validReg = Builder_->create<circt::firrtl::RegResetOp>(
1547  Builder_->getUnknownLoc(),
1548  GetIntType(1),
1549  clock,
1550  reset,
1551  zeroBitValue,
1552  Builder_->getStringAttr(validName));
1553  body->push_back(validReg);
1554  oValidRegs.push_back(validReg);
1555 
1556  auto zeroValue = GetConstant(body, JlmSize(node->output(i)->Type().get()), 0);
1557  std::string dataName("o");
1558  dataName.append(std::to_string(i));
1559  dataName.append("_data_reg");
1560  auto dataReg = Builder_->create<circt::firrtl::RegResetOp>(
1561  Builder_->getUnknownLoc(),
1562  GetIntType(node->output(i)->Type().get()),
1563  clock,
1564  reset,
1565  zeroValue,
1566  Builder_->getStringAttr(dataName));
1567  body->push_back(dataReg);
1568  oDataRegs.push_back(dataReg);
1569  }
1570  auto sentReg = Builder_->create<circt::firrtl::RegResetOp>(
1571  Builder_->getUnknownLoc(),
1572  GetIntType(1),
1573  clock,
1574  reset,
1575  zeroBitValue,
1576  Builder_->getStringAttr("sent_reg"));
1577  body->push_back(sentReg);
1578 
1579  mlir::Value canRequest = AddNotOp(body, sentReg.getResult());
1580  canRequest = AddAndOp(body, canRequest, inValid0);
1581  canRequest = AddAndOp(body, canRequest, inValid1);
1582  if (store)
1583  {
1584  canRequest = AddAndOp(body, canRequest, inValid2);
1585  }
1586  for (size_t i = 0; i < node->noutputs(); i++)
1587  {
1588  canRequest = AddAndOp(body, canRequest, AddNotOp(body, oValidRegs[i].getResult()));
1589  }
1590 
1591  // Block until all inputs and no outputs are valid
1592  Connect(body, memReqValid, canRequest);
1593  Connect(body, memReqAddr, inData0);
1594 
1595  int bitWidth = 0;
1596  if (store)
1597  {
1598  Connect(body, memReqWrite, oneBitValue);
1599  Connect(body, memReqData, inData1);
1600  bitWidth = std::dynamic_pointer_cast<const rvsdg::BitType>(node->input(1)->Type())->nbits();
1601  }
1602  else
1603  {
1604  Connect(body, memReqWrite, zeroBitValue);
1605  auto invalid = GetInvalid(body, 32);
1606  Connect(body, memReqData, invalid);
1607  if (auto bitType = std::dynamic_pointer_cast<const rvsdg::BitType>(node->output(0)->Type()))
1608  {
1609  bitWidth = bitType->nbits();
1610  }
1611  else if (rvsdg::is<llvm::PointerType>(node->output(0)->Type()))
1612  {
1613  bitWidth = GetPointerSizeInBits();
1614  }
1615  else
1616  {
1617  throw util::Error("unknown width for mem request");
1618  }
1619  }
1620 
1621  int log2Bytes = log2(bitWidth / 8);
1622  Connect(body, memReqWidth, GetConstant(body, 3, log2Bytes));
1623 
1624  // mem_req fire
1625  auto whenReqFireOp = AddWhenOp(body, memReqReady, false);
1626  auto whenReqFireBody = whenReqFireOp.getThenBodyBuilder().getBlock();
1627  Connect(whenReqFireBody, sentReg.getResult(), oneBitValue);
1628  if (store)
1629  {
1630  Connect(whenReqFireBody, oValidRegs[0].getResult(), oneBitValue);
1631  Connect(whenReqFireBody, oDataRegs[0].getResult(), inData2);
1632  }
1633  else
1634  {
1635  Connect(whenReqFireBody, oValidRegs[1].getResult(), oneBitValue);
1636  Connect(whenReqFireBody, oDataRegs[1].getResult(), inData1);
1637  }
1638 
1639  // mem_res fire
1640  auto whenResFireOp = AddWhenOp(body, AddAndOp(body, sentReg.getResult(), memResValid), false);
1641  auto whenResFireBody = whenResFireOp.getThenBodyBuilder().getBlock();
1642  Connect(whenResFireBody, sentReg.getResult(), zeroBitValue);
1643  if (!store)
1644  {
1645  Connect(whenResFireBody, oValidRegs[0].getResult(), oneBitValue);
1646  if (bitWidth != 64)
1647  {
1648  auto bitsOp = AddBitsOp(whenResFireBody, memResData, bitWidth - 1, 0);
1649  Connect(whenResFireBody, oDataRegs[0].getResult(), bitsOp);
1650  }
1651  else
1652  {
1653  Connect(whenResFireBody, oDataRegs[0].getResult(), memResData);
1654  }
1655  }
1656 
1657  // Handshaking
1658  Connect(body, inReady0, memReqReady);
1659  Connect(body, inReady1, memReqReady);
1660  if (store)
1661  {
1662  Connect(body, inReady2, memReqReady);
1663  }
1664 
1665  Connect(body, outValid0, oValidRegs[0].getResult());
1666  Connect(body, outData0, oDataRegs[0].getResult());
1667  auto andOp = AddAndOp(body, outReady0, outValid0);
1668  Connect(
1669  // When o0 fires
1670  AddWhenOp(body, andOp, false).getThenBodyBuilder().getBlock(),
1671  oValidRegs[0].getResult(),
1672  zeroBitValue);
1673  if (!store)
1674  {
1675  auto outBundle1 = GetOutPort(module, 1);
1676  auto outReady1 = GetSubfield(body, outBundle1, "ready");
1677  auto outValid1 = GetSubfield(body, outBundle1, "valid");
1678  auto outData1 = GetSubfield(body, outBundle1, "data");
1679 
1680  Connect(body, outValid1, oValidRegs[1].getResult());
1681  Connect(body, outData1, oDataRegs[1].getResult());
1682  auto andOp = AddAndOp(body, outReady1, outValid1);
1683  Connect(
1684  // When o1 fires
1685  AddWhenOp(body, andOp, false).getThenBodyBuilder().getBlock(),
1686  oValidRegs[1].getResult(),
1687  zeroBitValue);
1688  }
1689 
1690  return module;
1691 }
1692 
1693 circt::firrtl::FModuleOp
1695 {
1696  // Create the module and its input/output ports
1697  auto module = nodeToModule(node);
1698  auto body = module.getBodyBlock();
1699 
1700  // Input signals
1701  auto inBundle0 = GetInPort(module, 0);
1702  auto inReady0 = GetSubfield(body, inBundle0, "ready");
1703  auto inValid0 = GetSubfield(body, inBundle0, "valid");
1704  // auto inData0 = GetSubfield(body, inBundle0, "data");
1705  auto inBundle1 = GetInPort(module, 1);
1706  auto inReady1 = GetSubfield(body, inBundle1, "ready");
1707  auto inValid1 = GetSubfield(body, inBundle1, "valid");
1708  auto inData1 = GetSubfield(body, inBundle1, "data");
1709  // Output signals
1710  auto outBundle = GetOutPort(module, 0);
1711  auto outReady = GetSubfield(body, outBundle, "ready");
1712  auto outValid = GetSubfield(body, outBundle, "valid");
1713  auto outData = GetSubfield(body, outBundle, "data");
1714 
1715  auto andOp0 = AddAndOp(body, outReady, inValid1);
1716  auto andOp1 = AddAndOp(body, outReady, inValid0);
1717  auto andOp2 = AddAndOp(body, inValid0, inValid1);
1718 
1719  Connect(body, inReady0, andOp0);
1720  Connect(body, inReady1, andOp1);
1721  Connect(body, outValid, andOp2);
1722  Connect(body, outData, inData1);
1723 
1724  return module;
1725 }
1726 
1727 circt::firrtl::FModuleOp
1729 {
1730  // Create the module and its input/output ports
1731  auto module = nodeToModule(node);
1732  auto body = module.getBodyBlock();
1733 
1734  auto clock = GetClockSignal(module);
1735  auto reset = GetResetSignal(module);
1736 
1737  // Input signals
1738  auto inBundle = GetInPort(module, 0);
1739  auto inReady = GetSubfield(body, inBundle, "ready");
1740  auto inValid = GetSubfield(body, inBundle, "valid");
1741  auto inData = GetSubfield(body, inBundle, "data");
1742  // Output signals
1743  auto outBundle = GetOutPort(module, 0);
1744  Connect(body, outBundle, inBundle);
1745  auto trigger = AddAndOp(body, AddAndOp(body, inReady, inValid), AddNotOp(body, reset));
1746  auto pn = dynamic_cast<const PrintOperation *>(&node->GetOperation());
1747  auto formatString = "print node " + std::to_string(pn->id()) + ": %x\n";
1748  auto name = "print_node_" + std::to_string(pn->id());
1749  auto printValue = AddPadOp(body, inData, 64);
1750  ::llvm::SmallVector<mlir::Value> operands;
1751  operands.push_back(printValue);
1752  body->push_back(Builder_->create<circt::firrtl::PrintFOp>(
1753  Builder_->getUnknownLoc(),
1754  clock,
1755  trigger,
1756  formatString,
1757  operands,
1758  name));
1759  return module;
1760 }
1761 
1762 circt::firrtl::FModuleOp
1764 {
1765  // Create the module and its input/output ports
1766  auto module = nodeToModule(node);
1767  auto body = module.getBodyBlock();
1768 
1769  auto clock = GetClockSignal(module);
1770  auto reset = GetResetSignal(module);
1771  auto zeroBitValue = GetConstant(body, 1, 0);
1772  auto oneBitValue = GetConstant(body, 1, 1);
1773 
1774  std::string validName("buf_valid_reg");
1775  auto validReg = Builder_->create<circt::firrtl::RegResetOp>(
1776  Builder_->getUnknownLoc(),
1777  GetIntType(1),
1778  clock,
1779  reset,
1780  oneBitValue,
1781  Builder_->getStringAttr(validName));
1782  body->push_back(validReg);
1783 
1784  std::string dataName("buf_data_reg");
1785  auto dataReg = Builder_->create<circt::firrtl::RegResetOp>(
1786  Builder_->getUnknownLoc(),
1787  GetIntType(node->input(0)->Type().get()),
1788  clock,
1789  reset,
1790  zeroBitValue,
1791  Builder_->getStringAttr(dataName));
1792  body->push_back(dataReg);
1793 
1794  auto inBundle = GetInPort(module, 0);
1795  auto inReady = GetSubfield(body, inBundle, "ready");
1796  auto inValid = GetSubfield(body, inBundle, "valid");
1797  auto inData = GetSubfield(body, inBundle, "data");
1798 
1799  auto outBundle = GetOutPort(module, 0);
1800  auto outReady = GetSubfield(body, outBundle, "ready");
1801  auto outValid = GetSubfield(body, outBundle, "valid");
1802  auto outData = GetSubfield(body, outBundle, "data");
1803 
1804  auto orOp = AddOrOp(body, validReg.getResult(), inValid);
1805  Connect(body, outValid, orOp);
1806  auto muxOp = AddMuxOp(body, validReg.getResult(), dataReg.getResult(), inData);
1807  Connect(body, outData, muxOp);
1808  auto notOp = AddNotOp(body, validReg.getResult());
1809  Connect(body, inReady, notOp);
1810 
1811  // When
1812  auto condition = AddAndOp(body, inValid, inReady);
1813  auto whenOp = AddWhenOp(body, condition, false);
1814  auto thenBody = whenOp.getThenBodyBuilder().getBlock();
1815  Connect(thenBody, validReg.getResult(), oneBitValue);
1816  Connect(thenBody, dataReg.getResult(), inData);
1817 
1818  // When
1819  condition = AddAndOp(body, outValid, outReady);
1820  whenOp = AddWhenOp(body, condition, false);
1821  thenBody = whenOp.getThenBodyBuilder().getBlock();
1822  Connect(thenBody, validReg.getResult(), zeroBitValue);
1823 
1824  return module;
1825 }
1826 
1827 circt::firrtl::FModuleOp
1829 {
1830  // Create the module and its input/output ports
1831  auto module = nodeToModule(node);
1832  auto body = module.getBodyBlock();
1833 
1834  auto op = dynamic_cast<const BufferOperation *>(&(node->GetOperation()));
1835  auto capacity = op->Capacity();
1836 
1837  auto clock = GetClockSignal(module);
1838  auto reset = GetResetSignal(module);
1839  auto zeroBitValue = GetConstant(body, 1, 0);
1840  auto zeroValue = GetConstant(body, JlmSize(node->input(0)->Type().get()), 0);
1841  auto oneBitValue = GetConstant(body, 1, 1);
1842 
1843  // Registers
1844  ::llvm::SmallVector<circt::firrtl::RegResetOp> validRegs;
1845  ::llvm::SmallVector<circt::firrtl::RegResetOp> dataRegs;
1846  for (size_t i = 0; i <= capacity; i++)
1847  {
1848  std::string validName("buf");
1849  validName.append(std::to_string(i));
1850  validName.append("_valid_reg");
1851  auto validReg = Builder_->create<circt::firrtl::RegResetOp>(
1852  Builder_->getUnknownLoc(),
1853  GetIntType(1),
1854  clock,
1855  reset,
1856  zeroBitValue,
1857  Builder_->getStringAttr(validName));
1858  body->push_back(validReg);
1859  validRegs.push_back(validReg);
1860 
1861  std::string dataName("buf");
1862  dataName.append(std::to_string(i));
1863  dataName.append("_data_reg");
1864  auto dataReg = Builder_->create<circt::firrtl::RegResetOp>(
1865  Builder_->getUnknownLoc(),
1866  GetIntType(node->input(0)->Type().get()),
1867  clock,
1868  reset,
1869  zeroValue,
1870  Builder_->getStringAttr(dataName));
1871  body->push_back(dataReg);
1872  dataRegs.push_back(dataReg);
1873  }
1874  // FIXME
1875  // Resource waste as the registers will constantly be set to zero
1876  // This simplifies the code below but might waste resources unless
1877  // the tools are clever anough to replace it with a constant
1878  Connect(body, validRegs[capacity].getResult(), zeroBitValue);
1879  Connect(body, dataRegs[capacity].getResult(), zeroValue);
1880 
1881  // Add wires
1882  ::llvm::SmallVector<circt::firrtl::WireOp> shiftWires;
1883  ::llvm::SmallVector<circt::firrtl::WireOp> consumedWires;
1884  for (size_t i = 0; i <= capacity; i++)
1885  {
1886  std::string shiftName("shift_out");
1887  shiftName.append(std::to_string(i));
1888  shiftWires.push_back(AddWireOp(body, shiftName, 1));
1889  std::string consumedName("in_consumed");
1890  consumedName.append(std::to_string(i));
1891  consumedWires.push_back(AddWireOp(body, consumedName, 1));
1892  }
1893 
1894  auto inBundle = GetInPort(module, 0);
1895  auto inReady = GetSubfield(body, inBundle, "ready");
1896  auto inValid = GetSubfield(body, inBundle, "valid");
1897  auto inData = GetSubfield(body, inBundle, "data");
1898 
1899  auto outBundle = GetOutPort(module, 0);
1900  auto outReady = GetSubfield(body, outBundle, "ready");
1901  auto outValid = GetSubfield(body, outBundle, "valid");
1902  auto outData = GetSubfield(body, outBundle, "data");
1903 
1904  // Connect out to buf0
1905  Connect(body, outValid, validRegs[0].getResult());
1906  Connect(body, outData, dataRegs[0].getResult());
1907  auto andOp = AddAndOp(body, outReady, outValid);
1908  Connect(body, shiftWires[0].getResult(), andOp);
1909  if (op->IsPassThrough())
1910  {
1911  auto notOp = AddNotOp(body, validRegs[0].getResult());
1912  andOp = AddAndOp(body, notOp, outReady);
1913  Connect(body, consumedWires[0].getResult(), andOp);
1914  auto whenOp = AddWhenOp(body, notOp, false);
1915  auto thenBody = whenOp.getThenBodyBuilder().getBlock();
1916  Connect(thenBody, outData, inData);
1917  Connect(thenBody, outValid, inValid);
1918  }
1919  else
1920  {
1921  Connect(body, consumedWires[0].getResult(), zeroBitValue);
1922  }
1923 
1924  // The buffer is ready if the last one is empty
1925  auto notOp = AddNotOp(body, validRegs[capacity - 1].getResult());
1926  Connect(body, inReady, notOp);
1927 
1928  andOp = AddAndOp(body, inReady, inValid);
1929  for (size_t i = 0; i < capacity; ++i)
1930  {
1931  Connect(body, consumedWires[i + 1].getResult(), consumedWires[i].getResult());
1932  Connect(body, shiftWires[i + 1].getResult(), zeroBitValue);
1933 
1934  // When valid reg
1935  auto whenOp = AddWhenOp(body, shiftWires[i].getResult(), false);
1936  auto thenBody = whenOp.getThenBodyBuilder().getBlock();
1937  Connect(thenBody, validRegs[i].getResult(), zeroBitValue);
1938 
1939  // When will be empty
1940  auto notOp = AddNotOp(body, validRegs[i].getResult());
1941  auto condition = AddOrOp(body, shiftWires[i].getResult(), notOp);
1942  whenOp = AddWhenOp(body, condition, false);
1943  thenBody = whenOp.getThenBodyBuilder().getBlock();
1944  // Create the condition needed in nested when
1945  notOp = AddNotOp(thenBody, consumedWires[i].getResult());
1946  auto elseCondition = AddAndOp(thenBody, andOp, notOp);
1947 
1948  // Nested when valid reg
1949  whenOp = AddWhenOp(thenBody, validRegs[i + 1].getResult(), true);
1950  thenBody = whenOp.getThenBodyBuilder().getBlock();
1951  Connect(thenBody, validRegs[i].getResult(), oneBitValue);
1952  Connect(thenBody, dataRegs[i].getResult(), dataRegs[i + 1].getResult());
1953  Connect(thenBody, shiftWires[i + 1].getResult(), oneBitValue);
1954 
1955  // Nested else in available
1956  auto elseBody = whenOp.getElseBodyBuilder().getBlock();
1957  auto nestedWhen = AddWhenOp(elseBody, elseCondition, false);
1958  thenBody = nestedWhen.getThenBodyBuilder().getBlock();
1959  Connect(thenBody, consumedWires[i + 1].getResult(), oneBitValue);
1960  Connect(thenBody, validRegs[i].getResult(), oneBitValue);
1961  Connect(thenBody, dataRegs[i].getResult(), inData);
1962  }
1963 
1964  return module;
1965 }
1966 
1967 circt::firrtl::FModuleOp
1969 {
1970  // Create the module and its input/output ports
1971  auto module = nodeToModule(node);
1972  auto body = module.getBodyBlock();
1973 
1974  auto op = dynamic_cast<const hls::AddressQueueOperation *>(&(node->GetOperation()));
1975  auto capacity = op->capacity;
1976 
1977  auto clock = GetClockSignal(module);
1978  auto reset = GetResetSignal(module);
1979  auto zeroBitValue = GetConstant(body, 1, 0);
1980  auto zeroValue = GetConstant(body, JlmSize(node->input(0)->Type().get()), 0);
1981  auto oneBitValue = GetConstant(body, 1, 1);
1982 
1983  // Registers
1984  ::llvm::SmallVector<circt::firrtl::RegResetOp> validRegs;
1985  ::llvm::SmallVector<circt::firrtl::RegResetOp> dataRegs;
1986  for (size_t i = 0; i <= capacity; i++)
1987  {
1988  std::string validName("buf");
1989  validName.append(std::to_string(i));
1990  validName.append("_valid_reg");
1991  auto validReg = Builder_->create<circt::firrtl::RegResetOp>(
1992  Builder_->getUnknownLoc(),
1993  GetIntType(1),
1994  clock,
1995  reset,
1996  zeroBitValue,
1997  Builder_->getStringAttr(validName));
1998  body->push_back(validReg);
1999  validRegs.push_back(validReg);
2000 
2001  std::string dataName("buf");
2002  dataName.append(std::to_string(i));
2003  dataName.append("_data_reg");
2004  auto dataReg = Builder_->create<circt::firrtl::RegResetOp>(
2005  Builder_->getUnknownLoc(),
2006  GetIntType(node->input(0)->Type().get()),
2007  clock,
2008  reset,
2009  zeroValue,
2010  Builder_->getStringAttr(dataName));
2011  body->push_back(dataReg);
2012  dataRegs.push_back(dataReg);
2013  }
2014  // FIXME
2015  // Resource waste as the registers will constantly be set to zero
2016  // This simplifies the code below but might waste resources unless
2017  // the tools are clever anough to replace it with a constant
2018  Connect(body, validRegs[capacity].getResult(), zeroBitValue);
2019  Connect(body, dataRegs[capacity].getResult(), zeroValue);
2020 
2021  // Add wires
2022  ::llvm::SmallVector<circt::firrtl::WireOp> shiftWires;
2023  ::llvm::SmallVector<circt::firrtl::WireOp> consumedWires;
2024  for (size_t i = 0; i <= capacity; i++)
2025  {
2026  std::string shiftName("shift_out");
2027  shiftName.append(std::to_string(i));
2028  shiftWires.push_back(AddWireOp(body, shiftName, 1));
2029  std::string consumedName("in_consumed");
2030  consumedName.append(std::to_string(i));
2031  consumedWires.push_back(AddWireOp(body, consumedName, 1));
2032  }
2033 
2034  auto checkBundle = GetInPort(module, 0);
2035  auto checkReady = GetSubfield(body, checkBundle, "ready");
2036  auto checkValid = GetSubfield(body, checkBundle, "valid");
2037  auto checkData = GetSubfield(body, checkBundle, "data");
2038 
2039  auto enqBundle = GetInPort(module, 1);
2040  auto enqReady = GetSubfield(body, enqBundle, "ready");
2041  auto enqValid = GetSubfield(body, enqBundle, "valid");
2042  auto enqData = GetSubfield(body, enqBundle, "data");
2043 
2044  auto deqBundle = GetInPort(module, 2);
2045  auto deqReady = GetSubfield(body, deqBundle, "ready");
2046  auto deqValid = GetSubfield(body, deqBundle, "valid");
2047 
2048  auto outBundle = GetOutPort(module, 0);
2049  auto outReady = GetSubfield(body, outBundle, "ready");
2050  auto outValid = GetSubfield(body, outBundle, "valid");
2051  auto outData = GetSubfield(body, outBundle, "data");
2052 
2053  // Connect out to addr
2054  auto addr_in_queue_wire = AddWireOp(body, "addr_in_queue", 1);
2055  auto addr_out_valid = AddAndOp(body, checkValid, AddNotOp(body, addr_in_queue_wire.getResult()));
2056  Connect(body, outValid, addr_out_valid);
2057  Connect(body, outData, checkData);
2058  Connect(body, checkReady, AddAndOp(body, outReady, addr_out_valid));
2059  auto andOp = AddAndOp(body, deqReady, deqValid);
2060 
2061  Connect(body, deqReady, validRegs[0].getResult());
2062  // deq fire
2063  Connect(body, shiftWires[0].getResult(), andOp);
2064  // if (op->pass_through) {
2065  // auto notOp = AddNotOp(body, validRegs[0]);
2066  // andOp = AddAndOp(body, notOp, outReady);
2067  // Connect(body, consumedWires[0], andOp);
2068  // auto whenOp = AddWhenOp(body, notOp, false);
2069  // auto thenBody = whenOp.getThenBodyBuilder().getBlock();
2070  // Connect(thenBody, outData, inData);
2071  // Connect(thenBody, outValid, inValid);
2072  // } else {
2073  Connect(body, consumedWires[0].getResult(), zeroBitValue);
2074  // }
2075 
2076  // The buffer is ready if the last one is empty
2077  auto notOp = AddNotOp(body, validRegs[capacity - 1].getResult());
2078  Connect(body, enqReady, notOp);
2079 
2080  andOp = AddAndOp(body, enqReady, enqValid);
2081  mlir::Value addr_in_queue = zeroBitValue;
2082  for (size_t i = 0; i < capacity; ++i)
2083  {
2084  Connect(body, consumedWires[i + 1].getResult(), consumedWires[i].getResult());
2085  Connect(body, shiftWires[i + 1].getResult(), zeroBitValue);
2086 
2087  // When valid reg
2088  auto whenOp = AddWhenOp(body, shiftWires[i].getResult(), false);
2089  auto thenBody = whenOp.getThenBodyBuilder().getBlock();
2090  Connect(thenBody, validRegs[i].getResult(), zeroBitValue);
2091 
2092  // When will be empty
2093  auto notOp = AddNotOp(body, validRegs[i].getResult());
2094  auto condition = AddOrOp(body, shiftWires[i].getResult(), notOp);
2095  whenOp = AddWhenOp(body, condition, false);
2096  thenBody = whenOp.getThenBodyBuilder().getBlock();
2097  // Create the condition needed in nested when
2098  notOp = AddNotOp(thenBody, consumedWires[i].getResult());
2099  auto elseCondition = AddAndOp(thenBody, andOp, notOp);
2100 
2101  // Nested when valid reg
2102  whenOp = AddWhenOp(thenBody, validRegs[i + 1].getResult(), true);
2103  thenBody = whenOp.getThenBodyBuilder().getBlock();
2104  Connect(thenBody, validRegs[i].getResult(), oneBitValue);
2105  Connect(thenBody, dataRegs[i].getResult(), dataRegs[i + 1].getResult());
2106  Connect(thenBody, shiftWires[i + 1].getResult(), oneBitValue);
2107 
2108  // Nested else in available
2109  auto elseBody = whenOp.getElseBodyBuilder().getBlock();
2110  auto nestedWhen = AddWhenOp(elseBody, elseCondition, false);
2111  thenBody = nestedWhen.getThenBodyBuilder().getBlock();
2112  Connect(thenBody, consumedWires[i + 1].getResult(), oneBitValue);
2113  Connect(thenBody, validRegs[i].getResult(), oneBitValue);
2114  Connect(thenBody, dataRegs[i].getResult(), enqData);
2115 
2116  addr_in_queue = AddOrOp(
2117  body,
2118  addr_in_queue,
2119  AddAndOp(
2120  body,
2121  validRegs[i].getResult(),
2122  AddEqOp(body, dataRegs[i].getResult(), checkData)));
2123  }
2124  if (op->combinatorial)
2125  {
2126  // may not be the same as addr enqueued in same cycle
2127  addr_in_queue =
2128  AddOrOp(body, addr_in_queue, AddAndOp(body, enqValid, AddEqOp(body, enqData, checkData)));
2129  }
2130  Connect(body, addr_in_queue_wire.getResult(), addr_in_queue);
2131 
2132  return module;
2133 }
2134 
2135 circt::firrtl::FModuleOp
2137 {
2138  // Create the module and its input/output ports
2139  auto module = nodeToModule(node);
2140  auto body = module.getBodyBlock();
2141 
2142  auto zeroBitValue = GetConstant(body, 1, 0);
2143 
2144  auto inputs = node->ninputs();
2145  auto outBundle = GetOutPort(module, 0);
2146  auto outReady = GetSubfield(body, outBundle, "ready");
2147  // Out valid
2148  auto outValid = GetSubfield(body, outBundle, "valid");
2149  Connect(body, outValid, zeroBitValue);
2150  // Out data
2151  auto invalid = GetInvalid(body, JlmSize(node->output(0)->Type().get()));
2152  auto outData = GetSubfield(body, outBundle, "data");
2153  Connect(body, outData, invalid);
2154  // Input ready 0
2155  auto inBundle0 = GetInPort(module, 0);
2156  auto inReady0 = GetSubfield(body, inBundle0, "ready");
2157  auto inValid0 = GetSubfield(body, inBundle0, "valid");
2158  auto inData0 = GetSubfield(body, inBundle0, "data");
2159 
2160  // Add discard registers
2161  auto clock = GetClockSignal(module);
2162  auto reset = GetResetSignal(module);
2163 
2164  int ctr_bits = 4;
2165  auto ctr_zero = GetConstant(body, ctr_bits, 0);
2166  auto ctr_one = GetConstant(body, ctr_bits, 1);
2167  auto ctr_max = GetConstant(body, ctr_bits, (1 << ctr_bits) - 1);
2168 
2169  ::llvm::SmallVector<mlir::Value> discard_queueds;
2170  ::llvm::SmallVector<circt::firrtl::WireOp> discardWires;
2171  mlir::Value any_discard_full = GetConstant(body, 1, 0);
2172  // each input has a counter that tracks how many tokens to discard
2173  // the discardWires are used to increase these counters
2174  for (size_t i = 1; i < inputs; i++)
2175  {
2176  auto inBundle = GetInPort(module, i);
2177  auto inReady = GetSubfield(body, inBundle, "ready");
2178  auto inValid = GetSubfield(body, inBundle, "valid");
2179 
2180  std::string regName("i");
2181  regName.append(std::to_string(i));
2182  regName.append("_discard_ctr");
2183  auto discard_ctr_reg = Builder_->create<circt::firrtl::RegResetOp>(
2184  Builder_->getUnknownLoc(),
2185  GetIntType(ctr_bits),
2186  clock,
2187  reset,
2188  ctr_zero,
2189  Builder_->getStringAttr(regName));
2190  body->push_back(discard_ctr_reg);
2191 
2192  std::string wireName("i");
2193  wireName.append(std::to_string(i));
2194  wireName.append("_discard");
2195  auto discard_wire = AddWireOp(body, wireName, 1);
2196  discardWires.push_back(discard_wire);
2197  Connect(body, discard_wire.getResult(), zeroBitValue);
2198  auto discard_queued = AddNeqOp(body, discard_ctr_reg.getResult(), ctr_zero);
2199  discard_queueds.push_back(discard_queued);
2200  auto discard_full = AddEqOp(body, discard_ctr_reg.getResult(), ctr_max);
2201  any_discard_full = AddOrOp(body, any_discard_full, discard_full);
2202  auto fire = AddAndOp(body, inReady, inValid);
2203  Connect(body, inReady, AddOrOp(body, discard_queued, discard_wire.getResult()));
2204  auto whenOp = AddWhenOp(
2205  body,
2206  AddAndOp(
2207  body,
2208  AddAndOp(body, discard_queued, fire),
2209  AddNotOp(body, discard_wire.getResult())),
2210  true);
2211  // This connect was a partial connect and is likely to not work
2212  Connect(
2213  &whenOp.getThenBlock(),
2214  discard_ctr_reg.getResult(),
2215  DropMSBs(
2216  &whenOp.getThenBlock(),
2217  AddSubOp(&whenOp.getThenBlock(), discard_ctr_reg.getResult(), ctr_one),
2218  1));
2219  auto elseWhenOp = AddWhenOp(
2220  &whenOp.getElseBlock(),
2221  AddAndOp(
2222  &whenOp.getElseBlock(),
2223  discard_wire.getResult(),
2224  AddNotOp(&whenOp.getElseBlock(), fire)),
2225  false);
2226  // This connect was a partial connect and is likely to not work
2227  Connect(
2228  &elseWhenOp.getThenBlock(),
2229  discard_ctr_reg.getResult(),
2230  DropMSBs(
2231  &elseWhenOp.getThenBlock(),
2232  AddAddOp(&elseWhenOp.getThenBlock(), discard_ctr_reg.getResult(), ctr_one),
2233  1));
2234  }
2235 
2236  auto out_fire = AddAndOp(body, outReady, outValid);
2237  Connect(body, inReady0, out_fire);
2238 
2239  auto matchBlock =
2240  &AddWhenOp(body, AddAndOp(body, inValid0, AddNotOp(body, any_discard_full)), false)
2241  .getThenBlock();
2242  for (size_t i = 1; i < inputs; i++)
2243  {
2244  auto inBundle = GetInPort(module, i);
2245  auto inReady = GetSubfield(matchBlock, inBundle, "ready");
2246  auto inValid = GetSubfield(matchBlock, inBundle, "valid");
2247  auto inData = GetSubfield(matchBlock, inBundle, "data");
2248 
2249  auto whenBlock = &AddWhenOp(
2250  matchBlock,
2251  AddAndOp(
2252  matchBlock,
2253  AddEqOp(matchBlock, inData0, GetConstant(matchBlock, 64, i - 1)),
2254  AddNotOp(matchBlock, discard_queueds[i - 1])),
2255  false)
2256  .getThenBlock();
2257  Connect(whenBlock, outValid, inValid);
2258  Connect(whenBlock, outData, inData);
2259  Connect(whenBlock, inReady, outReady);
2260  for (size_t j = 1; j < inputs; j++)
2261  {
2262  if (i == j)
2263  {
2264  continue;
2265  }
2266  Connect(whenBlock, discardWires[j - 1].getResult(), out_fire);
2267  }
2268  }
2269 
2270  return module;
2271 }
2272 
2273 circt::firrtl::FModuleOp
2275 {
2276  // Create the module and its input/output ports
2277  auto module = nodeToModule(node);
2278  auto body = module.getBodyBlock();
2279 
2280  auto inputs = node->ninputs();
2281  auto outBundle = GetOutPort(module, 0);
2282  auto outReady = GetSubfield(body, outBundle, "ready");
2283  // Out valid
2284  auto outValid = GetSubfield(body, outBundle, "valid");
2285  auto zeroBitValue = GetConstant(body, 1, 0);
2286  Connect(body, outValid, zeroBitValue);
2287  // Out data
2288  auto invalid = GetInvalid(body, JlmSize(node->output(0)->Type().get()));
2289  auto outData = GetSubfield(body, outBundle, "data");
2290  Connect(body, outData, invalid);
2291 
2292  auto inBundle0 = GetInPort(module, 0);
2293  auto inReady0 = GetSubfield(body, inBundle0, "ready");
2294  auto inValid0 = GetSubfield(body, inBundle0, "valid");
2295  Connect(body, inReady0, zeroBitValue);
2296  auto inData0 = GetSubfield(body, inBundle0, "data");
2297 
2298  // We have already handled the first input (i.e., i == 0)
2299  for (size_t i = 1; i < inputs; i++)
2300  {
2301  auto inBundle = GetInPort(module, i);
2302  auto inReady = GetSubfield(body, inBundle, "ready");
2303  auto inValid = GetSubfield(body, inBundle, "valid");
2304  auto inData = GetSubfield(body, inBundle, "data");
2305  Connect(body, inReady, zeroBitValue);
2306  auto constant = GetConstant(body, JlmSize(node->input(0)->Type().get()), i - 1);
2307  auto eqOp = AddEqOp(body, inData0, constant);
2308  auto andOp = AddAndOp(body, inValid0, eqOp);
2309  auto whenOp = AddWhenOp(body, andOp, false);
2310  auto thenBody = whenOp.getThenBodyBuilder().getBlock();
2311  Connect(thenBody, outValid, inValid);
2312  Connect(thenBody, outData, inData);
2313  Connect(thenBody, inReady, outReady);
2314  auto whenAnd = AddAndOp(thenBody, outReady, inValid);
2315  Connect(thenBody, inReady0, whenAnd);
2316  }
2317  return module;
2318 }
2319 
2320 circt::firrtl::FModuleOp
2322 {
2323  // Create the module and its input/output ports
2324  auto module = nodeToModule(node);
2325  auto body = module.getBodyBlock();
2326 
2327  auto zeroBitValue = GetConstant(body, 1, 0);
2328 
2329  auto inBundle0 = GetInPort(module, 0);
2330  auto inReady0 = GetSubfield(body, inBundle0, "ready");
2331  auto inValid0 = GetSubfield(body, inBundle0, "valid");
2332  auto inData0 = GetSubfield(body, inBundle0, "data");
2333 
2334  auto inBundle1 = GetInPort(module, 1);
2335  auto inReady1 = GetSubfield(body, inBundle1, "ready");
2336  auto inValid1 = GetSubfield(body, inBundle1, "valid");
2337  auto inData1 = GetSubfield(body, inBundle1, "data");
2338 
2339  Connect(body, inReady0, zeroBitValue);
2340  Connect(body, inReady1, zeroBitValue);
2341 
2342  auto invalid = GetInvalid(body, 1);
2343  for (size_t i = 0; i < node->noutputs(); i++)
2344  {
2345  auto outBundle = GetOutPort(module, i);
2346  auto outReady = GetSubfield(body, outBundle, "ready");
2347  auto outValid = GetSubfield(body, outBundle, "valid");
2348  auto outData = GetSubfield(body, outBundle, "data");
2349  Connect(body, outValid, zeroBitValue);
2350  Connect(body, outData, invalid);
2351 
2352  auto constant = GetConstant(body, JlmSize(node->input(0)->Type().get()), i);
2353  auto eqOp = AddEqOp(body, inData0, constant);
2354  auto condition = AddAndOp(body, inValid0, eqOp);
2355  auto whenOp = AddWhenOp(body, condition, false);
2356  auto thenBody = whenOp.getThenBodyBuilder().getBlock();
2357  Connect(thenBody, inReady1, outReady);
2358  auto andOp = AddAndOp(thenBody, outReady, inValid1);
2359  Connect(thenBody, inReady0, andOp);
2360  Connect(thenBody, outValid, inValid1);
2361  Connect(thenBody, outData, inData1);
2362  }
2363 
2364  return module;
2365 }
2366 
2367 circt::firrtl::FModuleLike
2369 {
2370  if (dynamic_cast<const hls::SinkOperation *>(&(node->GetOperation())))
2371  {
2372  return MlirGenSink(node);
2373  }
2374  else if (dynamic_cast<const ForkOperation *>(&(node->GetOperation())))
2375  {
2376  return MlirGenFork(node);
2377  }
2378  else if (rvsdg::is<LoopConstantBufferOperation>(node))
2379  {
2380  return MlirGenLoopConstBuffer(node);
2381  // } else if (dynamic_cast<const jlm::LoadOperation *>(&(node->GetOperation()))) {
2382  // return MlirGenMem(node);
2383  // } else if (dynamic_cast<const jlm::StoreOperation *>(&(node->GetOperati()))) {
2384  // return MlirGenMem(node);
2385  }
2386  else if (dynamic_cast<const LoadOperation *>(&(node->GetOperation())))
2387  {
2388  return MlirGenHlsLoad(node);
2389  }
2390  else if (dynamic_cast<const hls::DecoupledLoadOperation *>(&(node->GetOperation())))
2391  {
2392  return MlirGenExtModule(node);
2393  }
2394  else if (dynamic_cast<const hls::StoreOperation *>(&(node->GetOperation())))
2395  {
2396  return MlirGenHlsStore(node);
2397  }
2398  else if (dynamic_cast<const hls::LocalLoadOperation *>(&(node->GetOperation())))
2399  {
2400  // same as normal load for now, but with index instead of address
2401  return MlirGenHlsLoad(node);
2402  }
2403  if (rvsdg::is<LocalStoreOperation>(node))
2404  {
2405  // same as normal store for now, but with index instead of address
2406  return MlirGenHlsStore(node);
2407  }
2408  else if (dynamic_cast<const hls::LocalMemoryOperation *>(&(node->GetOperation())))
2409  {
2410  return MlirGenHlsLocalMem(node);
2411  }
2412  else if (rvsdg::is<MemoryResponseOperation>(node))
2413  {
2414  return MlirGenHlsMemResp(node);
2415  }
2416  else if (rvsdg::is<MemoryRequestOperation>(node))
2417  {
2418  return MlirGenHlsMemReq(node);
2419  }
2420  else if (jlm::rvsdg::is<const PredicateBufferOperation>(node->GetOperation()))
2421  {
2422  return MlirGenPredicationBuffer(node);
2423  }
2424  else if (auto b = dynamic_cast<const BufferOperation *>(&node->GetOperation()))
2425  {
2426  JLM_ASSERT(b->Capacity());
2427  return MlirGenExtModule(node);
2428  }
2429  else if (dynamic_cast<const hls::BranchOperation *>(&(node->GetOperation())))
2430  {
2431  return MlirGenBranch(node);
2432  }
2433  else if (rvsdg::is<TriggerOperation>(node))
2434  {
2435  return MlirGenTrigger(node);
2436  }
2437  else if (rvsdg::is<StateGateOperation>(node))
2438  {
2439  return MlirGenStateGate(node);
2440  }
2441  else if (dynamic_cast<const PrintOperation *>(&(node->GetOperation())))
2442  {
2443  return MlirGenPrint(node);
2444  }
2445  else if (dynamic_cast<const AddressQueueOperation *>(&(node->GetOperation())))
2446  {
2447  return MlirGenAddrQueue(node);
2448  }
2449  else if (auto o = dynamic_cast<const MuxOperation *>(&(node->GetOperation())))
2450  {
2451  if (o->discarding)
2452  {
2453  return MlirGenSimpleNode(node);
2454  }
2455  else
2456  {
2457  return MlirGenNDMux(node);
2458  }
2459  }
2460  bool is_float = false;
2461  for (size_t i = 0; i < node->ninputs(); ++i)
2462  {
2463  is_float = is_float || rvsdg::is<const llvm::FloatingPointType>(node->input(i)->Type());
2464  }
2465  for (size_t i = 0; i < node->noutputs(); ++i)
2466  {
2467  is_float = is_float || rvsdg::is<const llvm::FloatingPointType>(node->output(i)->Type());
2468  }
2469  if (is_float)
2470  {
2471  return MlirGenExtModule(node);
2472  }
2473  return MlirGenSimpleNode(node);
2474 }
2475 
2476 circt::firrtl::FModuleOp
2477 RhlsToFirrtlConverter::MlirGen(hls::LoopNode * loopNode, mlir::Block * circuitBody)
2478 {
2479  // Create the module and its input/output ports
2480  auto module = nodeToModule(loopNode);
2481  auto body = module.getBodyBlock();
2482 
2483  auto srModule = MlirGen(loopNode->subregion(), circuitBody);
2484  // Instantiate the region
2485  auto instance =
2486  Builder_->create<circt::firrtl::InstanceOp>(Builder_->getUnknownLoc(), srModule, "sr");
2487  body->push_back(instance);
2488  // Connect the Clock
2489  auto clock = GetClockSignal(module);
2490  Connect(body, GetInstancePort(instance, "clk"), clock);
2491  // Connect the Reset
2492  auto reset = GetResetSignal(module);
2493  Connect(body, GetInstancePort(instance, "reset"), reset);
2494  JLM_ASSERT(instance.getNumResults() == module.getNumPorts());
2495 
2496  const size_t clockAndResetOffset = 2;
2497  for (size_t i = 0; i < loopNode->ninputs(); ++i)
2498  {
2499  auto arg = loopNode->input(i)->arguments.begin().ptr();
2500  auto sourcePort = body->getArgument(i + clockAndResetOffset);
2501  Connect(body, GetInstancePort(instance, get_port_name(arg)), sourcePort);
2502  }
2503  for (size_t i = 0; i < loopNode->noutputs(); ++i)
2504  {
2505  auto res = loopNode->output(i)->results.begin().ptr();
2506  auto sinkPort = body->getArgument(i + loopNode->ninputs() + clockAndResetOffset);
2507  Connect(body, sinkPort, GetInstancePort(instance, get_port_name(res)));
2508  }
2509  return module;
2510 }
2511 
2512 circt::firrtl::BitsPrimOp
2513 RhlsToFirrtlConverter::DropMSBs(mlir::Block * body, mlir::Value value, int amount)
2514 {
2515  auto type = value.getType().cast<circt::firrtl::UIntType>();
2516  auto width = type.getWidth();
2517  auto result = AddBitsOp(body, value, width.value() - 1 - amount, 0);
2518  return result;
2519 }
2520 
2521 // Trace the argument back to the "node" generating the value
2522 // Returns the output of a node or the argument of a region that has
2523 // been instantiated as a module
2526 {
2527  // Check if the argument is part of a LoopNode
2528  auto region = arg->region();
2529  auto node = region->node();
2530  if (dynamic_cast<LoopNode *>(node))
2531  {
2532  if (auto ba = dynamic_cast<BackEdgeArgument *>(arg))
2533  {
2534  return ba->result()->origin();
2535  }
2536  else
2537  {
2538  // Check if the argument is connected to an input,
2539  // i.e., if the argument exits the region
2540  JLM_ASSERT(arg->input() != nullptr);
2541  // Check if we are in a nested region and directly
2542  // connected to the outer regions argument
2543  auto origin = arg->input()->origin();
2544  if (auto o = dynamic_cast<rvsdg::RegionArgument *>(origin))
2545  {
2546  // Need to find the source of the outer regions argument
2547  return TraceArgument(o);
2548  }
2549  else if (auto o = dynamic_cast<rvsdg::StructuralOutput *>(origin))
2550  {
2551  // Check if we the input of one LoopNode is connected to the output of another
2552  // StructuralNode, i.e., if the input is connected to the output of another LoopNode
2553  return TraceStructuralOutput(o);
2554  }
2555  // Else we have reached the source
2556  return origin;
2557  }
2558  }
2559  // Reached the argument of a structural node that is not a LoopNode
2560  return arg;
2561 }
2562 
2563 circt::firrtl::FModuleLike
2564 RhlsToFirrtlConverter::MlirGen(rvsdg::Region * subRegion, mlir::Block * circuitBody)
2565 {
2566  // Generate a vector with all inputs and outputs of the module
2567  ::llvm::SmallVector<circt::firrtl::PortInfo> ports;
2568 
2569  // Clock and reset ports
2570  AddClockPort(&ports);
2571  AddResetPort(&ports);
2572  // Argument ports
2573  for (size_t i = 0; i < subRegion->narguments(); ++i)
2574  {
2575  if (!dynamic_cast<BackEdgeArgument *>(subRegion->argument(i)))
2576  {
2577  AddBundlePort(
2578  &ports,
2579  circt::firrtl::Direction::In,
2580  get_port_name(subRegion->argument(i)),
2581  GetFirrtlType(subRegion->argument(i)->Type().get()));
2582  }
2583  }
2584  // Result ports
2585  for (size_t i = 0; i < subRegion->nresults(); ++i)
2586  {
2587  if (!dynamic_cast<BackEdgeResult *>(subRegion->result(i)))
2588  {
2589  AddBundlePort(
2590  &ports,
2591  circt::firrtl::Direction::Out,
2592  get_port_name(subRegion->result(i)),
2593  GetFirrtlType(subRegion->result(i)->Type().get()));
2594  }
2595  }
2596 
2597  // Create a name for the module
2598  auto moduleName = Builder_->getStringAttr("subregion_mod_" + util::strfmt(subRegion));
2599  // Now when we have all the port information we can create the module
2600  auto module = Builder_->create<circt::firrtl::FModuleOp>(
2601  Builder_->getUnknownLoc(),
2602  moduleName,
2603  circt::firrtl::ConventionAttr::get(
2604  Builder_->getContext(),
2605  circt::firrtl::Convention::Internal),
2606  ports);
2607  // Get the body of the module such that we can add contents to the module
2608  auto body = module.getBodyBlock();
2609 
2610  const size_t clockAndResetOffset = 2;
2611 
2612  std::unordered_map<rvsdg::Output *, mlir::Value> output_map;
2613  // Arguments
2614  for (size_t i = 0; i < subRegion->narguments(); ++i)
2615  {
2616  if (dynamic_cast<BackEdgeArgument *>(subRegion->argument(i)))
2617  {
2618  auto bundleType = GetBundleType(GetFirrtlType(subRegion->argument(i)->Type().get()));
2619  auto op = Builder_->create<circt::firrtl::WireOp>(
2620  Builder_->getUnknownLoc(),
2621  bundleType,
2622  get_port_name(subRegion->argument(i)));
2623  body->push_back(op);
2624  output_map[subRegion->argument(i)] = op.getResult();
2625  }
2626  else
2627  {
2628  auto ix = i;
2629  // handle indices of lambdas, that have no inputs and loops, that have backedges
2630  if (!rvsdg::is<rvsdg::LambdaOperation>(subRegion->node()))
2631  {
2632  ix = subRegion->argument(i)->input()->index();
2633  }
2634  auto sourcePort = body->getArgument(ix + clockAndResetOffset);
2635  output_map[subRegion->argument(i)] = sourcePort;
2636  }
2637  }
2638 
2639  auto clock = body->getArgument(0);
2640  auto reset = body->getArgument(1);
2641  // create nod instances and connect their inputs
2642  for (const auto node : rvsdg::TopDownTraverser(subRegion))
2643  {
2644  auto instance = AddInstanceOp(circuitBody, node);
2645  body->push_back(instance);
2646  // Connect clock and reset to the instance
2647  Connect(body, instance->getResult(0), clock);
2648  Connect(body, instance->getResult(1), reset);
2649  // connect inputs
2650  for (size_t i = 0; i < node->ninputs(); ++i)
2651  {
2652  auto sourcePort = output_map[node->input(i)->origin()];
2653  auto sinkPort = instance->getResult(i + clockAndResetOffset);
2654  Connect(body, sinkPort, sourcePort);
2655  }
2656  // map outputs
2657  for (size_t i = 0; i < node->noutputs(); ++i)
2658  {
2659  auto outputPort = instance->getResult(i + node->ninputs() + clockAndResetOffset);
2660  output_map[node->output(i)] = outputPort;
2661  }
2662  }
2663 
2664  for (size_t i = 0; i < subRegion->nresults(); ++i)
2665  {
2666  mlir::Value resultSink;
2667  if (auto ber = dynamic_cast<BackEdgeResult *>(subRegion->result(i)))
2668  {
2669  auto bundleType = GetBundleType(GetFirrtlType(subRegion->result(i)->Type().get()));
2670  auto op = Builder_->create<circt::firrtl::WireOp>(
2671  Builder_->getUnknownLoc(),
2672  bundleType,
2673  get_port_name(subRegion->result(i)));
2674  body->push_back(op);
2675  resultSink = op.getResult();
2676  // connect backedge to its argument
2677  Connect(body, output_map[ber->argument()], resultSink);
2678  }
2679  else
2680  {
2681  auto ix = i;
2682  // handle indices of lambdas, that have no outputs and loops, that have backedges
2683  if (!rvsdg::is<rvsdg::LambdaOperation>(subRegion->node()))
2684  {
2685  ix = subRegion->result(i)->output()->index();
2686  }
2687  resultSink = body->getArgument(ix + module.getNumInputPorts());
2688  }
2689  Connect(body, resultSink, output_map[subRegion->result(i)->origin()]);
2690  }
2691  circuitBody->push_back(module);
2692  return module;
2693 }
2694 
2695 // Trace a structural output back to the "node" generating the value
2696 // Returns the output of the node
2697 rvsdg::Output *
2699 {
2700  auto node = output->node();
2701 
2702  // We are only expecting LoopNode to have a structural output
2703  if (!dynamic_cast<LoopNode *>(node))
2704  {
2705  throw std::logic_error("Expected a hls::LoopNode but found: " + node->DebugString());
2706  }
2707  JLM_ASSERT(output->results.size() == 1);
2708  auto origin = output->results.begin().ptr()->origin();
2709  if (auto o = dynamic_cast<rvsdg::StructuralOutput *>(origin))
2710  {
2711  // Need to trace the output of the nested structural node
2712  return TraceStructuralOutput(o);
2713  }
2714 
2715  if (rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*origin))
2716  {
2717  // Found the source node
2718  return origin;
2719  }
2720  else if (dynamic_cast<rvsdg::RegionArgument *>(origin))
2721  {
2722  throw std::logic_error("Encountered pass through argument - should be eliminated");
2723  }
2724  else
2725  {
2726  throw std::logic_error("Encountered an unexpected output type");
2727  }
2728 }
2729 
2730 // Emit a circuit
2731 circt::firrtl::CircuitOp
2733 {
2734 
2735  // Ensure consistent naming across runs
2736  create_node_names(lambdaNode->subregion());
2737  // The same name is used for the circuit and main module
2738  auto moduleName = Builder_->getStringAttr(
2739  dynamic_cast<llvm::LlvmLambdaOperation &>(lambdaNode->GetOperation()).name() + "_lambda_mod");
2740  // Create the top level FIRRTL circuit
2741  auto circuit = Builder_->create<circt::firrtl::CircuitOp>(Builder_->getUnknownLoc(), moduleName);
2742  // The body will be populated with a list of modules
2743  auto circuitBody = circuit.getBodyBlock();
2744 
2745  // Get the region of the function
2746  auto subRegion = lambdaNode->subregion();
2747 
2748  //
2749  // Add ports
2750  //
2751  // Generate a vector with all inputs and outputs of the module
2752  ::llvm::SmallVector<circt::firrtl::PortInfo> ports;
2753 
2754  // Clock and reset ports
2755  AddClockPort(&ports);
2756  AddResetPort(&ports);
2757 
2758  auto reg_args = get_reg_args(*lambdaNode);
2759  auto reg_results = get_reg_results(*lambdaNode);
2760 
2761  // Input bundle
2762  using BundleElement = circt::firrtl::BundleType::BundleElement;
2763  ::llvm::SmallVector<BundleElement> inputElements;
2764  inputElements.push_back(GetReadyElement());
2765  inputElements.push_back(GetValidElement());
2766 
2767  for (size_t i = 0; i < reg_args.size(); ++i)
2768  {
2769  // don't generate ports for state edges
2770  if (reg_args[i]->Type()->Kind() == rvsdg::TypeKind::State)
2771  continue;
2772  std::string portName("data_");
2773  portName.append(std::to_string(i));
2774  inputElements.push_back(BundleElement(
2775  Builder_->getStringAttr(portName),
2776  false,
2777  GetIntType(reg_args[i]->Type().get())));
2778  }
2779  auto inputType = circt::firrtl::BundleType::get(Builder_->getContext(), inputElements);
2780  struct circt::firrtl::PortInfo iBundle = {
2781  Builder_->getStringAttr("i"), inputType, circt::firrtl::Direction::In, {},
2782  Builder_->getUnknownLoc(),
2783  };
2784  ports.push_back(iBundle);
2785 
2786  // Output bundle
2787  ::llvm::SmallVector<BundleElement> outputElements;
2788  outputElements.push_back(GetReadyElement());
2789  outputElements.push_back(GetValidElement());
2790  for (size_t i = 0; i < reg_results.size(); ++i)
2791  {
2792  // don't generate ports for state edges
2793  if (reg_results[i]->Type()->Kind() == rvsdg::TypeKind::State)
2794  continue;
2795  std::string portName("data_");
2796  portName.append(std::to_string(i));
2797  outputElements.push_back(BundleElement(
2798  Builder_->getStringAttr(portName),
2799  false,
2800  GetIntType(reg_results[i]->Type().get())));
2801  }
2802  auto outputType = circt::firrtl::BundleType::get(Builder_->getContext(), outputElements);
2803  struct circt::firrtl::PortInfo oBundle = {
2804  Builder_->getStringAttr("o"), outputType, circt::firrtl::Direction::Out, {},
2805  Builder_->getUnknownLoc(),
2806  };
2807  ports.push_back(oBundle);
2808 
2809  // Memory ports
2810  auto mem_reqs = get_mem_reqs(*lambdaNode);
2811  auto mem_resps = get_mem_resps(*lambdaNode);
2812  JLM_ASSERT(mem_resps.size() == mem_reqs.size());
2813  for (size_t i = 0; i < mem_reqs.size(); ++i)
2814  {
2815  ::llvm::SmallVector<BundleElement> memElements;
2816 
2817  ::llvm::SmallVector<BundleElement> reqElements;
2818  reqElements.push_back(GetReadyElement());
2819  reqElements.push_back(GetValidElement());
2820  reqElements.push_back(BundleElement(
2821  Builder_->getStringAttr("data"),
2822  false,
2823  GetFirrtlType(mem_reqs[i]->Type().get())));
2824  auto reqType = circt::firrtl::BundleType::get(Builder_->getContext(), reqElements);
2825  memElements.push_back(BundleElement(Builder_->getStringAttr("req"), false, reqType));
2826 
2827  ::llvm::SmallVector<BundleElement> resElements;
2828  resElements.push_back(GetReadyElement());
2829  resElements.push_back(GetValidElement());
2830  resElements.push_back(BundleElement(
2831  Builder_->getStringAttr("data"),
2832  false,
2833  GetFirrtlType(mem_resps[i]->Type().get())));
2834  auto resType = circt::firrtl::BundleType::get(Builder_->getContext(), resElements);
2835  memElements.push_back(BundleElement(Builder_->getStringAttr("res"), true, resType));
2836 
2837  auto memType = circt::firrtl::BundleType::get(Builder_->getContext(), memElements);
2838  struct circt::firrtl::PortInfo memBundle = {
2839  Builder_->getStringAttr("mem_" + std::to_string(i)),
2840  memType,
2841  circt::firrtl::Direction::Out,
2842  {},
2843  Builder_->getUnknownLoc(),
2844  };
2845  ports.push_back(memBundle);
2846  }
2847 
2848  // Now when we have all the port information we can create the module
2849  // The same name is used for the circuit and main module
2850  auto module = Builder_->create<circt::firrtl::FModuleOp>(
2851  Builder_->getUnknownLoc(),
2852  moduleName,
2853  circt::firrtl::ConventionAttr::get(
2854  Builder_->getContext(),
2855  circt::firrtl::Convention::Internal),
2856  ports);
2857  // Get the body of the module such that we can add contents to the module
2858  auto body = module.getBodyBlock();
2859 
2860  // Create a module of the region
2861  auto srModule = MlirGen(subRegion, circuitBody);
2862  // Instantiate the region
2863  auto instance =
2864  Builder_->create<circt::firrtl::InstanceOp>(Builder_->getUnknownLoc(), srModule, "sr");
2865  body->push_back(instance);
2866  // Connect the Clock
2867  auto clock = GetClockSignal(module);
2868  Connect(body, GetInstancePort(instance, "clk"), clock);
2869  // Connect the Reset
2870  auto reset = GetResetSignal(module);
2871  Connect(body, GetInstancePort(instance, "reset"), reset);
2872 
2873  //
2874  // Add registers to the module
2875  //
2876  // Reset when low (0 == false) 1-bit
2877  auto zeroBitValue = GetConstant(body, 1, 0);
2878 
2879  // Input registers
2880  ::llvm::SmallVector<circt::firrtl::RegResetOp> inputValidRegs;
2881  ::llvm::SmallVector<circt::firrtl::RegResetOp> inputDataRegs;
2882  for (size_t i = 0; i < reg_args.size(); ++i)
2883  {
2884  std::string validName("i");
2885  validName.append(std::to_string(i));
2886  validName.append("_valid_reg");
2887  auto validReg = Builder_->create<circt::firrtl::RegResetOp>(
2888  Builder_->getUnknownLoc(),
2889  GetIntType(1),
2890  clock,
2891  reset,
2892  zeroBitValue,
2893  Builder_->getStringAttr(validName));
2894  body->push_back(validReg);
2895  inputValidRegs.push_back(validReg);
2896 
2897  std::string dataName("i");
2898  dataName.append(std::to_string(i));
2899  dataName.append("_data_reg");
2900  auto dataReg = Builder_->create<circt::firrtl::RegResetOp>(
2901  Builder_->getUnknownLoc(),
2902  GetIntType(reg_args[i]->Type().get()),
2903  clock,
2904  reset,
2905  zeroBitValue,
2906  Builder_->getStringAttr(dataName));
2907  body->push_back(dataReg);
2908  inputDataRegs.push_back(dataReg);
2909 
2910  auto port = GetInstancePort(instance, "a" + std::to_string(reg_args[i]->index()));
2911  auto portValid = GetSubfield(body, port, "valid");
2912  Connect(body, portValid, validReg.getResult());
2913  auto portData = GetSubfield(body, port, "data");
2914  Connect(body, portData, dataReg.getResult());
2915 
2916  // When statement
2917  auto portReady = GetSubfield(body, port, "ready");
2918  auto whenCondition = AddAndOp(body, portReady, portValid);
2919  auto whenOp = AddWhenOp(body, whenCondition, false);
2920 
2921  // getThenBlock() cause an error during commpilation
2922  // So we first get the builder and then its associated body
2923  auto thenBody = whenOp.getThenBodyBuilder().getBlock();
2924  Connect(thenBody, validReg.getResult(), zeroBitValue);
2925  }
2926 
2927  // Output registers
2928 
2929  // Need to know the number of inputs so we can calculate the
2930  // correct index for outputs
2931  ::llvm::SmallVector<circt::firrtl::RegResetOp> outputValidRegs;
2932  ::llvm::SmallVector<circt::firrtl::RegResetOp> outputDataRegs;
2933 
2934  auto oneBitValue = GetConstant(body, 1, 1);
2935  for (size_t i = 0; i < reg_results.size(); ++i)
2936  {
2937  std::string validName("o");
2938  validName.append(std::to_string(i));
2939  validName.append("_valid_reg");
2940  auto validReg = Builder_->create<circt::firrtl::RegResetOp>(
2941  Builder_->getUnknownLoc(),
2942  GetIntType(1),
2943  clock,
2944  reset,
2945  zeroBitValue,
2946  Builder_->getStringAttr(validName));
2947  body->push_back(validReg);
2948  outputValidRegs.push_back(validReg);
2949 
2950  std::string dataName("o");
2951  dataName.append(std::to_string(i));
2952  dataName.append("_data_reg");
2953  auto dataReg = Builder_->create<circt::firrtl::RegResetOp>(
2954  Builder_->getUnknownLoc(),
2955  GetIntType(reg_results[i]->Type().get()),
2956  clock,
2957  reset,
2958  zeroBitValue,
2959  Builder_->getStringAttr(dataName));
2960  body->push_back(dataReg);
2961  outputDataRegs.push_back(dataReg);
2962 
2963  // Get the bundle
2964  auto port = GetInstancePort(instance, "r" + std::to_string(reg_results[i]->index()));
2965 
2966  auto portReady = GetSubfield(body, port, "ready");
2967  auto notValidReg = Builder_->create<circt::firrtl::NotPrimOp>(
2968  Builder_->getUnknownLoc(),
2969  circt::firrtl::IntType::get(Builder_->getContext(), false, 1),
2970  validReg.getResult());
2971  body->push_back(notValidReg);
2972  Connect(body, portReady, notValidReg);
2973 
2974  // When statement
2975  auto portValid = GetSubfield(body, port, "valid");
2976  auto portData = GetSubfield(body, port, "data");
2977  auto whenCondition = AddAndOp(body, portReady, portValid);
2978  auto whenOp = AddWhenOp(body, whenCondition, false);
2979 
2980  // getThenBlock() cause an error during commpilation
2981  // So we first get the builder and then its associated body
2982  auto thenBody = whenOp.getThenBodyBuilder().getBlock();
2983  Connect(thenBody, validReg.getResult(), oneBitValue);
2984  Connect(thenBody, dataReg.getResult(), portData);
2985  }
2986 
2987  // Create the ready signal for the input bundle
2988  mlir::Value prevAnd = oneBitValue;
2989  for (size_t i = 0; i < inputValidRegs.size(); i++)
2990  {
2991  auto notReg = Builder_->create<circt::firrtl::NotPrimOp>(
2992  Builder_->getUnknownLoc(),
2993  circt::firrtl::IntType::get(Builder_->getContext(), false, 1),
2994  inputValidRegs[i].getResult());
2995  body->push_back(notReg);
2996  auto andOp = AddAndOp(body, notReg, prevAnd);
2997  prevAnd = andOp;
2998  }
2999  auto inBundle = GetPort(module, "i");
3000  auto inReady = GetSubfield(body, inBundle, "ready");
3001  Connect(body, inReady, prevAnd);
3002 
3003  // Create the valid signal for the output bundle
3004  prevAnd = oneBitValue;
3005  for (size_t i = 0; i < outputValidRegs.size(); i++)
3006  {
3007  auto andOp = AddAndOp(body, outputValidRegs[i].getResult(), prevAnd);
3008  prevAnd = andOp;
3009  }
3010  auto outBundle = GetPort(module, "o");
3011  auto outValid = GetSubfield(body, outBundle, "valid");
3012  Connect(body, outValid, prevAnd);
3013 
3014  // Connect output data signals
3015  for (size_t i = 0; i < outputDataRegs.size(); i++)
3016  {
3017  // don't generate ports for state edges
3018  if (reg_results[i]->Type()->Kind() == rvsdg::TypeKind::State)
3019  continue;
3020  auto outData = GetSubfield(body, outBundle, "data_" + std::to_string(i));
3021  Connect(body, outData, outputDataRegs[i].getResult());
3022  }
3023 
3024  if (inputValidRegs.size())
3025  { // avoid generating invalid firrtl for return of just a constant
3026  // Input when statement
3027  auto inValid = GetSubfield(body, inBundle, "valid");
3028  auto whenCondition = AddAndOp(body, inReady, inValid);
3029  auto whenOp = AddWhenOp(body, whenCondition, false);
3030 
3031  // getThenBlock() cause an error during commpilation
3032  // So we first get the builder and then its associated body
3033  auto thenBody = whenOp.getThenBodyBuilder().getBlock();
3034  for (size_t i = 0; i < inputValidRegs.size(); i++)
3035  {
3036  Connect(thenBody, inputValidRegs[i].getResult(), oneBitValue);
3037  // don't generate ports for state edges
3038  if (reg_args[i]->Type()->Kind() == rvsdg::TypeKind::State)
3039  continue;
3040  auto inData = GetSubfield(thenBody, inBundle, "data_" + std::to_string(i));
3041  Connect(thenBody, inputDataRegs[i].getResult(), inData);
3042  }
3043  }
3044 
3045  // Output when statement
3046  auto outReady = GetSubfield(body, outBundle, "ready");
3047  auto whenCondition = AddAndOp(body, outReady, outValid);
3048  auto whenOp = AddWhenOp(body, whenCondition, false);
3049  // getThenBlock() cause an error during commpilation
3050  // So we first get the builder and then its associated body
3051  auto thenBody = whenOp.getThenBodyBuilder().getBlock();
3052  for (size_t i = 0; i < outputValidRegs.size(); i++)
3053  {
3054  Connect(thenBody, outputValidRegs[i].getResult(), zeroBitValue);
3055  }
3056 
3057  // Connect the memory ports
3058  for (size_t i = 0; i < mem_reqs.size(); ++i)
3059  {
3060  auto mem_port = GetPort(module, "mem_" + std::to_string(i));
3061  auto mem_req = GetSubfield(body, mem_port, "req");
3062  auto mem_res = GetSubfield(body, mem_port, "res");
3063  auto inst_req = GetInstancePort(instance, "r" + std::to_string(mem_reqs[i]->index()));
3064  auto inst_res = GetInstancePort(instance, "a" + std::to_string(mem_resps[i]->index()));
3065  Connect(body, mem_req, inst_req);
3066  Connect(body, inst_res, mem_res);
3067  }
3068 
3069  // Add the module to the body of the circuit
3070  circuitBody->push_back(module);
3071 
3072  return circuit;
3073 }
3074 
3075 /*
3076  Helper functions
3077 */
3078 
3079 // Returns a PortInfo of ClockType
3080 void
3081 RhlsToFirrtlConverter::AddClockPort(::llvm::SmallVector<circt::firrtl::PortInfo> * ports)
3082 {
3083  struct circt::firrtl::PortInfo port = {
3084  Builder_->getStringAttr("clk"), circt::firrtl::ClockType::get(Builder_->getContext()),
3085  circt::firrtl::Direction::In, {},
3086  Builder_->getUnknownLoc(),
3087  };
3088  ports->push_back(port);
3089 }
3090 
3091 // Returns a PortInfo of unsigned IntType with width of 1
3092 void
3093 RhlsToFirrtlConverter::AddResetPort(::llvm::SmallVector<circt::firrtl::PortInfo> * ports)
3094 {
3095  struct circt::firrtl::PortInfo port = {
3096  Builder_->getStringAttr("reset"), circt::firrtl::IntType::get(Builder_->getContext(), false, 1),
3097  circt::firrtl::Direction::In, {},
3098  Builder_->getUnknownLoc(),
3099  };
3100  ports->push_back(port);
3101 }
3102 
3103 void
3104 RhlsToFirrtlConverter::AddMemReqPort(::llvm::SmallVector<circt::firrtl::PortInfo> * ports)
3105 {
3106  using BundleElement = circt::firrtl::BundleType::BundleElement;
3107 
3108  ::llvm::SmallVector<BundleElement> memReqElements;
3109  memReqElements.push_back(GetReadyElement());
3110  memReqElements.push_back(GetValidElement());
3111  memReqElements.push_back(BundleElement(
3112  Builder_->getStringAttr("addr"),
3113  false,
3114  circt::firrtl::IntType::get(Builder_->getContext(), false, GetPointerSizeInBits())));
3115  memReqElements.push_back(BundleElement(
3116  Builder_->getStringAttr("data"),
3117  false,
3118  circt::firrtl::IntType::get(Builder_->getContext(), false, 64)));
3119  memReqElements.push_back(BundleElement(
3120  Builder_->getStringAttr("write"),
3121  false,
3122  circt::firrtl::IntType::get(Builder_->getContext(), false, 1)));
3123  memReqElements.push_back(BundleElement(
3124  Builder_->getStringAttr("width"),
3125  false,
3126  circt::firrtl::IntType::get(Builder_->getContext(), false, 3)));
3127 
3128  auto memType = circt::firrtl::BundleType::get(Builder_->getContext(), memReqElements);
3129  struct circt::firrtl::PortInfo memBundle = {
3130  Builder_->getStringAttr("mem_req"), memType, circt::firrtl::Direction::Out, {},
3131  Builder_->getUnknownLoc(),
3132  };
3133  ports->push_back(memBundle);
3134 }
3135 
3136 void
3137 RhlsToFirrtlConverter::AddMemResPort(::llvm::SmallVector<circt::firrtl::PortInfo> * ports)
3138 {
3139  using BundleElement = circt::firrtl::BundleType::BundleElement;
3140 
3141  ::llvm::SmallVector<BundleElement> memResElements;
3142  memResElements.push_back(GetValidElement());
3143  memResElements.push_back(BundleElement(
3144  Builder_->getStringAttr("data"),
3145  false,
3146  circt::firrtl::IntType::get(Builder_->getContext(), false, 64)));
3147 
3148  auto memResType = circt::firrtl::BundleType::get(Builder_->getContext(), memResElements);
3149  struct circt::firrtl::PortInfo memResBundle = {
3150  Builder_->getStringAttr("mem_res"), memResType, circt::firrtl::Direction::In, {},
3151  Builder_->getUnknownLoc(),
3152  };
3153  ports->push_back(memResBundle);
3154 }
3155 
3156 void
3158  ::llvm::SmallVector<circt::firrtl::PortInfo> * ports,
3159  circt::firrtl::Direction direction,
3160  std::string name,
3161  circt::firrtl::FIRRTLBaseType type)
3162 {
3163  auto bundleType = GetBundleType(type);
3164  struct circt::firrtl::PortInfo bundle = {
3165  Builder_->getStringAttr(name), bundleType, direction, {}, Builder_->getUnknownLoc(),
3166  };
3167  ports->push_back(bundle);
3168 }
3169 
3170 circt::firrtl::BundleType
3171 RhlsToFirrtlConverter::GetBundleType(const circt::firrtl::FIRRTLBaseType & type)
3172 {
3173  using BundleElement = circt::firrtl::BundleType::BundleElement;
3174 
3175  ::llvm::SmallVector<BundleElement> elements;
3176  elements.push_back(this->GetReadyElement());
3177  elements.push_back(this->GetValidElement());
3178  elements.push_back(BundleElement(this->Builder_->getStringAttr("data"), false, type));
3179 
3180  return circt::firrtl::BundleType::get(this->Builder_->getContext(), elements);
3181 }
3182 
3183 circt::firrtl::SubfieldOp
3184 RhlsToFirrtlConverter::GetSubfield(mlir::Block * body, mlir::Value value, int index)
3185 {
3186  auto subfield =
3187  Builder_->create<circt::firrtl::SubfieldOp>(Builder_->getUnknownLoc(), value, index);
3188  body->push_back(subfield);
3189  return subfield;
3190 }
3191 
3192 circt::firrtl::SubfieldOp
3194  mlir::Block * body,
3195  mlir::Value value,
3196  ::llvm::StringRef fieldName)
3197 {
3198  auto subfield =
3199  Builder_->create<circt::firrtl::SubfieldOp>(Builder_->getUnknownLoc(), value, fieldName);
3200  body->push_back(subfield);
3201  return subfield;
3202 }
3203 
3204 mlir::BlockArgument
3205 RhlsToFirrtlConverter::GetPort(circt::firrtl::FModuleOp & module, std::string portName)
3206 {
3207  for (size_t i = 0; i < module.getNumPorts(); ++i)
3208  {
3209  if (module.getPortName(i) == portName)
3210  {
3211  return module.getArgument(i);
3212  }
3213  }
3214  llvm_unreachable("port not found");
3215 }
3216 
3217 mlir::OpResult
3218 RhlsToFirrtlConverter::GetInstancePort(circt::firrtl::InstanceOp & instance, std::string portName)
3219 {
3220  for (size_t i = 0; i < instance.getNumResults(); ++i)
3221  {
3222  // std::cout << instance.getPortName(i).str() << std::endl;
3223  if (instance.getPortName(i) == portName)
3224  {
3225  return instance->getResult(i);
3226  }
3227  }
3228  llvm_unreachable("port not found");
3229 }
3230 
3231 mlir::BlockArgument
3232 RhlsToFirrtlConverter::GetInPort(circt::firrtl::FModuleOp & module, size_t portNr)
3233 {
3234  return GetPort(module, "i" + std::to_string(portNr));
3235 }
3236 
3237 mlir::BlockArgument
3238 RhlsToFirrtlConverter::GetOutPort(circt::firrtl::FModuleOp & module, size_t portNr)
3239 {
3240  return GetPort(module, "o" + std::to_string(portNr));
3241 }
3242 
3243 void
3244 RhlsToFirrtlConverter::Connect(mlir::Block * body, mlir::Value sink, mlir::Value source)
3245 {
3246  body->push_back(
3247  Builder_->create<circt::firrtl::ConnectOp>(Builder_->getUnknownLoc(), sink, source));
3248 }
3249 
3250 circt::firrtl::BitsPrimOp
3251 RhlsToFirrtlConverter::AddBitsOp(mlir::Block * body, mlir::Value value, int high, int low)
3252 {
3253  auto intType = Builder_->getIntegerType(32);
3254  auto op = Builder_->create<circt::firrtl::BitsPrimOp>(
3255  Builder_->getUnknownLoc(),
3256  value,
3257  Builder_->getIntegerAttr(intType, high),
3258  Builder_->getIntegerAttr(intType, low));
3259  body->push_back(op);
3260  return op;
3261 }
3262 
3263 circt::firrtl::AndPrimOp
3264 RhlsToFirrtlConverter::AddAndOp(mlir::Block * body, mlir::Value first, mlir::Value second)
3265 {
3266  auto op = Builder_->create<circt::firrtl::AndPrimOp>(Builder_->getUnknownLoc(), first, second);
3267  body->push_back(op);
3268  return op;
3269 }
3270 
3271 circt::firrtl::NodeOp
3272 RhlsToFirrtlConverter::AddNodeOp(mlir::Block * body, mlir::Value value, std::string name)
3273 {
3274  auto op = Builder_->create<circt::firrtl::NodeOp>(Builder_->getUnknownLoc(), value, name);
3275  body->push_back(op);
3276  return op;
3277 }
3278 
3279 circt::firrtl::XorPrimOp
3280 RhlsToFirrtlConverter::AddXorOp(mlir::Block * body, mlir::Value first, mlir::Value second)
3281 {
3282  auto op = Builder_->create<circt::firrtl::XorPrimOp>(Builder_->getUnknownLoc(), first, second);
3283  body->push_back(op);
3284  return op;
3285 }
3286 
3287 circt::firrtl::OrPrimOp
3288 RhlsToFirrtlConverter::AddOrOp(mlir::Block * body, mlir::Value first, mlir::Value second)
3289 {
3290  auto op = Builder_->create<circt::firrtl::OrPrimOp>(Builder_->getUnknownLoc(), first, second);
3291  body->push_back(op);
3292  return op;
3293 }
3294 
3295 circt::firrtl::NotPrimOp
3296 RhlsToFirrtlConverter::AddNotOp(mlir::Block * body, mlir::Value first)
3297 {
3298  auto op = Builder_->create<circt::firrtl::NotPrimOp>(Builder_->getUnknownLoc(), first);
3299  body->push_back(op);
3300  return op;
3301 }
3302 
3303 circt::firrtl::AddPrimOp
3304 RhlsToFirrtlConverter::AddAddOp(mlir::Block * body, mlir::Value first, mlir::Value second)
3305 {
3306  auto op = Builder_->create<circt::firrtl::AddPrimOp>(Builder_->getUnknownLoc(), first, second);
3307  body->push_back(op);
3308  return op;
3309 }
3310 
3311 circt::firrtl::SubPrimOp
3312 RhlsToFirrtlConverter::AddSubOp(mlir::Block * body, mlir::Value first, mlir::Value second)
3313 {
3314  auto op = Builder_->create<circt::firrtl::SubPrimOp>(Builder_->getUnknownLoc(), first, second);
3315  body->push_back(op);
3316  return op;
3317 }
3318 
3319 circt::firrtl::MulPrimOp
3320 RhlsToFirrtlConverter::AddMulOp(mlir::Block * body, mlir::Value first, mlir::Value second)
3321 {
3322  auto op = Builder_->create<circt::firrtl::MulPrimOp>(Builder_->getUnknownLoc(), first, second);
3323  body->push_back(op);
3324  return op;
3325 }
3326 
3327 circt::firrtl::DivPrimOp
3328 RhlsToFirrtlConverter::AddDivOp(mlir::Block * body, mlir::Value first, mlir::Value second)
3329 {
3330  auto op = Builder_->create<circt::firrtl::DivPrimOp>(Builder_->getUnknownLoc(), first, second);
3331  body->push_back(op);
3332  return op;
3333 }
3334 
3335 circt::firrtl::DShrPrimOp
3336 RhlsToFirrtlConverter::AddDShrOp(mlir::Block * body, mlir::Value first, mlir::Value second)
3337 {
3338  auto op = Builder_->create<circt::firrtl::DShrPrimOp>(Builder_->getUnknownLoc(), first, second);
3339  body->push_back(op);
3340  return op;
3341 }
3342 
3343 circt::firrtl::DShlPrimOp
3344 RhlsToFirrtlConverter::AddDShlOp(mlir::Block * body, mlir::Value first, mlir::Value second)
3345 {
3346  auto op = Builder_->create<circt::firrtl::DShlPrimOp>(Builder_->getUnknownLoc(), first, second);
3347  body->push_back(op);
3348  return op;
3349 }
3350 
3351 circt::firrtl::RemPrimOp
3352 RhlsToFirrtlConverter::AddRemOp(mlir::Block * body, mlir::Value first, mlir::Value second)
3353 {
3354  auto op = Builder_->create<circt::firrtl::RemPrimOp>(Builder_->getUnknownLoc(), first, second);
3355  body->push_back(op);
3356  return op;
3357 }
3358 
3359 circt::firrtl::EQPrimOp
3360 RhlsToFirrtlConverter::AddEqOp(mlir::Block * body, mlir::Value first, mlir::Value second)
3361 {
3362  auto op = Builder_->create<circt::firrtl::EQPrimOp>(Builder_->getUnknownLoc(), first, second);
3363  body->push_back(op);
3364  return op;
3365 }
3366 
3367 circt::firrtl::NEQPrimOp
3368 RhlsToFirrtlConverter::AddNeqOp(mlir::Block * body, mlir::Value first, mlir::Value second)
3369 {
3370  auto op = Builder_->create<circt::firrtl::NEQPrimOp>(Builder_->getUnknownLoc(), first, second);
3371  body->push_back(op);
3372  return op;
3373 }
3374 
3375 circt::firrtl::GTPrimOp
3376 RhlsToFirrtlConverter::AddGtOp(mlir::Block * body, mlir::Value first, mlir::Value second)
3377 {
3378  auto op = Builder_->create<circt::firrtl::GTPrimOp>(Builder_->getUnknownLoc(), first, second);
3379  body->push_back(op);
3380  return op;
3381 }
3382 
3383 circt::firrtl::GEQPrimOp
3384 RhlsToFirrtlConverter::AddGeqOp(mlir::Block * body, mlir::Value first, mlir::Value second)
3385 {
3386  auto op = Builder_->create<circt::firrtl::GEQPrimOp>(Builder_->getUnknownLoc(), first, second);
3387  body->push_back(op);
3388  return op;
3389 }
3390 
3391 circt::firrtl::LTPrimOp
3392 RhlsToFirrtlConverter::AddLtOp(mlir::Block * body, mlir::Value first, mlir::Value second)
3393 {
3394  auto op = Builder_->create<circt::firrtl::LTPrimOp>(Builder_->getUnknownLoc(), first, second);
3395  body->push_back(op);
3396  return op;
3397 }
3398 
3399 circt::firrtl::LEQPrimOp
3400 RhlsToFirrtlConverter::AddLeqOp(mlir::Block * body, mlir::Value first, mlir::Value second)
3401 {
3402  auto op = Builder_->create<circt::firrtl::LEQPrimOp>(Builder_->getUnknownLoc(), first, second);
3403  body->push_back(op);
3404  return op;
3405 }
3406 
3407 circt::firrtl::MuxPrimOp
3409  mlir::Block * body,
3410  mlir::Value select,
3411  mlir::Value high,
3412  mlir::Value low)
3413 {
3414  auto op =
3415  Builder_->create<circt::firrtl::MuxPrimOp>(Builder_->getUnknownLoc(), select, high, low);
3416  body->push_back(op);
3417  return op;
3418 }
3419 
3420 circt::firrtl::AsSIntPrimOp
3421 RhlsToFirrtlConverter::AddAsSIntOp(mlir::Block * body, mlir::Value value)
3422 {
3423  auto op = Builder_->create<circt::firrtl::AsSIntPrimOp>(Builder_->getUnknownLoc(), value);
3424  body->push_back(op);
3425  return op;
3426 }
3427 
3428 circt::firrtl::AsUIntPrimOp
3429 RhlsToFirrtlConverter::AddAsUIntOp(mlir::Block * body, mlir::Value value)
3430 {
3431  auto op = Builder_->create<circt::firrtl::AsUIntPrimOp>(Builder_->getUnknownLoc(), value);
3432  body->push_back(op);
3433  return op;
3434 }
3435 
3436 circt::firrtl::PadPrimOp
3437 RhlsToFirrtlConverter::AddPadOp(mlir::Block * body, mlir::Value value, int amount)
3438 {
3439  auto op = Builder_->create<circt::firrtl::PadPrimOp>(Builder_->getUnknownLoc(), value, amount);
3440  body->push_back(op);
3441  return op;
3442 }
3443 
3444 circt::firrtl::CvtPrimOp
3445 RhlsToFirrtlConverter::AddCvtOp(mlir::Block * body, mlir::Value value)
3446 {
3447  auto op = Builder_->create<circt::firrtl::CvtPrimOp>(Builder_->getUnknownLoc(), value);
3448  body->push_back(op);
3449  return op;
3450 }
3451 
3452 circt::firrtl::WireOp
3453 RhlsToFirrtlConverter::AddWireOp(mlir::Block * body, std::string name, int size)
3454 {
3455  auto op =
3456  Builder_->create<circt::firrtl::WireOp>(Builder_->getUnknownLoc(), GetIntType(size), name);
3457  body->push_back(op);
3458  return op;
3459 }
3460 
3461 circt::firrtl::WhenOp
3462 RhlsToFirrtlConverter::AddWhenOp(mlir::Block * body, mlir::Value condition, bool elseStatement)
3463 {
3464  auto op =
3465  Builder_->create<circt::firrtl::WhenOp>(Builder_->getUnknownLoc(), condition, elseStatement);
3466  body->push_back(op);
3467  return op;
3468 }
3469 
3470 void
3472  mlir::Value value,
3473  ::llvm::SmallPtrSet<mlir::Value, 16> & forbiddenDependencies,
3474  ::llvm::SmallPtrSet<mlir::Value, 16> & visited)
3475 {
3476  if (visited.contains(value))
3477  {
3478  return;
3479  }
3480  visited.insert(value);
3481  if (forbiddenDependencies.contains(value))
3482  {
3483  throw util::Error("forbidden dependency detected");
3484  }
3485  auto op = value.getDefiningOp();
3486  // don't check anything for registers - connects don't count since they don't form combinatorial
3487  // circuits
3488  if (mlir::dyn_cast<circt::firrtl::RegResetOp>(op))
3489  {
3490  return;
3491  }
3492  else if (mlir::dyn_cast<circt::firrtl::RegOp>(op))
3493  {
3494  return;
3495  }
3496  // check uses because of connects
3497  for (auto & use : value.getUses())
3498  {
3499  auto * user = use.getOwner();
3500  if (auto connectOp = mlir::dyn_cast<circt::firrtl::ConnectOp>(user))
3501  {
3502  if (connectOp.getDest() == value)
3503  {
3504  check_may_not_depend_on(connectOp.getSrc(), forbiddenDependencies, visited);
3505  }
3506  }
3507  else
3508  {
3509  }
3510  }
3511  // stop at port level
3512  if (mlir::dyn_cast<circt::firrtl::SubfieldOp>(op))
3513  {
3514  return;
3515  }
3516  JLM_ASSERT(op->getNumResults() == 1);
3517  for (size_t i = 0; i < op->getNumOperands(); ++i)
3518  {
3519  check_may_not_depend_on(op->getOperand(i), forbiddenDependencies, visited);
3520  }
3521 }
3522 
3523 void
3525  ::llvm::SmallVector<mlir::Value> & oReadys,
3526  ::llvm::SmallVector<mlir::Value> & oValids)
3527 {
3528  ::llvm::SmallPtrSet<mlir::Value, 16> forbiddenDependencies(oReadys.begin(), oReadys.end());
3529  for (auto oValid : oValids)
3530  {
3531  ::llvm::SmallPtrSet<mlir::Value, 16> visited;
3532  check_may_not_depend_on(oValid, forbiddenDependencies, visited);
3533  }
3534 }
3535 
3536 void
3537 RhlsToFirrtlConverter::check_module(circt::firrtl::FModuleOp & module)
3538 {
3539  // check if module/node obeys ready/valid semantics at the circuit level
3540 
3541  // compile time: ovalid and odata may not depend on oready
3542  ::llvm::SmallVector<mlir::Value> oReadys;
3543  ::llvm::SmallVector<mlir::Value> oValids;
3544  ::llvm::SmallVector<mlir::Value> oDatas;
3545  for (size_t i = 0; i < module.getNumPorts(); ++i)
3546  {
3547  auto portName = module.getPortName(i);
3548  auto port = module.getArgument(i);
3549  if (portName.starts_with("o"))
3550  {
3551  // out port
3552  for (auto & use : port.getUses())
3553  {
3554  auto * user = use.getOwner();
3555  if (auto subfieldOp = mlir::dyn_cast<circt::firrtl::SubfieldOp>(user))
3556  {
3557  auto subfieldName =
3558  subfieldOp.getInput().getType().cast<circt::firrtl::BundleType>().getElementName(
3559  subfieldOp.getFieldIndex());
3560  if (subfieldName == "ready")
3561  {
3562  oReadys.push_back(subfieldOp);
3563  }
3564  else if (subfieldName == "valid")
3565  {
3566  oValids.push_back(subfieldOp);
3567  }
3568  else if (subfieldName == "data")
3569  {
3570  oDatas.push_back(subfieldOp);
3571  }
3572  }
3573  else
3574  {
3575  user->print(::llvm::outs());
3576  llvm_unreachable("unexpected GetOperation");
3577  }
3578  }
3579  }
3580  }
3581  check_oValids(oReadys, oValids);
3582  check_oValids(oReadys, oDatas);
3583 
3584 #ifdef FIRRTL_RUNTIME_ASSERTIONS
3585  // run time: valid/ready may not go down without firing once they are up - insert assertions
3586  auto body = &module.getBody().back();
3587  auto clock = GetClockSignal(module);
3588  auto reset = GetResetSignal(module);
3589  auto zeroBitValue = GetConstant(body, 1, 0);
3590  for (size_t i = 0; i < module.getNumPorts(); ++i)
3591  {
3592  auto portName = module.getPortName(i);
3593  auto port = module.getArgument(i);
3594  if (portName.starts_with("o") || portName.starts_with("i"))
3595  {
3596  auto ready = GetSubfield(body, port, "ready");
3597  auto valid = GetSubfield(body, port, "valid");
3598  auto data = GetSubfield(body, port, "data");
3599  if (data.getResult().getType().dyn_cast<circt::firrtl::BundleType>())
3600  {
3601  // skip memory ports
3602  continue;
3603  }
3604  auto fire = AddAndOp(body, ready, valid);
3605  auto prev_ready_reg = Builder_->create<circt::firrtl::RegResetOp>(
3606  Builder_->getUnknownLoc(),
3607  GetIntType(1),
3608  clock,
3609  reset,
3610  zeroBitValue,
3611  std::string(portName) + "_prev_ready_reg");
3612  body->push_back(prev_ready_reg);
3613  auto prev_valid_reg = Builder_->create<circt::firrtl::RegResetOp>(
3614  Builder_->getUnknownLoc(),
3615  GetIntType(1),
3616  clock,
3617  reset,
3618  zeroBitValue,
3619  std::string(portName) + "_prev_valid_reg");
3620  body->push_back(prev_valid_reg);
3621  auto prev_data_reg = Builder_->create<circt::firrtl::RegOp>(
3622  Builder_->getUnknownLoc(),
3623  data.getResult().getType(),
3624  clock,
3625  std::string(portName) + "_prev_data_reg");
3626  body->push_back(prev_data_reg);
3627  Connect(body, prev_ready_reg.getResult(), ready);
3628  Connect(body, prev_valid_reg.getResult(), valid);
3629  Connect(body, prev_data_reg.getResult(), data);
3630  auto fireBody = &AddWhenOp(body, fire, false).getThenBlock();
3631  Connect(fireBody, prev_ready_reg.getResult(), zeroBitValue);
3632  Connect(fireBody, prev_valid_reg.getResult(), zeroBitValue);
3633 
3634  auto valid_assert = Builder_->create<circt::firrtl::AssertOp>(
3635  Builder_->getUnknownLoc(),
3636  clock,
3637  AddNotOp(body, AddAndOp(body, prev_valid_reg.getResult(), AddNotOp(body, valid))),
3638  AddNotOp(body, reset),
3639  std::string(portName) + "_valid went down without firing",
3640  mlir::ValueRange(),
3641  std::string(portName) + "_valid_assert");
3642  body->push_back(valid_assert);
3643 
3644  auto ready_assert = Builder_->create<circt::firrtl::AssertOp>(
3645  Builder_->getUnknownLoc(),
3646  clock,
3647  AddNotOp(body, AddAndOp(body, prev_ready_reg.getResult(), AddNotOp(body, ready))),
3648  AddNotOp(body, reset),
3649  std::string(portName) + "_ready went down without firing",
3650  mlir::ValueRange(),
3651  std::string(portName) + "_ready_assert");
3652  body->push_back(ready_assert);
3653 
3654  auto data_assert = Builder_->create<circt::firrtl::AssertOp>(
3655  Builder_->getUnknownLoc(),
3656  clock,
3657  AddNotOp(
3658  body,
3659  AddAndOp(
3660  body,
3661  prev_valid_reg.getResult(),
3662  AddNeqOp(body, prev_data_reg.getResult(), data))),
3663  AddNotOp(body, reset),
3664  std::string(portName) + "_data changed without firing",
3665  mlir::ValueRange(),
3666  std::string(portName) + "_data_assert");
3667  body->push_back(data_assert);
3668  }
3669  }
3670 #endif // FIRRTL_RUNTIME_ASSERTIONS
3671 }
3672 
3673 circt::firrtl::InstanceOp
3674 RhlsToFirrtlConverter::AddInstanceOp(mlir::Block * circuitBody, jlm::rvsdg::Node * node)
3675 {
3676  auto name = GetModuleName(node);
3677  // Check if the module has already been instantiated else we need to generate it
3678  if (auto sn = dynamic_cast<rvsdg::SimpleNode *>(node))
3679  {
3680  if (!modules[name])
3681  {
3682  auto module = MlirGen(sn);
3683  if (circt::isa<circt::firrtl::FModuleOp>(module))
3684  check_module(circt::cast<circt::firrtl::FModuleOp>(module));
3685  modules[name] = module;
3686  circuitBody->push_back(module);
3687  }
3688  }
3689  else
3690  {
3691  auto ln = dynamic_cast<LoopNode *>(node);
3692  JLM_ASSERT(ln);
3693  auto module = MlirGen(ln, circuitBody);
3694  modules[name] = module;
3695  circuitBody->push_back(module);
3696  }
3697  // We increment a counter for each node that is instantiated
3698  // to assure the name is unique while still being relatively
3699  // easy to read (which helps when debugging).
3700  auto node_name = get_node_name(node);
3701  return Builder_->create<circt::firrtl::InstanceOp>(
3702  Builder_->getUnknownLoc(),
3703  modules[name],
3704  node_name);
3705 }
3706 
3707 circt::firrtl::ConstantOp
3708 RhlsToFirrtlConverter::GetConstant(mlir::Block * body, int size, int value)
3709 {
3710  auto intType = GetIntType(size);
3711  auto constant = Builder_->create<circt::firrtl::ConstantOp>(
3712  Builder_->getUnknownLoc(),
3713  intType,
3714  ::llvm::APInt(size, value));
3715  body->push_back(constant);
3716  return constant;
3717 }
3718 
3719 circt::firrtl::InvalidValueOp
3720 RhlsToFirrtlConverter::GetInvalid(mlir::Block * body, int size)
3721 {
3722 
3723  auto invalid =
3724  Builder_->create<circt::firrtl::InvalidValueOp>(Builder_->getUnknownLoc(), GetIntType(size));
3725  body->push_back(invalid);
3726  return invalid;
3727 }
3728 
3729 void
3730 RhlsToFirrtlConverter::ConnectInvalid(mlir::Block * body, mlir::Value value)
3731 {
3732 
3733  auto invalid =
3734  Builder_->create<circt::firrtl::InvalidValueOp>(Builder_->getUnknownLoc(), value.getType());
3735  body->push_back(invalid);
3736  return Connect(body, value, invalid);
3737 }
3738 
3739 // Get the clock signal in the module
3740 mlir::BlockArgument
3741 RhlsToFirrtlConverter::GetClockSignal(circt::firrtl::FModuleOp module)
3742 {
3743  auto clock = module.getArgument(0);
3744  auto ctype = clock.getType().cast<circt::firrtl::FIRRTLType>();
3745  if (!ctype.isa<circt::firrtl::ClockType>())
3746  {
3747  JLM_ASSERT("Not a ClockType");
3748  }
3749  return clock;
3750 }
3751 
3752 // Get the reset signal in the module
3753 mlir::BlockArgument
3754 RhlsToFirrtlConverter::GetResetSignal(circt::firrtl::FModuleOp module)
3755 {
3756  auto reset = module.getArgument(1);
3757  auto rtype = reset.getType().cast<circt::firrtl::FIRRTLType>();
3758  if (!rtype.isa<circt::firrtl::ResetType>())
3759  {
3760  JLM_ASSERT("Not a ResetType");
3761  }
3762  return reset;
3763 }
3764 
3765 circt::firrtl::BundleType::BundleElement
3767 {
3768  using BundleElement = circt::firrtl::BundleType::BundleElement;
3769 
3770  return BundleElement(
3771  Builder_->getStringAttr("ready"),
3772  true,
3773  circt::firrtl::IntType::get(Builder_->getContext(), false, 1));
3774 }
3775 
3776 circt::firrtl::BundleType::BundleElement
3778 {
3779  using BundleElement = circt::firrtl::BundleType::BundleElement;
3780 
3781  return BundleElement(
3782  Builder_->getStringAttr("valid"),
3783  false,
3784  circt::firrtl::IntType::get(Builder_->getContext(), false, 1));
3785 }
3786 
3787 void
3788 RhlsToFirrtlConverter::InitializeMemReq(circt::firrtl::FModuleOp module)
3789 {
3790  mlir::BlockArgument mem = GetPort(module, "mem_req");
3791  mlir::Block * body = module.getBodyBlock();
3792 
3793  auto zeroBitValue = GetConstant(body, 1, 0);
3794  auto invalid1 = GetInvalid(body, 1);
3795  auto invalid3 = GetInvalid(body, 3);
3796  auto invalidPtr = GetInvalid(body, GetPointerSizeInBits());
3797  auto invalid64 = GetInvalid(body, 64);
3798 
3799  auto memValid = GetSubfield(body, mem, "valid");
3800  auto memAddr = GetSubfield(body, mem, "addr");
3801  auto memData = GetSubfield(body, mem, "data");
3802  auto memWrite = GetSubfield(body, mem, "write");
3803  auto memWidth = GetSubfield(body, mem, "width");
3804 
3805  Connect(body, memValid, zeroBitValue);
3806  Connect(body, memAddr, invalidPtr);
3807  Connect(body, memData, invalid64);
3808  Connect(body, memWrite, invalid1);
3809  Connect(body, memWidth, invalid3);
3810 }
3811 
3812 // Takes a jlm::rvsdg::Node and creates a firrtl module with an input
3813 // bundle for each node input and output bundle for each node output
3814 // Returns a circt::firrtl::FModuleOp with an empty body
3815 circt::firrtl::FModuleOp
3817 {
3818  // Generate a vector with all inputs and outputs of the module
3819  ::llvm::SmallVector<circt::firrtl::PortInfo> ports;
3820 
3821  // Clock and reset ports
3822  AddClockPort(&ports);
3823  AddResetPort(&ports);
3824  // Input bundle port
3825  for (size_t i = 0; i < node->ninputs(); ++i)
3826  {
3827  std::string name("i");
3828  name.append(std::to_string(i));
3829  AddBundlePort(
3830  &ports,
3831  circt::firrtl::Direction::In,
3832  name,
3833  GetFirrtlType(node->input(i)->Type().get()));
3834  }
3835  for (size_t i = 0; i < node->noutputs(); ++i)
3836  {
3837  std::string name("o");
3838  name.append(std::to_string(i));
3839  AddBundlePort(
3840  &ports,
3841  circt::firrtl::Direction::Out,
3842  name,
3843  GetFirrtlType(node->output(i)->Type().get()));
3844  }
3845 
3846  if (mem)
3847  {
3848  AddMemReqPort(&ports);
3849  AddMemResPort(&ports);
3850  }
3851 
3852  // Creat a name for the module
3853  auto nodeName = GetModuleName(node);
3854  mlir::StringAttr name = Builder_->getStringAttr(nodeName);
3855  // Create the module
3856  return Builder_->create<circt::firrtl::FModuleOp>(
3857  Builder_->getUnknownLoc(),
3858  name,
3859  circt::firrtl::ConventionAttr::get(
3860  Builder_->getContext(),
3861  circt::firrtl::Convention::Internal),
3862  ports);
3863 }
3864 
3865 //
3866 // HLS only works with wires so all types are represented as unsigned integers
3867 //
3868 
3869 // Returns IntType of the specified width
3870 circt::firrtl::IntType
3872 {
3873  return circt::firrtl::IntType::get(Builder_->getContext(), false, size);
3874 }
3875 
3876 // Return unsigned IntType with the bit width specified by the
3877 // jlm::rvsdg::type. The extend argument extends the width of the IntType,
3878 // which is useful for, e.g., additions where the result has to be 1
3879 // larger than the operands to accommodate for the carry.
3880 circt::firrtl::IntType
3882 {
3883  return circt::firrtl::IntType::get(Builder_->getContext(), false, JlmSize(type) + extend);
3884 }
3885 
3886 circt::firrtl::FIRRTLBaseType
3888 {
3889  if (auto bt = dynamic_cast<const BundleType *>(type))
3890  {
3891  using BundleElement = circt::firrtl::BundleType::BundleElement;
3892  ::llvm::SmallVector<BundleElement> elements;
3893  for (size_t i = 0; i < bt->elements_.size(); ++i)
3894  {
3895  auto t = &bt->elements_.at(i);
3896  elements.push_back(
3897  BundleElement(Builder_->getStringAttr(t->first), false, GetFirrtlType(t->second.get())));
3898  }
3899  return circt::firrtl::BundleType::get(Builder_->getContext(), elements);
3900  }
3901  else
3902  {
3903  return GetIntType(type);
3904  }
3905 }
3906 
3907 std::string
3909 {
3910 
3911  std::string append = "";
3912  for (size_t i = 0; i < node->ninputs(); ++i)
3913  {
3914  append.append("_I");
3915  append.append(std::to_string(JlmSize(node->input(i)->Type().get())));
3916  append.append("W");
3917  }
3918  for (size_t i = 0; i < node->noutputs(); ++i)
3919  {
3920  append.append("_O");
3921  append.append(std::to_string(JlmSize(node->output(i)->Type().get())));
3922  append.append("W");
3923  }
3924  if (auto op = dynamic_cast<const llvm::GetElementPtrOperation *>(&node->GetOperation()))
3925  {
3926  const jlm::rvsdg::Type * pointeeType = &op->GetPointeeType();
3927  for (size_t i = 1; i < node->ninputs(); i++)
3928  {
3929  int bits = JlmSize(pointeeType);
3930  if (dynamic_cast<const jlm::rvsdg::BitType *>(pointeeType)
3931  || dynamic_cast<const llvm::FloatingPointType *>(pointeeType))
3932  {
3933  pointeeType = nullptr;
3934  }
3935  else if (auto arrayType = dynamic_cast<const llvm::ArrayType *>(pointeeType))
3936  {
3937  pointeeType = &arrayType->element_type();
3938  }
3939  else if (auto vectorType = dynamic_cast<const llvm::VectorType *>(pointeeType))
3940  {
3941  pointeeType = vectorType->Type().get();
3942  }
3943  else
3944  {
3945  throw std::logic_error(pointeeType->debug_string() + " pointer not implemented!");
3946  }
3947  int bytes = bits / 8;
3948  append.append("_");
3949  append.append(std::to_string(bytes));
3950  }
3951  }
3952  if (auto op = dynamic_cast<const MemoryRequestOperation *>(&node->GetOperation()))
3953  {
3954  auto loadTypes = op->GetLoadTypes();
3955  for (size_t i = 0; i < loadTypes->size(); i++)
3956  {
3957  auto loadType = loadTypes->at(i).get();
3958  int bitWidth = JlmSize(loadType);
3959  append.append("_");
3960  append.append(std::to_string(bitWidth));
3961  }
3962  }
3963  if (auto op = dynamic_cast<const LocalMemoryOperation *>(&node->GetOperation()))
3964  {
3965  append.append("_S");
3966  append.append(std::to_string(
3967  std::dynamic_pointer_cast<const llvm::ArrayType>(op->result(0))->nelements()));
3968  append.append("_L");
3969  size_t loads =
3970  rvsdg::TryGetOwnerNode<rvsdg::Node>(*node->output(0)->Users().begin())->noutputs();
3971  append.append(std::to_string(loads));
3972  append.append("_S");
3973  size_t stores =
3974  (rvsdg::TryGetOwnerNode<rvsdg::Node>(*node->output(1)->Users().begin())->ninputs() - 1
3975  - loads)
3976  / 2;
3977  append.append(std::to_string(stores));
3978  }
3979  if (dynamic_cast<const LoopOperation *>(&node->GetOperation()))
3980  {
3981  append.append("_");
3982  append.append(util::strfmt(node));
3983  }
3984  auto name = jlm::util::strfmt("op_", node->DebugString() + append);
3985  // Remove characters that are not valid in firrtl module names
3986  std::replace_if(name.begin(), name.end(), isForbiddenChar, '_');
3987  return name;
3988 }
3989 
3990 bool
3992 {
3993  for (const auto & pair : op)
3994  {
3995  if (pair.first != pair.second)
3996  return false;
3997  }
3998 
3999  return true;
4000 }
4001 
4002 // Used for debugging a module by wrapping it in a circuit and writing it to a file
4003 // Node is simply a convenience for generating the circuit name
4004 void
4006  const circt::firrtl::FModuleOp fModuleOp,
4007  const rvsdg::Node * node)
4008 {
4009  if (!fModuleOp)
4010  return;
4011 
4012  auto name = GetModuleName(node);
4013  auto moduleName = Builder_->getStringAttr(name);
4014 
4015  // Adde the fModuleOp to a circuit
4016  auto circuit = Builder_->create<circt::firrtl::CircuitOp>(Builder_->getUnknownLoc(), moduleName);
4017  auto body = circuit.getBodyBlock();
4018  body->push_back(fModuleOp);
4019 
4020  WriteCircuitToFile(circuit, name);
4021 }
4022 
4023 // Verifies the circuit and writes the FIRRTL to a file
4024 void
4025 RhlsToFirrtlConverter::WriteCircuitToFile(const circt::firrtl::CircuitOp circuit, std::string name)
4026 {
4027  // Add the circuit to a top module
4028  auto module = mlir::ModuleOp::create(Builder_->getUnknownLoc());
4029  module.push_back(circuit);
4030 
4031  // Verify the module
4032  if (failed(mlir::verify(module)))
4033  {
4034  module.emitError("module verification error");
4035  throw std::logic_error("Verification of firrtl failed");
4036  }
4037  // Print the FIRRTL IR
4038  module.print(::llvm::outs());
4039 
4040  // Write the module to file
4041  std::string fileName = name + extension();
4042  std::error_code EC;
4043  ::llvm::raw_fd_ostream output(fileName, EC);
4044  size_t targetLineLength = 100;
4045  auto status = circt::firrtl::exportFIRFile(module, output, targetLineLength, DefaultFIRVersion_);
4046 
4047  if (status.failed())
4048  {
4049  throw util::Error("Exporting of FIRRTL failed");
4050  }
4051 
4052  output.close();
4053  std::cout << "\nWritten firrtl to " << fileName << "\n";
4054 }
4055 
4056 std::string
4057 RhlsToFirrtlConverter::toString(const circt::firrtl::CircuitOp circuit)
4058 {
4059  // Add the circuit to a top module
4060  auto module = mlir::ModuleOp::create(Builder_->getUnknownLoc());
4061  module.push_back(circuit);
4062 
4063  // Verify the module
4064  if (failed(mlir::verify(module)))
4065  {
4066  module.emitError("module verification error");
4067  module.print(::llvm::outs());
4068  throw std::logic_error("Verification of firrtl failed");
4069  }
4070 
4071  // Export FIRRTL to string
4072  std::string outputString;
4073  ::llvm::raw_string_ostream output(outputString);
4074 
4075  size_t targetLineLength = 100;
4076  auto status = circt::firrtl::exportFIRFile(module, output, targetLineLength, DefaultFIRVersion_);
4077  if (status.failed())
4078  throw std::logic_error("Exporting of firrtl failed");
4079 
4080  return outputString;
4081 }
4082 
4083 circt::firrtl::FExtModuleOp
4085 {
4086  // Generate a vector with all inputs and outputs of the module
4087  ::llvm::SmallVector<circt::firrtl::PortInfo> ports;
4088 
4089  // Clock and reset ports
4090  AddClockPort(&ports);
4091  AddResetPort(&ports);
4092  // Input bundle port
4093  for (size_t i = 0; i < node->ninputs(); ++i)
4094  {
4095  std::string name("i");
4096  name.append(std::to_string(i));
4097  AddBundlePort(
4098  &ports,
4099  circt::firrtl::Direction::In,
4100  name,
4101  GetFirrtlType(node->input(i)->Type().get()));
4102  }
4103  for (size_t i = 0; i < node->noutputs(); ++i)
4104  {
4105  std::string name("o");
4106  name.append(std::to_string(i));
4107  AddBundlePort(
4108  &ports,
4109  circt::firrtl::Direction::Out,
4110  name,
4111  GetFirrtlType(node->output(i)->Type().get()));
4112  }
4113 
4114  // Creat a name for the module
4115  auto nodeName = GetModuleName(node);
4116  mlir::StringAttr name = Builder_->getStringAttr(nodeName);
4117  // Create the module
4118  return Builder_->create<circt::firrtl::FExtModuleOp>(
4119  Builder_->getUnknownLoc(),
4120  name,
4121  circt::firrtl::ConventionAttr::get(
4122  Builder_->getContext(),
4123  circt::firrtl::Convention::Internal),
4124  ports);
4125 }
4126 } // namespace jlm::hls
std::vector< rvsdg::RegionArgument * > get_reg_args(const rvsdg::LambdaNode &lambda)
Definition: base-hls.hpp:112
std::vector< rvsdg::RegionResult * > get_reg_results(const rvsdg::LambdaNode &lambda)
Definition: base-hls.hpp:130
std::vector< rvsdg::RegionResult * > get_mem_reqs(const rvsdg::LambdaNode &lambda)
Definition: base-hls.hpp:93
static std::string get_port_name(jlm::rvsdg::Input *port)
Definition: base-hls.cpp:62
void create_node_names(rvsdg::Region *r)
Definition: base-hls.cpp:116
static int JlmSize(const jlm::rvsdg::Type *type)
Definition: base-hls.cpp:110
std::unordered_map< jlm::rvsdg::Output *, std::string > output_map
Definition: base-hls.hpp:45
std::string get_node_name(const rvsdg::Node *node)
Definition: base-hls.cpp:28
std::vector< rvsdg::RegionArgument * > get_mem_resps(const rvsdg::LambdaNode &lambda)
Definition: base-hls.hpp:75
std::size_t Capacity() const noexcept
Definition: hls.hpp:406
bool IsConstant() const noexcept
Definition: hls.hpp:189
rvsdg::Region * subregion() const noexcept
Definition: hls.hpp:725
const std::vector< std::shared_ptr< const rvsdg::Type > > * GetLoadTypes() const
Definition: hls.hpp:1357
circt::firrtl::InstanceOp AddInstanceOp(mlir::Block *circuitBody, jlm::rvsdg::Node *node)
circt::firrtl::GEQPrimOp AddGeqOp(mlir::Block *body, mlir::Value first, mlir::Value second)
circt::firrtl::FModuleOp MlirGenBuffer(const jlm::rvsdg::SimpleNode *node)
mlir::BlockArgument GetPort(circt::firrtl::FModuleOp &module, std::string portName)
circt::firrtl::FModuleOp MlirGenHlsMemReq(const jlm::rvsdg::SimpleNode *node)
circt::firrtl::EQPrimOp AddEqOp(mlir::Block *body, mlir::Value first, mlir::Value second)
circt::firrtl::LTPrimOp AddLtOp(mlir::Block *body, mlir::Value first, mlir::Value second)
circt::firrtl::BitsPrimOp AddBitsOp(mlir::Block *body, mlir::Value value, int high, int low)
void check_module(circt::firrtl::FModuleOp &module)
circt::firrtl::XorPrimOp AddXorOp(mlir::Block *body, mlir::Value first, mlir::Value second)
circt::firrtl::FModuleOp MlirGenNDMux(const jlm::rvsdg::SimpleNode *node)
circt::firrtl::BundleType GetBundleType(const circt::firrtl::FIRRTLBaseType &type)
circt::firrtl::BitsPrimOp DropMSBs(mlir::Block *body, mlir::Value value, int amount)
circt::firrtl::WireOp AddWireOp(mlir::Block *body, std::string name, int size)
void AddMemResPort(::llvm::SmallVector< circt::firrtl::PortInfo > *ports)
circt::firrtl::RemPrimOp AddRemOp(mlir::Block *body, mlir::Value first, mlir::Value second)
std::unordered_map< std::string, circt::firrtl::FModuleLike > modules
circt::firrtl::FModuleOp MlirGenFork(const jlm::rvsdg::SimpleNode *node)
circt::firrtl::FModuleOp nodeToModule(const jlm::rvsdg::Node *node, bool mem=false)
void WriteModuleToFile(const circt::firrtl::FModuleOp fModuleOp, const rvsdg::Node *node)
rvsdg::Output * TraceStructuralOutput(rvsdg::StructuralOutput *out)
circt::firrtl::NodeOp AddNodeOp(mlir::Block *body, mlir::Value value, std::string name)
circt::firrtl::FModuleOp MlirGenTrigger(const jlm::rvsdg::SimpleNode *node)
circt::firrtl::DShrPrimOp AddDShrOp(mlir::Block *body, mlir::Value first, mlir::Value second)
circt::firrtl::NEQPrimOp AddNeqOp(mlir::Block *body, mlir::Value first, mlir::Value second)
circt::firrtl::FModuleOp MlirGenPrint(const jlm::rvsdg::SimpleNode *node)
circt::firrtl::FModuleOp MlirGenHlsMemResp(const jlm::rvsdg::SimpleNode *node)
circt::firrtl::BundleType::BundleElement GetReadyElement()
circt::firrtl::BundleType::BundleElement GetValidElement()
void AddResetPort(::llvm::SmallVector< circt::firrtl::PortInfo > *ports)
circt::firrtl::FModuleOp MlirGenStateGate(const jlm::rvsdg::SimpleNode *node)
const circt::firrtl::FIRVersion DefaultFIRVersion_
void InitializeMemReq(circt::firrtl::FModuleOp module)
circt::firrtl::FModuleOp MlirGenSink(const jlm::rvsdg::SimpleNode *node)
void Connect(mlir::Block *body, mlir::Value sink, mlir::Value source)
circt::firrtl::FModuleOp MlirGenHlsStore(const jlm::rvsdg::SimpleNode *node)
circt::firrtl::FIRRTLBaseType GetFirrtlType(const jlm::rvsdg::Type *type)
void AddMemReqPort(::llvm::SmallVector< circt::firrtl::PortInfo > *ports)
circt::firrtl::DShlPrimOp AddDShlOp(mlir::Block *body, mlir::Value first, mlir::Value second)
void WriteCircuitToFile(const circt::firrtl::CircuitOp circuit, std::string name)
bool IsIdentityMapping(const rvsdg::MatchOperation &op)
circt::firrtl::CircuitOp MlirGen(const rvsdg::LambdaNode *lamdaNode)
circt::firrtl::FModuleOp MlirGenSimpleNode(const jlm::rvsdg::SimpleNode *node)
mlir::BlockArgument GetResetSignal(circt::firrtl::FModuleOp module)
std::string toString(const circt::firrtl::CircuitOp circuit)
circt::firrtl::ConstantOp GetConstant(mlir::Block *body, int size, int value)
circt::firrtl::InvalidValueOp GetInvalid(mlir::Block *body, int size)
mlir::BlockArgument GetOutPort(circt::firrtl::FModuleOp &module, size_t portNr)
circt::firrtl::FModuleOp MlirGenBranch(const jlm::rvsdg::SimpleNode *node)
circt::firrtl::SubPrimOp AddSubOp(mlir::Block *body, mlir::Value first, mlir::Value second)
circt::firrtl::PadPrimOp AddPadOp(mlir::Block *body, mlir::Value value, int amount)
circt::firrtl::MuxPrimOp AddMuxOp(mlir::Block *body, mlir::Value select, mlir::Value high, mlir::Value low)
circt::firrtl::WhenOp AddWhenOp(mlir::Block *body, mlir::Value condition, bool elseStatment)
circt::firrtl::AsUIntPrimOp AddAsUIntOp(mlir::Block *body, mlir::Value value)
circt::firrtl::FModuleOp MlirGenPredicationBuffer(const jlm::rvsdg::SimpleNode *node)
circt::firrtl::NotPrimOp AddNotOp(mlir::Block *body, mlir::Value first)
circt::firrtl::FModuleOp MlirGenAddrQueue(const jlm::rvsdg::SimpleNode *node)
circt::firrtl::GTPrimOp AddGtOp(mlir::Block *body, mlir::Value first, mlir::Value second)
circt::firrtl::FExtModuleOp MlirGenExtModule(const jlm::rvsdg::SimpleNode *node)
void AddClockPort(::llvm::SmallVector< circt::firrtl::PortInfo > *ports)
circt::firrtl::FModuleOp MlirGenHlsDLoad(const jlm::rvsdg::SimpleNode *node)
circt::firrtl::FModuleOp MlirGenLoopConstBuffer(const jlm::rvsdg::SimpleNode *node)
circt::firrtl::OrPrimOp AddOrOp(mlir::Block *body, mlir::Value first, mlir::Value second)
circt::firrtl::IntType GetIntType(int size)
circt::firrtl::MulPrimOp AddMulOp(mlir::Block *body, mlir::Value first, mlir::Value second)
circt::firrtl::LEQPrimOp AddLeqOp(mlir::Block *body, mlir::Value first, mlir::Value second)
circt::firrtl::FModuleOp MlirGenHlsLocalMem(const jlm::rvsdg::SimpleNode *node)
std::string GetModuleName(const rvsdg::Node *node)
jlm::rvsdg::Output * TraceArgument(rvsdg::RegionArgument *arg)
std::unique_ptr<::mlir::OpBuilder > Builder_
circt::firrtl::AsSIntPrimOp AddAsSIntOp(mlir::Block *body, mlir::Value value)
circt::firrtl::FModuleOp MlirGenDMux(const jlm::rvsdg::SimpleNode *node)
circt::firrtl::DivPrimOp AddDivOp(mlir::Block *body, mlir::Value first, mlir::Value second)
circt::firrtl::FModuleOp MlirGenHlsLoad(const jlm::rvsdg::SimpleNode *node)
circt::firrtl::AddPrimOp AddAddOp(mlir::Block *body, mlir::Value first, mlir::Value second)
mlir::BlockArgument GetInPort(circt::firrtl::FModuleOp &module, size_t portNr)
circt::firrtl::CvtPrimOp AddCvtOp(mlir::Block *body, mlir::Value value)
circt::firrtl::SubfieldOp GetSubfield(mlir::Block *body, mlir::Value value, int index)
mlir::OpResult GetInstancePort(circt::firrtl::InstanceOp &instance, std::string portName)
mlir::BlockArgument GetClockSignal(circt::firrtl::FModuleOp module)
circt::firrtl::AndPrimOp AddAndOp(mlir::Block *body, mlir::Value first, mlir::Value second)
void ConnectInvalid(mlir::Block *body, mlir::Value value)
circt::firrtl::FModuleOp MlirGenMem(const jlm::rvsdg::SimpleNode *node)
void AddBundlePort(::llvm::SmallVector< circt::firrtl::PortInfo > *ports, circt::firrtl::Direction direction, std::string name, circt::firrtl::FIRRTLBaseType type)
Lambda operation.
Definition: lambda.hpp:30
const std::string & name() const noexcept
Definition: lambda.hpp:41
UndefValueOperation class.
Definition: operators.hpp:992
Output * origin() const noexcept
Definition: node.hpp:58
size_t index() const noexcept
Definition: node.hpp:52
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
Definition: node.hpp:67
Lambda node.
Definition: lambda.hpp:83
rvsdg::Region * subregion() const noexcept
Definition: lambda.hpp:138
LambdaOperation & GetOperation() const noexcept override
Definition: lambda.cpp:51
virtual const Operation & GetOperation() const noexcept=0
virtual std::string DebugString() const =0
NodeInput * input(size_t index) const noexcept
Definition: node.hpp:615
NodeOutput * output(size_t index) const noexcept
Definition: node.hpp:650
size_t ninputs() const noexcept
Definition: node.hpp:609
size_t noutputs() const noexcept
Definition: node.hpp:644
rvsdg::Region * region() const noexcept
Definition: node.cpp:151
UsersRange Users()
Definition: node.hpp:354
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
Definition: node.hpp:366
size_t index() const noexcept
Definition: node.hpp:274
Represents the argument of a region.
Definition: region.hpp:41
StructuralInput * input() const noexcept
Definition: region.hpp:69
StructuralOutput * output() const noexcept
Definition: region.hpp:149
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
RegionResult * result(size_t index) const noexcept
Definition: region.hpp:471
rvsdg::StructuralNode * node() const noexcept
Definition: region.hpp:369
size_t nresults() const noexcept
Definition: region.hpp:465
RegionArgument * argument(size_t index) const noexcept
Definition: region.hpp:437
size_t narguments() const noexcept
Definition: region.hpp:431
const SimpleOperation & GetOperation() const noexcept override
Definition: simple-node.cpp:48
std::string DebugString() const override
Definition: simple-node.cpp:79
NodeInput * input(size_t index) const noexcept
Definition: simple-node.hpp:82
NodeOutput * output(size_t index) const noexcept
Definition: simple-node.hpp:88
StructuralOutput * output(size_t index) const noexcept
StructuralInput * input(size_t index) const noexcept
StructuralNode * node() const noexcept
constexpr Type() noexcept
Definition: type.hpp:46
virtual std::string debug_string() const =0
size_type size() const noexcept
Iterator begin() noexcept
#define JLM_ASSERT(x)
Definition: common.hpp:16
void check_may_not_depend_on(mlir::Value value, ::llvm::SmallPtrSet< mlir::Value, 16 > &forbiddenDependencies, ::llvm::SmallPtrSet< mlir::Value, 16 > &visited)
void check_oValids(::llvm::SmallVector< mlir::Value > &oReadys, ::llvm::SmallVector< mlir::Value > &oValids)
bool isForbiddenChar(char c)
Definition: base-hls.cpp:16
size_t GetPointerSizeInBits()
Definition: hls.cpp:396
static std::string type(const Node *n)
Definition: view.cpp:255
@ State
Designate a state type.
static std::vector< jlm::rvsdg::Output * > operands(const Node *node)
Definition: node.hpp:1049
size_t ninputs(const rvsdg::Region *region) noexcept
Definition: region.cpp:682
static std::string strfmt(Args... args)
Definition: strfmt.hpp:35