refactor: Index notation

This commit is contained in:
jstoobysmith 2024-10-08 07:26:23 +00:00
parent 341aea19c6
commit e5116d152c
3 changed files with 75 additions and 97 deletions

View file

@ -17,6 +17,11 @@ The first character `ᵘ` specifies the color of the index, and the subsequent c
Strings of indices e.g. `ᵘ¹²ᵤ₄₃`` are defined elsewhere.
## Note
Index notation is currently being refactored. Much of the content here will likely be replaced
or removed.
-/
open Lean

View file

@ -34,7 +34,7 @@ instance : Group S.G := S.G_group
end TensorStruct
inductive TensorTree (S : TensorStruct) : ∀ {n : }, (Fin n → S.C) → Type where
| tensorNode {n : } {c : Fin n → S.C} : S.F.obj (OverColor.mk c) TensorTree S c
| tensorNode {n : } {c : Fin n → S.C} (T: S.F.obj (OverColor.mk c)): TensorTree S c
| add {n : } {c : Fin n → S.C} : TensorTree S c → TensorTree S c → TensorTree S c
| perm {n m : } {c : Fin n → S.C} {c1 : Fin m → S.C}
(σ : (OverColor.mk c) ⟶ (OverColor.mk c1)) (t : TensorTree S c) : TensorTree S c1
@ -48,8 +48,8 @@ inductive TensorTree (S : TensorStruct) : ∀ {n : }, (Fin n → S.C) → Typ
(j : Fin n.succ) → TensorTree S c → TensorTree S (c ∘ Fin.succAbove i ∘ Fin.succAbove j)
| jiggle {n : } {c : Fin n → S.C} : (i : Fin n) → TensorTree S c →
TensorTree S (Function.update c i (S.τ (c i)))
| eval {n : } {c : Fin n.succ → S.C} : TensorTree S c →
(i : Fin n.succ) → (x : Fin (S.evalNo (c i))) →
| eval {n : } {c : Fin n.succ → S.C} :
(i : Fin n.succ) → (x : Fin (S.evalNo (c i))) → TensorTree S c →
TensorTree S (c ∘ Fin.succAbove i)
namespace TensorTree
@ -68,7 +68,7 @@ def size : ∀ {n : } {c : Fin n → S.C}, TensorTree S c → := fun
| mult _ _ t1 t2 => t1.size + t2.size + 1
| contr _ _ t => t.size + 1
| jiggle _ t => t.size + 1
| eval t _ _ => t.size + 1
| eval _ _ t => t.size + 1
noncomputable section

View file

@ -9,7 +9,7 @@ import Lean.Elab.Term
## Elaboration of tensor trees
This file turns
This file turns tensor expressions into tensor trees.
-/
open Lean
@ -49,6 +49,21 @@ def indexToNum (stx : Syntax) : TermElabM Nat :=
| _ =>
throwError "Unsupported tensor expression syntax (indexToNum): {stx}"
def indexToIdent (stx : Syntax) : TermElabM Ident :=
match stx with
| `(indexExpr|$a:ident) => return a
| `(indexExpr| τ($a:ident)) => return a
| _ =>
throwError "Unsupported tensor expression syntax (indexToIdent): {stx}"
def indexPosEq (a b : × TSyntax `indexExpr) : TermElabM (Option ( × )) := do
let a' ← indexToIdent a.2
let b' ← indexToIdent b.2
if a.1 < b.1 ∧ Lean.TSyntax.getId a' = Lean.TSyntax.getId b' then
return some (a.1, b.1)
else
return none
def indexToDual (stx : Syntax) : Bool :=
match stx with
| `(indexExpr| τ($_)) => true
@ -60,7 +75,7 @@ def indexToDual (stx : Syntax) : Bool :=
-/
declare_syntax_cat tensorExpr
syntax term (ppSpace indexExpr)* : tensorExpr
syntax term "|" (ppSpace indexExpr)* : tensorExpr
syntax tensorExpr "⊗" tensorExpr : tensorExpr
@ -70,127 +85,85 @@ syntax "(" tensorExpr ")" : tensorExpr
## For tensor nodes.
The operations are done in the following order:
- evaluation.
- dualization.
- contraction.
-/
namespace TensorNode
/-- The indices of a tensor node. Before contraction, dualisation, and evaluation. -/
partial def getIndicesNode (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $_:ident $[$args]*) => do
| `(tensorExpr| $_:term | $[$args]*) => do
let indices ← args.toList.mapM fun arg => do
match arg with
| `(indexExpr|$t:indexExpr) => pure t
return indices
| _ =>
throwError "Unsupported tensor expression syntax: {stx}"
throwError "Unsupported tensor expression syntax (getIndicesNode): {stx}"
/-- The positions in getIndicesNode which get evaluated, and the value they take. -/
partial def getEvalPos (stx : Syntax) : TermElabM (List ( × )) := do
let ind ← getIndicesNode stx
let indEnum := ind.enum
let evals := indEnum.filter (fun x => indexExprIsNum x.2)
println! "indEnum: {evals}"
let evals2 ← (evals.mapM (fun x => indexToNum x.2))
return List.zip (evals.map (fun x => x.1)) evals2
def evalSyntax (l : List ( × )) (T : Term) : Term :=
l.foldl (fun T' (x1, x2) => Syntax.mkApp (mkIdent ``TensorTree.eval)
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x2), T']) T
/-- The positions in getIndicesNode which get dualized. -/
partial def getDualPos (stx : Syntax) : TermElabM (List ) := do
let ind ← getIndicesNode stx
let indEnum := ind.enum
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
let indEnum := indFilt.enum
let duals := indEnum.filter (fun x => indexToDual x.2)
return duals.map (fun x => x.1)
syntax (name := tensorExprSyntax) "{" tensorExpr "}ᵀ" : term -- Pattern 1: Identifier with terms
def dualSyntax (l : List ) (T : Term) : Term :=
l.foldl (fun T' x => Syntax.mkApp (mkIdent ``TensorTree.jiggle)
#[Syntax.mkNumLit (toString x), T']) T
partial def getContrPos (stx : Syntax) : TermElabM (List ( × )) := do
let ind ← getIndicesNode stx
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
let indEnum := indFilt.enum
let bind := List.bind indEnum (fun a => indEnum.map (fun b => (a, b)))
let filt ← bind.filterMapM (fun x => indexPosEq x.1 x.2)
if ¬ ((filt.map Prod.fst).Nodup ∧ (filt.map Prod.snd).Nodup) then
throwError "To many contractions"
return filt
def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
let ind ← getIndicesNode stx
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
return ind.filter (fun x => indFilt.count x ≤ 1)
def contrSyntax (l : List ( × )) (T : Term) : Term :=
l.foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr)
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x0), T']) T
def elaborateTensorNode (stx : Syntax) : TermElabM Expr := do
let ind ← getIndicesNode stx
let evalPos ← getEvalPos stx
let dualPos ← getDualPos stx
match stx with
| `(tensorExpr| $T:ident $[$args]*) => do
let tensor ← elabTerm T none
return tensor
| `(tensorExpr| $T:term | $[$args]*) => do
let tensorNodeSyntax := Syntax.mkApp (mkIdent ``TensorTree.tensorNode) #[T]
let evalSyntax := evalSyntax (← getEvalPos stx) tensorNodeSyntax
let dualSyntax := dualSyntax (← getDualPos stx) evalSyntax
let contrSyntax := contrSyntax (← getContrPos stx) dualSyntax
let tensorExpr ← elabTerm contrSyntax none
return tensorExpr
| _ =>
throwError "Unsupported tensor expression syntax: {stx}"
throwError "Unsupported tensor expression syntax (elaborateTensorNode): {stx}"
syntax (name := tensorExprSyntax) "{" tensorExpr "}ᵀ" : term
elab_rules (kind:=tensorExprSyntax) : term
| `(term| {$e:tensorExpr}ᵀ) => do
let tensorTree ← elaborateTensorNode e
return tensorTree
open IndexNotation
example {S : TensorStruct} {n : } {c : Fin n → S.C} (T : S.F.obj (OverColor.mk c)) :
{T i j}ᵀ = T := by
sorry
#eval do
let stx ← `(tensorExpr| T τ(i) τ(k) 0)
let indices ← getIndicesNode stx
let evalPos ← getEvalPos stx
let dualPos ← getDualPos stx
IO.println s!"Indices: {indices},\nEval positions: {evalPos}\nDual positions: {dualPos}"
partial def dropEvalNode (stx : Syntax) : TermElabM (List (TSyntax `ident)) := do
let ind ← getIndicesNode stx
let indIndent := ind.filter (fun x => ¬ indexExprIsNum x)
partial def getContrPairsNode (stx : Syntax) : TermElabM (Array ( × )) := do
let ind ← getIndicesNode stx
let mut pairs : Array ( × ) := #[]
for i in [:ind.length] do
for j in [i+1:ind.length] do
if Option.map Lean.TSyntax.getId (ind.get? i) = Option.map Lean.TSyntax.getId (ind.get? j) then
pairs := pairs.push (i, j)
/- Check on pairs. -/
let x := pairs.toList
if ¬ ((x.map Prod.fst).Nodup ∧ (x.map Prod.snd).Nodup) then
throwError "To many contractions"
return pairs
def getContrIndicesNode (stx : Syntax) : TermElabM (List (TSyntax `ident)) := do
let ind ← getIndicesNode stx
let contrInd := ind.filter (fun x => ind.count x ≤ 1)
return contrInd
partial def getIndicesProd (stx : Syntax) : TermElabM (List (TSyntax `ident)) := do
match stx with
| `(tensorExpr| $_:ident $[$args]*) => do
getContrIndicesNode stx
| `(tensorExpr| $e1:tensorExpr ⊗ $e2:tensorExpr) => do
let ind1 ← getIndicesProd e1
let ind2 ← getIndicesProd e2
return ind1 ++ ind2
| `(tensorExpr| ($e:tensorExpr)) => do
getIndicesProd e
| _ =>
throwError "Unsupported tensor expression syntax: {stx}"
def getContrIndices (stx : Syntax) : TermElabM (List (TSyntax `ident)) := do
let ind ← getIndicesProd stx
let contrInd := ind.filter (fun x => ind.count x ≤ 1)
return contrInd
def getContrPairsProd (stx : Syntax) : TermElabM (Array ( × )) := do
let ind ← getIndicesProd stx
let mut pairs : Array ( × ) := #[]
for i in [:ind.length] do
for j in [i+1:ind.length] do
if Option.map Lean.TSyntax.getId (ind.get? i) = Option.map Lean.TSyntax.getId (ind.get? j) then
pairs := pairs.push (i, j)
/- Check on pairs. -/
let x := pairs.toList
if ¬ ((x.map Prod.fst).Nodup ∧ (x.map Prod.snd).Nodup) then
throwError "To many contractions"
return pairs
/-! Some test cases. -/
#eval do
let stx ← `(tensorExpr| (T i ⊗ B i i j ⊗ C k))
let indices ← getIndicesProd stx
let contrPairs ← getContrPairsProd stx
let contrInd ← getContrIndices stx
IO.println s!"Indices: {indices}\nContraction pairs: {contrPairs}\n Contraction list: {contrInd}"
variable (f : Fin 1 → Fin 4)
end TensorNode