A while ago, I wrote about how Chisel’s early optimization breaks common heuristics used by other tools to detect finite state machine patterns. Here is its workaround.

Instead of following Chisel’s “advised way”, you need to wrap every next state expression with dontTouch optimization barrier. Some wrappers are helpful in making this step more elegant, such as the following MuxDontTouch:

import chisel3._

object MuxDontTouch {
  def apply[T <: Data](cond: Bool, con: T, alt: T): T = {
    val conWire = WireInit(con)
    val altWire = WireInit(alt)
    Mux(cond, dontTouch(conWire), dontTouch(altWire))
  }
}

Since CIRCT won’t be happy if you apply dontTouch annotations to constants, two Wires are added to ensure these branches being hardware.

Let’s take the 5-bit burst detector as an example once more. Make the following edits to apply the workaround:

diff --git a/Top.old.scala b/Top.scala
index 1e66954..2df2a1a 100644
--- a/Top.old.scala
+++ b/Top.scala
@@ -1,6 +1,14 @@
 import chisel3._
 import chisel3.util._
 
+object MuxDontTouch {
+  def apply[T <: Data](cond: Bool, con: T, alt: T): T = {
+    val conWire = WireInit(con)
+    val altWire = WireInit(alt)
+    Mux(cond, dontTouch(conWire), dontTouch(altWire))
+  }
+}
+
 class Burst5Detector extends Module {
   class Port extends Bundle {
     val din  = Input(Bool())
@@ -21,12 +29,12 @@ class Burst5Detector extends Module {
   private val y = RegInit(S_Idle)
   y := MuxLookup(y, S_Idle)(
     Seq(
-      S_Idle  -> Mux(io.din, S_1High, S_Idle),
-      S_1High -> Mux(io.din, S_2High, S_Idle),
-      S_2High -> Mux(io.din, S_3High, S_Idle),
-      S_3High -> Mux(io.din, S_4High, S_Idle),
-      S_4High -> Mux(io.din, S_5High, S_Idle),
-      S_5High -> Mux(io.din, S_5High, S_Idle)
+      S_Idle  -> MuxDontTouch(io.din, S_1High, S_Idle),
+      S_1High -> MuxDontTouch(io.din, S_2High, S_Idle),
+      S_2High -> MuxDontTouch(io.din, S_3High, S_Idle),
+      S_3High -> MuxDontTouch(io.din, S_4High, S_Idle),
+      S_4High -> MuxDontTouch(io.din, S_5High, S_Idle),
+      S_5High -> MuxDontTouch(io.din, S_5High, S_Idle)
     )
   )

After elaboration, CIRCT generates beautiful FSM state transitions with all-const leaves.

// Generated by CIRCT firtool-1.77.0

// ...
module Burst5Detector(
  input  clock,
         reset,
         io_din,
  output io_dout
);

  wire [2:0] y_conWire_3 = 3'h4;
  wire [2:0] y_conWire_4 = 3'h5;
  wire [2:0] y_conWire_5 = 3'h5;
  wire [2:0] y_conWire = 3'h1;
  wire [2:0] y_conWire_1 = 3'h2;
  wire [2:0] y_conWire_2 = 3'h3;
  wire [2:0] y_altWire = 3'h0;
  wire [2:0] y_altWire_1 = 3'h0;
  wire [2:0] y_altWire_2 = 3'h0;
  wire [2:0] y_altWire_3 = 3'h0;
  wire [2:0] y_altWire_4 = 3'h0;
  wire [2:0] y_altWire_5 = 3'h0;
  reg  [2:0] y;
  reg  [2:0] casez_tmp;
  always_comb begin
    casez (y)
      3'b000:
        casez_tmp = io_din ? y_conWire : y_altWire;
      3'b001:
        casez_tmp = io_din ? y_conWire_1 : y_altWire_1;
      3'b010:
        casez_tmp = io_din ? y_conWire_2 : y_altWire_2;
      3'b011:
        casez_tmp = io_din ? y_conWire_3 : y_altWire_3;
      3'b100:
        casez_tmp = io_din ? y_conWire_4 : y_altWire_4;
      3'b101:
        casez_tmp = io_din ? y_conWire_5 : y_altWire_5;
      3'b110:
        casez_tmp = 3'h0;
      default:
        casez_tmp = 3'h0;
    endcase
  end // always_comb
  always @(posedge clock) begin
    if (reset)
      y <= 3'h0;
    else
      y <= casez_tmp;
  end // always @(posedge)
  // ...
  assign io_dout = y == 3'h5;
endmodule
// ...

And Vivado is also happy about this change!

INFO: [Synth 8-802] inferred FSM for state register 'y_reg' in module 'Burst5Detector'
---------------------------------------------------------------------------------------------------
                   State |                     New Encoding |                Previous Encoding 
---------------------------------------------------------------------------------------------------
                 iSTATE4 |                           000001 |                              000
                 iSTATE0 |                           000010 |                              001
                 iSTATE1 |                           000100 |                              010
                 iSTATE2 |                           001000 |                              011
                 iSTATE3 |                           010000 |                              100
                 iSTATE5 |                           100000 |                              101
---------------------------------------------------------------------------------------------------
INFO: [Synth 8-3354] encoded FSM with state register 'y_reg' using encoding 'one-hot' in module 'Burst5Detector'
---------------------------------------------------------------------------------
Finished RTL Optimization Phase 2 : Time (s): cpu = 00:00:18 ; elapsed = 00:00:20 . Memory (MB): peak = 1577.703 ; gain = 618.145
---------------------------------------------------------------------------------

So, for now, I will start wrapping every FSM with these helper objects. After all, the synthesizer can optimize as well as (and even better than) source-level compilers!

After story. If you don’t care about names of generated wires, you can simplify the wrapper objects into one single line.

import chisel3._

object MuxDontTouch {
  def apply[T <: Data](cond: Bool, con: T, alt: T): T = {
    Mux(cond, dontTouch(WireInit(con)), dontTouch(WireInit(alt)))
  }
}

Writing other wrappers for MuxLookups, MuxCases, and MuxCase1Hs is rather straightforward thanks to Scala’s complete functional facilities. For example, this is a MuxCaseDontTouch:

import chisel3._

object MuxCaseDontTouch {
  def apply[T <: Data](default: T, mapping: Seq[(Bool, T)]): T = {
    MuxCase(
      dontTouch(WireInit(default)),
      mapping.map((pair) => pair._1 -> dontTouch(WireInit(pair._2))))
  }
}