refactor: Index notation
This commit is contained in:
parent
341aea19c6
commit
e5116d152c
3 changed files with 75 additions and 97 deletions
|
@ -17,6 +17,11 @@ The first character `ᵘ` specifies the color of the index, and the subsequent c
|
|||
|
||||
Strings of indices e.g. `ᵘ¹²ᵤ₄₃`` are defined elsewhere.
|
||||
|
||||
## Note
|
||||
|
||||
Index notation is currently being refactored. Much of the content here will likely be replaced
|
||||
or removed.
|
||||
|
||||
-/
|
||||
|
||||
open Lean
|
||||
|
|
|
@ -34,7 +34,7 @@ instance : Group S.G := S.G_group
|
|||
end TensorStruct
|
||||
|
||||
inductive TensorTree (S : TensorStruct) : ∀ {n : ℕ}, (Fin n → S.C) → Type where
|
||||
| tensorNode {n : ℕ} {c : Fin n → S.C} : S.F.obj (OverColor.mk c) → TensorTree S c
|
||||
| tensorNode {n : ℕ} {c : Fin n → S.C} (T: S.F.obj (OverColor.mk c)): TensorTree S c
|
||||
| add {n : ℕ} {c : Fin n → S.C} : TensorTree S c → TensorTree S c → TensorTree S c
|
||||
| perm {n m : ℕ} {c : Fin n → S.C} {c1 : Fin m → S.C}
|
||||
(σ : (OverColor.mk c) ⟶ (OverColor.mk c1)) (t : TensorTree S c) : TensorTree S c1
|
||||
|
@ -48,8 +48,8 @@ inductive TensorTree (S : TensorStruct) : ∀ {n : ℕ}, (Fin n → S.C) → Typ
|
|||
(j : Fin n.succ) → TensorTree S c → TensorTree S (c ∘ Fin.succAbove i ∘ Fin.succAbove j)
|
||||
| jiggle {n : ℕ} {c : Fin n → S.C} : (i : Fin n) → TensorTree S c →
|
||||
TensorTree S (Function.update c i (S.τ (c i)))
|
||||
| eval {n : ℕ} {c : Fin n.succ → S.C} : TensorTree S c →
|
||||
(i : Fin n.succ) → (x : Fin (S.evalNo (c i))) →
|
||||
| eval {n : ℕ} {c : Fin n.succ → S.C} :
|
||||
(i : Fin n.succ) → (x : Fin (S.evalNo (c i))) → TensorTree S c →
|
||||
TensorTree S (c ∘ Fin.succAbove i)
|
||||
|
||||
namespace TensorTree
|
||||
|
@ -68,7 +68,7 @@ def size : ∀ {n : ℕ} {c : Fin n → S.C}, TensorTree S c → ℕ := fun
|
|||
| mult _ _ t1 t2 => t1.size + t2.size + 1
|
||||
| contr _ _ t => t.size + 1
|
||||
| jiggle _ t => t.size + 1
|
||||
| eval t _ _ => t.size + 1
|
||||
| eval _ _ t => t.size + 1
|
||||
|
||||
|
||||
noncomputable section
|
||||
|
|
|
@ -9,7 +9,7 @@ import Lean.Elab.Term
|
|||
|
||||
## Elaboration of tensor trees
|
||||
|
||||
This file turns
|
||||
This file turns tensor expressions into tensor trees.
|
||||
|
||||
-/
|
||||
open Lean
|
||||
|
@ -49,6 +49,21 @@ def indexToNum (stx : Syntax) : TermElabM Nat :=
|
|||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax (indexToNum): {stx}"
|
||||
|
||||
def indexToIdent (stx : Syntax) : TermElabM Ident :=
|
||||
match stx with
|
||||
| `(indexExpr|$a:ident) => return a
|
||||
| `(indexExpr| τ($a:ident)) => return a
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax (indexToIdent): {stx}"
|
||||
|
||||
def indexPosEq (a b : ℕ × TSyntax `indexExpr) : TermElabM (Option (ℕ × ℕ)) := do
|
||||
let a' ← indexToIdent a.2
|
||||
let b' ← indexToIdent b.2
|
||||
if a.1 < b.1 ∧ Lean.TSyntax.getId a' = Lean.TSyntax.getId b' then
|
||||
return some (a.1, b.1)
|
||||
else
|
||||
return none
|
||||
|
||||
def indexToDual (stx : Syntax) : Bool :=
|
||||
match stx with
|
||||
| `(indexExpr| τ($_)) => true
|
||||
|
@ -60,7 +75,7 @@ def indexToDual (stx : Syntax) : Bool :=
|
|||
-/
|
||||
declare_syntax_cat tensorExpr
|
||||
|
||||
syntax term (ppSpace indexExpr)* : tensorExpr
|
||||
syntax term "|" (ppSpace indexExpr)* : tensorExpr
|
||||
|
||||
syntax tensorExpr "⊗" tensorExpr : tensorExpr
|
||||
|
||||
|
@ -70,127 +85,85 @@ syntax "(" tensorExpr ")" : tensorExpr
|
|||
|
||||
## For tensor nodes.
|
||||
|
||||
The operations are done in the following order:
|
||||
- evaluation.
|
||||
- dualization.
|
||||
- contraction.
|
||||
-/
|
||||
|
||||
namespace TensorNode
|
||||
|
||||
/-- The indices of a tensor node. Before contraction, dualisation, and evaluation. -/
|
||||
partial def getIndicesNode (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:ident $[$args]*) => do
|
||||
| `(tensorExpr| $_:term | $[$args]*) => do
|
||||
let indices ← args.toList.mapM fun arg => do
|
||||
match arg with
|
||||
| `(indexExpr|$t:indexExpr) => pure t
|
||||
return indices
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax: {stx}"
|
||||
throwError "Unsupported tensor expression syntax (getIndicesNode): {stx}"
|
||||
|
||||
/-- The positions in getIndicesNode which get evaluated, and the value they take. -/
|
||||
partial def getEvalPos (stx : Syntax) : TermElabM (List (ℕ × ℕ)) := do
|
||||
let ind ← getIndicesNode stx
|
||||
let indEnum := ind.enum
|
||||
|
||||
let evals := indEnum.filter (fun x => indexExprIsNum x.2)
|
||||
println! "indEnum: {evals}"
|
||||
let evals2 ← (evals.mapM (fun x => indexToNum x.2))
|
||||
return List.zip (evals.map (fun x => x.1)) evals2
|
||||
|
||||
def evalSyntax (l : List (ℕ × ℕ)) (T : Term) : Term :=
|
||||
l.foldl (fun T' (x1, x2) => Syntax.mkApp (mkIdent ``TensorTree.eval)
|
||||
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x2), T']) T
|
||||
|
||||
/-- The positions in getIndicesNode which get dualized. -/
|
||||
partial def getDualPos (stx : Syntax) : TermElabM (List ℕ) := do
|
||||
let ind ← getIndicesNode stx
|
||||
let indEnum := ind.enum
|
||||
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
|
||||
let indEnum := indFilt.enum
|
||||
let duals := indEnum.filter (fun x => indexToDual x.2)
|
||||
return duals.map (fun x => x.1)
|
||||
|
||||
syntax (name := tensorExprSyntax) "{" tensorExpr "}ᵀ" : term -- Pattern 1: Identifier with terms
|
||||
def dualSyntax (l : List ℕ) (T : Term) : Term :=
|
||||
l.foldl (fun T' x => Syntax.mkApp (mkIdent ``TensorTree.jiggle)
|
||||
#[Syntax.mkNumLit (toString x), T']) T
|
||||
|
||||
partial def getContrPos (stx : Syntax) : TermElabM (List (ℕ × ℕ)) := do
|
||||
let ind ← getIndicesNode stx
|
||||
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
|
||||
let indEnum := indFilt.enum
|
||||
let bind := List.bind indEnum (fun a => indEnum.map (fun b => (a, b)))
|
||||
let filt ← bind.filterMapM (fun x => indexPosEq x.1 x.2)
|
||||
if ¬ ((filt.map Prod.fst).Nodup ∧ (filt.map Prod.snd).Nodup) then
|
||||
throwError "To many contractions"
|
||||
return filt
|
||||
|
||||
def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
let ind ← getIndicesNode stx
|
||||
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
|
||||
return ind.filter (fun x => indFilt.count x ≤ 1)
|
||||
|
||||
def contrSyntax (l : List (ℕ × ℕ)) (T : Term) : Term :=
|
||||
l.foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr)
|
||||
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x0), T']) T
|
||||
|
||||
def elaborateTensorNode (stx : Syntax) : TermElabM Expr := do
|
||||
let ind ← getIndicesNode stx
|
||||
let evalPos ← getEvalPos stx
|
||||
let dualPos ← getDualPos stx
|
||||
match stx with
|
||||
| `(tensorExpr| $T:ident $[$args]*) => do
|
||||
let tensor ← elabTerm T none
|
||||
return tensor
|
||||
| `(tensorExpr| $T:term | $[$args]*) => do
|
||||
let tensorNodeSyntax := Syntax.mkApp (mkIdent ``TensorTree.tensorNode) #[T]
|
||||
let evalSyntax := evalSyntax (← getEvalPos stx) tensorNodeSyntax
|
||||
let dualSyntax := dualSyntax (← getDualPos stx) evalSyntax
|
||||
let contrSyntax := contrSyntax (← getContrPos stx) dualSyntax
|
||||
let tensorExpr ← elabTerm contrSyntax none
|
||||
return tensorExpr
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax: {stx}"
|
||||
throwError "Unsupported tensor expression syntax (elaborateTensorNode): {stx}"
|
||||
|
||||
syntax (name := tensorExprSyntax) "{" tensorExpr "}ᵀ" : term
|
||||
|
||||
elab_rules (kind:=tensorExprSyntax) : term
|
||||
| `(term| {$e:tensorExpr}ᵀ) => do
|
||||
let tensorTree ← elaborateTensorNode e
|
||||
return tensorTree
|
||||
|
||||
open IndexNotation
|
||||
|
||||
example {S : TensorStruct} {n : ℕ} {c : Fin n → S.C} (T : S.F.obj (OverColor.mk c)) :
|
||||
{T i j}ᵀ = T := by
|
||||
sorry
|
||||
#eval do
|
||||
let stx ← `(tensorExpr| T τ(i) τ(k) 0)
|
||||
let indices ← getIndicesNode stx
|
||||
let evalPos ← getEvalPos stx
|
||||
let dualPos ← getDualPos stx
|
||||
IO.println s!"Indices: {indices},\nEval positions: {evalPos}\nDual positions: {dualPos}"
|
||||
|
||||
partial def dropEvalNode (stx : Syntax) : TermElabM (List (TSyntax `ident)) := do
|
||||
let ind ← getIndicesNode stx
|
||||
let indIndent := ind.filter (fun x => ¬ indexExprIsNum x)
|
||||
|
||||
partial def getContrPairsNode (stx : Syntax) : TermElabM (Array (ℕ × ℕ)) := do
|
||||
let ind ← getIndicesNode stx
|
||||
let mut pairs : Array (ℕ × ℕ) := #[]
|
||||
for i in [:ind.length] do
|
||||
for j in [i+1:ind.length] do
|
||||
if Option.map Lean.TSyntax.getId (ind.get? i) = Option.map Lean.TSyntax.getId (ind.get? j) then
|
||||
pairs := pairs.push (i, j)
|
||||
/- Check on pairs. -/
|
||||
let x := pairs.toList
|
||||
if ¬ ((x.map Prod.fst).Nodup ∧ (x.map Prod.snd).Nodup) then
|
||||
throwError "To many contractions"
|
||||
return pairs
|
||||
|
||||
def getContrIndicesNode (stx : Syntax) : TermElabM (List (TSyntax `ident)) := do
|
||||
let ind ← getIndicesNode stx
|
||||
let contrInd := ind.filter (fun x => ind.count x ≤ 1)
|
||||
return contrInd
|
||||
|
||||
partial def getIndicesProd (stx : Syntax) : TermElabM (List (TSyntax `ident)) := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:ident $[$args]*) => do
|
||||
getContrIndicesNode stx
|
||||
| `(tensorExpr| $e1:tensorExpr ⊗ $e2:tensorExpr) => do
|
||||
let ind1 ← getIndicesProd e1
|
||||
let ind2 ← getIndicesProd e2
|
||||
return ind1 ++ ind2
|
||||
| `(tensorExpr| ($e:tensorExpr)) => do
|
||||
getIndicesProd e
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax: {stx}"
|
||||
|
||||
def getContrIndices (stx : Syntax) : TermElabM (List (TSyntax `ident)) := do
|
||||
let ind ← getIndicesProd stx
|
||||
let contrInd := ind.filter (fun x => ind.count x ≤ 1)
|
||||
return contrInd
|
||||
|
||||
def getContrPairsProd (stx : Syntax) : TermElabM (Array (ℕ × ℕ)) := do
|
||||
let ind ← getIndicesProd stx
|
||||
let mut pairs : Array (ℕ × ℕ) := #[]
|
||||
for i in [:ind.length] do
|
||||
for j in [i+1:ind.length] do
|
||||
if Option.map Lean.TSyntax.getId (ind.get? i) = Option.map Lean.TSyntax.getId (ind.get? j) then
|
||||
pairs := pairs.push (i, j)
|
||||
/- Check on pairs. -/
|
||||
let x := pairs.toList
|
||||
if ¬ ((x.map Prod.fst).Nodup ∧ (x.map Prod.snd).Nodup) then
|
||||
throwError "To many contractions"
|
||||
return pairs
|
||||
|
||||
/-! Some test cases. -/
|
||||
|
||||
|
||||
|
||||
|
||||
#eval do
|
||||
let stx ← `(tensorExpr| (T i ⊗ B i i j ⊗ C k))
|
||||
let indices ← getIndicesProd stx
|
||||
let contrPairs ← getContrPairsProd stx
|
||||
let contrInd ← getContrIndices stx
|
||||
IO.println s!"Indices: {indices}\nContraction pairs: {contrPairs}\n Contraction list: {contrInd}"
|
||||
|
||||
variable (f : Fin 1 → Fin 4)
|
||||
end TensorNode
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue