feat: addition elab and node identities

This commit is contained in:
jstoobysmith 2024-10-22 11:49:58 +00:00
parent ecb2c7778c
commit 6fbace33da
5 changed files with 165 additions and 70 deletions

View file

@ -370,52 +370,11 @@ partial def syntaxFull (stx : Syntax) : TermElabM Term := do
end ProdNode
namespace negNode
/-- The syntax associated with a product of tensors. -/
def negSyntax (T1 : Term) : Term :=
Syntax.mkApp (mkIdent ``TensorTree.neg) #[T1]
end negNode
/-- Returns the full list of indices after contraction. TODO: Include evaluation. -/
partial def getIndicesFull (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $_:term | $[$args]*) => do
return (← TensorNode.withoutContr stx)
| `(tensorExpr| $_:tensorExpr ⊗ $_:tensorExpr) => do
return (← ProdNode.withoutContr stx)
| `(tensorExpr| ($a:tensorExpr)) => do
return (← getIndicesFull a)
| `(tensorExpr| -$a:tensorExpr) => do
return (← getIndicesFull a)
| _ =>
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
namespace Equality
/-!
## For equality.
## Permutation constructions
-/
/-- Gets the indices associated with the LHS of an equality. -/
partial def getIndicesLeft (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $a:tensorExpr = $_:tensorExpr) => do
return (← getIndicesFull a)
| _ =>
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
/-- Gets the indices associated with the RHS of an equality. -/
partial def getIndicesRight (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $_:tensorExpr = $a:tensorExpr) => do
return (← getIndicesFull a)
| _ =>
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
/-- Given two lists of indices returns the `List ()` representing the how one list
permutes into the other. -/
def getPermutation (l1 l2 : List (TSyntax `indexExpr)) : TermElabM (List ()) := do
@ -445,6 +404,82 @@ def getPermutationSyntax (l1 l2 : List (TSyntax `indexExpr)) : TermElabM Term :=
let stx := Syntax.mkApp (mkIdent ``finMapToEquiv) #[P1, P2]
return stx
namespace negNode
/-- The syntax associated with a product of tensors. -/
def negSyntax (T1 : Term) : Term :=
Syntax.mkApp (mkIdent ``TensorTree.neg) #[T1]
end negNode
/-- Returns the full list of indices after contraction. TODO: Include evaluation. -/
partial def getIndicesFull (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $_:term | $[$args]*) => do
return (← TensorNode.withoutContr stx)
| `(tensorExpr| $_:tensorExpr ⊗ $_:tensorExpr) => do
return (← ProdNode.withoutContr stx)
| `(tensorExpr| ($a:tensorExpr)) => do
return (← getIndicesFull a)
| `(tensorExpr| -$a:tensorExpr) => do
return (← getIndicesFull a)
| `(tensorExpr| $a:tensorExpr + $_:tensorExpr) => do
return (← getIndicesFull a)
| _ =>
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
namespace Add
/-- Gets the indices associated with the LHS of an addition. -/
partial def getIndicesLeft (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $a:tensorExpr + $_:tensorExpr) => do
return (← getIndicesFull a)
| _ =>
throwError "Unsupported tensor expression syntax in Add.getIndicesLeft: {stx}"
/-- Gets the indices associated with the RHS of an addition. -/
partial def getIndicesRight (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $_:tensorExpr + $a:tensorExpr) => do
return (← getIndicesFull a)
| _ =>
throwError "Unsupported tensor expression syntax in Add.getIndicesRight: {stx}"
/-- The syntax for a equality of tensor trees. -/
def addSyntax (permSyntax : Term) (T1 T2 : Term) : TermElabM Term := do
let P := Syntax.mkApp (mkIdent ``OverColor.equivToHomEq) #[permSyntax]
let RHS := Syntax.mkApp (mkIdent ``TensorTree.perm) #[P, T2]
return Syntax.mkApp (mkIdent ``add) #[T1, RHS]
end Add
namespace Equality
/-!
## For equality.
-/
/-- Gets the indices associated with the LHS of an equality. -/
partial def getIndicesLeft (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $a:tensorExpr = $_:tensorExpr) => do
return (← getIndicesFull a)
| _ =>
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
/-- Gets the indices associated with the RHS of an equality. -/
partial def getIndicesRight (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $_:tensorExpr = $a:tensorExpr) => do
return (← getIndicesFull a)
| _ =>
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
/-- The syntax for a equality of tensor trees. -/
def equalSyntax (permSyntax : Term) (T1 T2 : Term) : TermElabM Term := do
let X1 := Syntax.mkApp (mkIdent ``TensorTree.tensor) #[T1]
@ -453,6 +488,8 @@ def equalSyntax (permSyntax : Term) (T1 T2 : Term) : TermElabM Term := do
let X2 := Syntax.mkApp (mkIdent ``TensorTree.tensor) #[X2']
return Syntax.mkApp (mkIdent ``Eq) #[X1, X2]
end Equality
/-- Creates the syntax associated with a tensor node. -/
partial def syntaxFull (stx : Syntax) : TermElabM Term := do
match stx with
@ -464,11 +501,17 @@ partial def syntaxFull (stx : Syntax) : TermElabM Term := do
return (← syntaxFull a)
| `(tensorExpr| -$a:tensorExpr) => do
return negNode.negSyntax (← syntaxFull a)
| `(tensorExpr| $a:tensorExpr = $b:tensorExpr) => do
let indicesLeft ← getIndicesLeft stx
let indicesRight ← getIndicesRight stx
| `(tensorExpr| $a + $b) => do
let indicesLeft ← Add.getIndicesLeft stx
let indicesRight ← Add.getIndicesRight stx
let permSyntax ← getPermutationSyntax indicesLeft indicesRight
let equalSyntax ← equalSyntax permSyntax (← syntaxFull a) (← syntaxFull b)
let addSyntax ← Add.addSyntax permSyntax (← syntaxFull a) (← syntaxFull b)
return addSyntax
| `(tensorExpr| $a:tensorExpr = $b:tensorExpr) => do
let indicesLeft ← Equality.getIndicesLeft stx
let indicesRight ← Equality.getIndicesRight stx
let permSyntax ← getPermutationSyntax indicesLeft indicesRight
let equalSyntax ← Equality.equalSyntax permSyntax (← syntaxFull a) (← syntaxFull b)
return equalSyntax
| _ =>
throwError "Unsupported tensor expression syntax in elaborateTensorNode: {stx}"
@ -486,24 +529,4 @@ elab_rules (kind:=tensorExprSyntax) : term
let tensorTree ← elaborateTensorNode e
return tensorTree
variable {S : TensorSpecies} {c4 : Fin 4 → S.C} (T4 : S.F.obj (OverColor.mk c4))
{c5 : Fin 5 → S.C} (T5 : S.F.obj (OverColor.mk c5)) (a : S.k)
variable (𝓣 : TensorTree S c4)
/-!
# Checks
-/
/-
#tensor_dot {T4 | i j τ(l) d ⊗ T5 | i j k m m}ᵀ.dot
#check {T4 | i j l d ⊗ T5 | i j k a b}ᵀ
#check {(T4 | i j l a ⊗ T5 | i j k c d) ⊗ T5 | i1 i2 i3 e f}ᵀ
-/
end Equality
end TensorTree