Rewriting

Pattern Matching

Collect

opts := SpiralDefaults;
s := SumsRuleTree(RandomRuleTree(DFT(8), opts), opts);
c := CodeSums(s, opts);

Collect(s, Scat);               # get list of scatter operations
Set(Collect(s, Value)); # get all unique values

Simple Patterns

Collect(c, @(1, [add, sub, neg, mul])); # get all arith ops...
Collect(c, @(1, [add, sub, neg, mul], e->e.t=TReal)); #...on reals

List(Collect(s, @(1, ISum)), e->e.var); # all loop variables
Set(Collect(s, @@(1, Value,     # all values inside Blk objects
        (e, cx)->IsBound(cx.Blk) and Length(cx.Blk) > 0)));

Subtree Patterns

Collect(c, [deref, add, sub]);
Collect(c, [mul, @(1), sub]);
Collect(c, [mul, Value, ...]);
Collect(c, [mul, @(1), [sub, deref, @(2)]]);
Collect(c, [mul, @(1), [sub, @(2, deref, e->X in e.free()), @(3)]]);

Substitutions

SubstTopDown/SubstBottomUp

opts := SpiralDefaults;
c := CodeSums(SumsRuleTree(RandomRuleTree(DFT(8), opts), opts), opts);

# Ordered substitution: traversal order can matter greatly
SubstTopDown(Copy(c), @(1, Value, e->e.v=1), e->V(25));
SubstBottomUp(Copy(c), @(1, Value, e->e.v=1), e->V(-25));

Variable Substitutions

vars := Collect(c, @(1, var, e->e.t=TReal));    # all the real variables
SubstVars(Copy(c), rec((vars[1].id) := V(1.1)));        # substitute one

# record of assignment of consecutive numbers to all real variables
substrec := FoldR(Zip2(vars, [1..Length(vars)]),
        (a,b) -> CopyFields(a, rec((b[1].id) := V(b[2]))), rec());
SubstVars(Copy(c), substrec);   # substitute them

# loop unrolling example
i := Ind(4);
c2 := loop(i, 4, assign(nth(X, i), i)); # loop to be unrolled
chain(List(c2.range,    # chain of partially evaluated loop iterations
        i->SubstVars(Copy(c2.cmd), rec((c2.var.id) := V(i)))));

Rules

Simple Rules

Rule([neg, [neg, @1]], e -> @1.val);
Rule([add, Value, Value],
        e->Value.new(e.args[1].t, e.args[1].v + e.args[2].v));
Rule([im, [conj, @(1)]], x->-im(@(1).val));
Rule([IF, @(1), skip, skip], e -> skip());

Rule([RC, @(1, Compose)], e -> Compose(List(@(1).val.children(), RC)));
Rule([RC, @(1, Gath)], e -> Gath(fTensor(@(1).val.func, fId(2))));

Rule([Tensor, ..., @(1,O), ...], e -> O(Rows(e), Cols(e)));

Complex Rules

_v0none := @(0).target([ Value, noneExp ]).cond(
        (e) -> Cond(ObjId(e) = noneExp, true, isValueZero(e)));
_0noneOrZero :=(t) -> When(
        ObjId(@(0).val) = noneExp, noneExp(t), t.zero());
Rule([mul, ..., _v0none, ...], e -> _0noneOrZero(e.t));
Rule([@@(0,mul,(e,cx)->IsBound(cx.nth) and cx.nth<>[]), @(1), @(2,add)],
         e -> ApplyFunc(add, List(@(2).val.args, a->@(1).val * a)));
Rule( [im, [mul, [cxpack, @(1), @(2)], [conj, [cxpack,
        @(3).cond(x->x=@(1).val), @(4).cond(x->x=@(2).val)]]]],
        e -> e.t.zero() );

Associative Rules

Simple Rules

ARule(add, [ @(1,add) ], e -> @1.val.args);
ARule(fTensor,  [@(1, fTensor) ], e -> @(1).val.children());
ARule(fCompose, [@(1), fId ], e -> [@(1).val]);

ARule(Compose, [ @(1, Prm), @(2, Prm) ],
         e -> [ Prm(fCompose(@(2).val.func, @(1).val.func)) ])
ARule(Compose, [ @(1, Gath), @(2, [Gath, Prm]) ],
         e -> [ Gath(fCompose(@(2).val.func, @(1).val.func)) ]);

Complex Rules

ARule(leq, [@(1, Value, x->x.v<=0), [@(0,mul), @(2, Value, x->x.v>0),
        @(3,var,IsLoopIndex)]], e -> [@(0).val]);

ARule( Compose, [ @(1, [Prm, Scat, ScatAcc, Conj, ConjL, ConjR, ConjLR]),
        @(2, [RecursStep, Grp, BB, SUM, ISum, Data, COND]) ],
        e -> [ CopyFields(@(2).val, rec(
                         _children :=  List(@(2).val._children, c -> @(1).val * c),
                         dimensions := [Rows(@(1).val), Cols(@(2).val)] )) ]);

ARule(fCompose, [ @(1, L), [ @(3, fTensor),
        @(2).cond(e->range(e) = @(1).val.params[2] and domain(e)=1), ... ] ],
        e->[ fTensor(Copy(Drop(@(3).val.children(), 1)), Copy(@(2).val)) ]

Rule Sets

Define a Rule Set

# spiral-core\namespaces\spiral\code\sreduce.gi
Class(RulesStrengthReduce, RuleSet);
RewriteRules(RulesStrengthReduce, rec(
         leq_single := Rule([leq, @(1)], e-> V_true),
         add_assoc  := ARule(add, [ @(1,add) ], e -> @1.val.args),
# hundreds of rules
));

Add Rules to Existing Rule Set

# somewhere else in the source code
RewriteRules(RulesStrengthReduce, rec(  # add more rules
        logic_single := Rule([@(1, [logic_and, logic_or]), @1], e->@1.val)
));

Using Rule Sets

RulesStrengthReduce.rules.leq_single;
opts := SpiralDefaults;
s := SPLRuleTree(RandomRuleTree(DFT(8), opts)).sums();
s := Rewrite(s, RulesSums, opts);
s := Rewrite(s, RulesDiag, opts);
s := RulesDiagStandalone(s);

Rule Strategies

Define a Rule Strategy

# spiral-core\namespaces\spiral\code\sreduce.gi
LibStrategy := [ StandardSumsRules, HfuncSumsRules ];

Combining Rule Sets

StandardSumsRules := MergedRuleSet(
        RulesSums, RulesFuncSimp, RulesDiag, RulesDiagStandalone,
        RulesStrengthReduce, RulesRC, RulesII, OLRules
);

Use of Rule Strategies

SpiralDefaults.formulaStrategies.sigmaSpl;
SpiralDefaults.formulaStrategies.rc;
SReduce := (c,opts) ->  # handy shortcut
        ApplyStrategy(c, [RulesStrengthReduce], BUA, opts);

opts := SpiralDefaults;
s := SumsRuleTree(RandomRuleTree(DFT(4), opts), opts);
c := DefaultCodegen(s, Y, X, opts);
c := SubstTopDown(c, @(1, loop), e->e.unroll());
Collect(c, mul);
c := SReduce(c, opts);
Collect(c, mul);

General Mechanics

Interface for Rewriting

# spiral-core\namespaces\spiral\code\ir.gi
Class(deref, nth, rec(
        __call__ := (self, loc) >> Inherited(loc, TInt.value(0)),
        rChildren := self >> [self.loc],
        rSetChild := rSetChildFields("loc"),
));
deref.from_rChildren;

Unifying Interface Across All Rewritable Objects

c := deref(X);  # code objects
c.rChildren();  # get rewritable fields
c.rSetChild(1, Y);      # change a rewritable field

DFT.rChildren;          # Transform level
RandomRuleTree(DFT(4), SpiralDefaults).rChildren;       # ruletree level
F.rChildren;            # SPL level
L.rChildren;            # Permutations
Gath.rChildren; # Sigma-SPL
Lambda.rChildren;       # Lambda function
fId.rChildren;          # symbolic functions
add.rChildren;          # expressions
T_Real.rChildren;       # data types

Implementing Recursive Descent: Visitors

Simple Example

# spiral-core\namespaces\spiral\rewrite\visitor.gi
Class(LispGen, Visitor, rec(
        add := (self, o) >> Print("(+ ", self(o.args[1]), " ",
                self(o.args[2]), ")"),
        mul := (self, o) >> Print("(* ", self(o.args[1]), " ",
                self(o.args[2]), ")"),
        sub := (self, o) >> Print("(- ", self(o.args[1]), " ",
                self(o.args[2]), ")"),
        var := (self, o) >> Print("(var ", o.id, ")"),
        Value := (self, o) >> Print("(value ", o.v, ")")
));
LispGen(4*X+2);

Visitors Used in Standard Translation Flow

opts := SpiralDefaults;
opts.sumsgen;
DefaultSumsGen;
opts.codegen;
DefaultCodegen;
opts.unparser;
CUnparser;

SumsGen

The DefaultSumsGen Visitor

# spiral-core\namespaces\spiral\sigma\sumsgen.gi
# all the fields
Filtered(RecFields(DefaultSumsGen), i->not IsSystemRecField(i));
Print(DefaultSumsGen.__call__);

# the recursive definitions needed for DFT
DefaultSumsGen.Compose;
Print(DefaultSumsGen.Tensor);
DefaultSumsGen.I;
DefaultSumsGen.F;
DefaultSumsGen.Diag;
DefaultSumsGen.L;

Visitors Used in Standard Translation Flow

opts := SpiralDefaults;
s := SPLRuleTree(RandomRuleTree(DFT(8), opts));
SumsSPL(s, opts);
opts.sumsgen(s, opts);

# legacy and backwards compatibility framework
F(2).sums();
Tensor(F(2), I(2)).sums();

CodeGen

The DefaultCodegen Visitor

# spiral-core\namespaces\spiral\compiler\codegen.gi
# all the fields
Filtered(RecFields(DefaultCodegen), i->not IsSystemRecField(i));
Print(DefaultCodegen.__call__);

# Some of the fields
Print(DefaultCodegen.Formula);
Print(DefaultCodegen.Compose);
DefaultCodegen.ISum;
Print(DefaultCodegen.Gath);
Print(DefaultCodegen.Scat);
DefaultCodegen.Diag;

Visitors Used in Standard Translation Flow

opts := SpiralDefaults;
s := SumsRuleTree(RandomRuleTree(DFT(8), opts), opts);
# only translate Sigma-SPL to icode
opts.codegen(s, Y, X, opts);
# also invoke the basic block compiler
opts.codegen(Formula(s), Y, X, opts);

C Pretty Printer

The CUnparser Visitor

# spiral-core\namespaces\spiral\compiler\unparse.gi
# all the fields
Filtered(RecFields(CUnparser), i->not IsSystemRecField(i));
Filtered(RecFields(CUnparserBase), i->not IsSystemRecField(i));
Print(CUnparser.gen);

# Some of the fields
Print(CUnparser.loop);
CUnparser.deref;
CUnparser.add;
CUnparser.Value;
Print(CUnparser.decl);
CUnparser.chain;

Visitors Used in Standard Translation Flow

opts := SpiralDefaults;
c := CodeRuleTree(RandomRuleTree(DFT(8), opts), opts);
# Print full header etc.
PrintCode("dft8", c, opts);
# unparser needs opts as context
Unparse(c.cmds[1].cmds[2].cmd,
        CopyFields(CUnparser, rec(opts:=opts)), 0, 1);