feat: addition elab and node identities
This commit is contained in:
parent
ecb2c7778c
commit
6fbace33da
5 changed files with 165 additions and 70 deletions
|
@ -370,52 +370,11 @@ partial def syntaxFull (stx : Syntax) : TermElabM Term := do
|
|||
|
||||
end ProdNode
|
||||
|
||||
namespace negNode
|
||||
|
||||
/-- The syntax associated with a product of tensors. -/
|
||||
def negSyntax (T1 : Term) : Term :=
|
||||
Syntax.mkApp (mkIdent ``TensorTree.neg) #[T1]
|
||||
|
||||
end negNode
|
||||
|
||||
/-- Returns the full list of indices after contraction. TODO: Include evaluation. -/
|
||||
partial def getIndicesFull (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:term | $[$args]*) => do
|
||||
return (← TensorNode.withoutContr stx)
|
||||
| `(tensorExpr| $_:tensorExpr ⊗ $_:tensorExpr) => do
|
||||
return (← ProdNode.withoutContr stx)
|
||||
| `(tensorExpr| ($a:tensorExpr)) => do
|
||||
return (← getIndicesFull a)
|
||||
| `(tensorExpr| -$a:tensorExpr) => do
|
||||
return (← getIndicesFull a)
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
|
||||
|
||||
namespace Equality
|
||||
|
||||
/-!
|
||||
|
||||
## For equality.
|
||||
## Permutation constructions
|
||||
|
||||
-/
|
||||
|
||||
/-- Gets the indices associated with the LHS of an equality. -/
|
||||
partial def getIndicesLeft (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
match stx with
|
||||
| `(tensorExpr| $a:tensorExpr = $_:tensorExpr) => do
|
||||
return (← getIndicesFull a)
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
|
||||
|
||||
/-- Gets the indices associated with the RHS of an equality. -/
|
||||
partial def getIndicesRight (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:tensorExpr = $a:tensorExpr) => do
|
||||
return (← getIndicesFull a)
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
|
||||
|
||||
/-- Given two lists of indices returns the `List (ℕ)` representing the how one list
|
||||
permutes into the other. -/
|
||||
def getPermutation (l1 l2 : List (TSyntax `indexExpr)) : TermElabM (List (ℕ)) := do
|
||||
|
@ -445,6 +404,82 @@ def getPermutationSyntax (l1 l2 : List (TSyntax `indexExpr)) : TermElabM Term :=
|
|||
let stx := Syntax.mkApp (mkIdent ``finMapToEquiv) #[P1, P2]
|
||||
return stx
|
||||
|
||||
|
||||
namespace negNode
|
||||
|
||||
/-- The syntax associated with a product of tensors. -/
|
||||
def negSyntax (T1 : Term) : Term :=
|
||||
Syntax.mkApp (mkIdent ``TensorTree.neg) #[T1]
|
||||
|
||||
end negNode
|
||||
|
||||
/-- Returns the full list of indices after contraction. TODO: Include evaluation. -/
|
||||
partial def getIndicesFull (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:term | $[$args]*) => do
|
||||
return (← TensorNode.withoutContr stx)
|
||||
| `(tensorExpr| $_:tensorExpr ⊗ $_:tensorExpr) => do
|
||||
return (← ProdNode.withoutContr stx)
|
||||
| `(tensorExpr| ($a:tensorExpr)) => do
|
||||
return (← getIndicesFull a)
|
||||
| `(tensorExpr| -$a:tensorExpr) => do
|
||||
return (← getIndicesFull a)
|
||||
| `(tensorExpr| $a:tensorExpr + $_:tensorExpr) => do
|
||||
return (← getIndicesFull a)
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
|
||||
|
||||
namespace Add
|
||||
|
||||
/-- Gets the indices associated with the LHS of an addition. -/
|
||||
partial def getIndicesLeft (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
match stx with
|
||||
| `(tensorExpr| $a:tensorExpr + $_:tensorExpr) => do
|
||||
return (← getIndicesFull a)
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in Add.getIndicesLeft: {stx}"
|
||||
|
||||
|
||||
/-- Gets the indices associated with the RHS of an addition. -/
|
||||
partial def getIndicesRight (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:tensorExpr + $a:tensorExpr) => do
|
||||
return (← getIndicesFull a)
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in Add.getIndicesRight: {stx}"
|
||||
|
||||
/-- The syntax for a equality of tensor trees. -/
|
||||
def addSyntax (permSyntax : Term) (T1 T2 : Term) : TermElabM Term := do
|
||||
let P := Syntax.mkApp (mkIdent ``OverColor.equivToHomEq) #[permSyntax]
|
||||
let RHS := Syntax.mkApp (mkIdent ``TensorTree.perm) #[P, T2]
|
||||
return Syntax.mkApp (mkIdent ``add) #[T1, RHS]
|
||||
|
||||
end Add
|
||||
|
||||
namespace Equality
|
||||
|
||||
/-!
|
||||
|
||||
## For equality.
|
||||
|
||||
-/
|
||||
|
||||
/-- Gets the indices associated with the LHS of an equality. -/
|
||||
partial def getIndicesLeft (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
match stx with
|
||||
| `(tensorExpr| $a:tensorExpr = $_:tensorExpr) => do
|
||||
return (← getIndicesFull a)
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
|
||||
|
||||
/-- Gets the indices associated with the RHS of an equality. -/
|
||||
partial def getIndicesRight (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:tensorExpr = $a:tensorExpr) => do
|
||||
return (← getIndicesFull a)
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
|
||||
|
||||
/-- The syntax for a equality of tensor trees. -/
|
||||
def equalSyntax (permSyntax : Term) (T1 T2 : Term) : TermElabM Term := do
|
||||
let X1 := Syntax.mkApp (mkIdent ``TensorTree.tensor) #[T1]
|
||||
|
@ -453,6 +488,8 @@ def equalSyntax (permSyntax : Term) (T1 T2 : Term) : TermElabM Term := do
|
|||
let X2 := Syntax.mkApp (mkIdent ``TensorTree.tensor) #[X2']
|
||||
return Syntax.mkApp (mkIdent ``Eq) #[X1, X2]
|
||||
|
||||
end Equality
|
||||
|
||||
/-- Creates the syntax associated with a tensor node. -/
|
||||
partial def syntaxFull (stx : Syntax) : TermElabM Term := do
|
||||
match stx with
|
||||
|
@ -464,11 +501,17 @@ partial def syntaxFull (stx : Syntax) : TermElabM Term := do
|
|||
return (← syntaxFull a)
|
||||
| `(tensorExpr| -$a:tensorExpr) => do
|
||||
return negNode.negSyntax (← syntaxFull a)
|
||||
| `(tensorExpr| $a:tensorExpr = $b:tensorExpr) => do
|
||||
let indicesLeft ← getIndicesLeft stx
|
||||
let indicesRight ← getIndicesRight stx
|
||||
| `(tensorExpr| $a + $b) => do
|
||||
let indicesLeft ← Add.getIndicesLeft stx
|
||||
let indicesRight ← Add.getIndicesRight stx
|
||||
let permSyntax ← getPermutationSyntax indicesLeft indicesRight
|
||||
let equalSyntax ← equalSyntax permSyntax (← syntaxFull a) (← syntaxFull b)
|
||||
let addSyntax ← Add.addSyntax permSyntax (← syntaxFull a) (← syntaxFull b)
|
||||
return addSyntax
|
||||
| `(tensorExpr| $a:tensorExpr = $b:tensorExpr) => do
|
||||
let indicesLeft ← Equality.getIndicesLeft stx
|
||||
let indicesRight ← Equality.getIndicesRight stx
|
||||
let permSyntax ← getPermutationSyntax indicesLeft indicesRight
|
||||
let equalSyntax ← Equality.equalSyntax permSyntax (← syntaxFull a) (← syntaxFull b)
|
||||
return equalSyntax
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in elaborateTensorNode: {stx}"
|
||||
|
@ -486,24 +529,4 @@ elab_rules (kind:=tensorExprSyntax) : term
|
|||
let tensorTree ← elaborateTensorNode e
|
||||
return tensorTree
|
||||
|
||||
variable {S : TensorSpecies} {c4 : Fin 4 → S.C} (T4 : S.F.obj (OverColor.mk c4))
|
||||
{c5 : Fin 5 → S.C} (T5 : S.F.obj (OverColor.mk c5)) (a : S.k)
|
||||
|
||||
variable (𝓣 : TensorTree S c4)
|
||||
|
||||
/-!
|
||||
|
||||
# Checks
|
||||
|
||||
-/
|
||||
|
||||
/-
|
||||
#tensor_dot {T4 | i j τ(l) d ⊗ T5 | i j k m m}ᵀ.dot
|
||||
|
||||
#check {T4 | i j l d ⊗ T5 | i j k a b}ᵀ
|
||||
|
||||
#check {(T4 | i j l a ⊗ T5 | i j k c d) ⊗ T5 | i1 i2 i3 e f}ᵀ
|
||||
-/
|
||||
end Equality
|
||||
|
||||
end TensorTree
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue