Finite state machine pitfalls with Chisel: Revisited
A while ago, I wrote about how Chisel’s early optimization breaks common heuristics used by other tools to detect finite state machine patterns. Here is its workaround.
Instead of following Chisel’s “advised way”, you need to wrap every next state expression with dontTouch
optimization barrier. Some wrappers are helpful in making this step more elegant, such as the following MuxDontTouch
:
import chisel3._
object MuxDontTouch {
def apply[T <: Data](cond: Bool, con: T, alt: T): T = {
val conWire = WireInit(con)
val altWire = WireInit(alt)
Mux(cond, dontTouch(conWire), dontTouch(altWire))
}
}
Since CIRCT won’t be happy if you apply dontTouch
annotations to constants, two Wire
s are added to ensure these branches being hardware.
Let’s take the 5-bit burst detector as an example once more. Make the following edits to apply the workaround:
diff --git a/Top.old.scala b/Top.scala
index 1e66954..2df2a1a 100644
--- a/Top.old.scala
+++ b/Top.scala
@@ -1,6 +1,14 @@
import chisel3._
import chisel3.util._
+object MuxDontTouch {
+ def apply[T <: Data](cond: Bool, con: T, alt: T): T = {
+ val conWire = WireInit(con)
+ val altWire = WireInit(alt)
+ Mux(cond, dontTouch(conWire), dontTouch(altWire))
+ }
+}
+
class Burst5Detector extends Module {
class Port extends Bundle {
val din = Input(Bool())
@@ -21,12 +29,12 @@ class Burst5Detector extends Module {
private val y = RegInit(S_Idle)
y := MuxLookup(y, S_Idle)(
Seq(
- S_Idle -> Mux(io.din, S_1High, S_Idle),
- S_1High -> Mux(io.din, S_2High, S_Idle),
- S_2High -> Mux(io.din, S_3High, S_Idle),
- S_3High -> Mux(io.din, S_4High, S_Idle),
- S_4High -> Mux(io.din, S_5High, S_Idle),
- S_5High -> Mux(io.din, S_5High, S_Idle)
+ S_Idle -> MuxDontTouch(io.din, S_1High, S_Idle),
+ S_1High -> MuxDontTouch(io.din, S_2High, S_Idle),
+ S_2High -> MuxDontTouch(io.din, S_3High, S_Idle),
+ S_3High -> MuxDontTouch(io.din, S_4High, S_Idle),
+ S_4High -> MuxDontTouch(io.din, S_5High, S_Idle),
+ S_5High -> MuxDontTouch(io.din, S_5High, S_Idle)
)
)
After elaboration, CIRCT generates beautiful FSM state transitions with all-const leaves.
// Generated by CIRCT firtool-1.77.0
// ...
module Burst5Detector(
input clock,
reset,
io_din,
output io_dout
);
wire [2:0] y_conWire_3 = 3'h4;
wire [2:0] y_conWire_4 = 3'h5;
wire [2:0] y_conWire_5 = 3'h5;
wire [2:0] y_conWire = 3'h1;
wire [2:0] y_conWire_1 = 3'h2;
wire [2:0] y_conWire_2 = 3'h3;
wire [2:0] y_altWire = 3'h0;
wire [2:0] y_altWire_1 = 3'h0;
wire [2:0] y_altWire_2 = 3'h0;
wire [2:0] y_altWire_3 = 3'h0;
wire [2:0] y_altWire_4 = 3'h0;
wire [2:0] y_altWire_5 = 3'h0;
reg [2:0] y;
reg [2:0] casez_tmp;
always_comb begin
casez (y)
3'b000:
casez_tmp = io_din ? y_conWire : y_altWire;
3'b001:
casez_tmp = io_din ? y_conWire_1 : y_altWire_1;
3'b010:
casez_tmp = io_din ? y_conWire_2 : y_altWire_2;
3'b011:
casez_tmp = io_din ? y_conWire_3 : y_altWire_3;
3'b100:
casez_tmp = io_din ? y_conWire_4 : y_altWire_4;
3'b101:
casez_tmp = io_din ? y_conWire_5 : y_altWire_5;
3'b110:
casez_tmp = 3'h0;
default:
casez_tmp = 3'h0;
endcase
end // always_comb
always @(posedge clock) begin
if (reset)
y <= 3'h0;
else
y <= casez_tmp;
end // always @(posedge)
// ...
assign io_dout = y == 3'h5;
endmodule
// ...
And Vivado is also happy about this change!
INFO: [Synth 8-802] inferred FSM for state register 'y_reg' in module 'Burst5Detector'
---------------------------------------------------------------------------------------------------
State | New Encoding | Previous Encoding
---------------------------------------------------------------------------------------------------
iSTATE4 | 000001 | 000
iSTATE0 | 000010 | 001
iSTATE1 | 000100 | 010
iSTATE2 | 001000 | 011
iSTATE3 | 010000 | 100
iSTATE5 | 100000 | 101
---------------------------------------------------------------------------------------------------
INFO: [Synth 8-3354] encoded FSM with state register 'y_reg' using encoding 'one-hot' in module 'Burst5Detector'
---------------------------------------------------------------------------------
Finished RTL Optimization Phase 2 : Time (s): cpu = 00:00:18 ; elapsed = 00:00:20 . Memory (MB): peak = 1577.703 ; gain = 618.145
---------------------------------------------------------------------------------
So, for now, I will start wrapping every FSM with these helper objects. After all, the synthesizer can optimize as well as (and even better than) source-level compilers!
After story. If you don’t care about names of generated wires, you can simplify the wrapper objects into one single line.
import chisel3._
object MuxDontTouch {
def apply[T <: Data](cond: Bool, con: T, alt: T): T = {
Mux(cond, dontTouch(WireInit(con)), dontTouch(WireInit(alt)))
}
}
Writing other wrappers for MuxLookup
s, MuxCase
s, and MuxCase1H
s is rather straightforward thanks to Scala’s complete functional facilities. For example, this is a MuxCaseDontTouch
:
import chisel3._
object MuxCaseDontTouch {
def apply[T <: Data](default: T, mapping: Seq[(Bool, T)]): T = {
MuxCase(
dontTouch(WireInit(default)),
mapping.map((pair) => pair._1 -> dontTouch(WireInit(pair._2))))
}
}