PhysLean/HepLean/Tensors/Tree/Elab.lean
2024-12-20 16:46:11 +00:00

604 lines
24 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2024 Joseph Tooby-Smith. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Joseph Tooby-Smith
-/
import HepLean.Lorentz.ComplexTensor.Basic
/-!
# Elaboration of tensor trees
- Syntax in Lean allows us to represent tensor expressions in a way close to what we expect to
see on pen-and-paper.
- The elaborator turns this syntax into a tensor tree.
## Examples
- Suppose `T` and `T'` are tensors `S.F (OverColor.mk ![c1, c2])`.
- `{T | μ ν}ᵀ` is `tensorNode T`.
- We can also write e.g. `{T | μ ν}ᵀ.tensor` to get the tensor itself.
- `{- T | μ ν}ᵀ` is `neg (tensorNode T)`.
- `{T | 0 ν}ᵀ` is `eval 0 0 (tensorNode T)`.
- `{T | μ ν + T' | μ ν}ᵀ` is `addNode (tensorNode T) (perm _ (tensorNode T'))`, where
here `_` will be the identity permutation so does nothing.
- `{T | μ ν = T' | μ ν}ᵀ` is `(tensorNode T).tensor = (perm _ (tensorNode T')).tensor`.
- If `a ∈ S.k` then `{a •ₜ T | μ ν}ᵀ` is `smulNode a (tensorNode T)`.
- If `g ∈ S.G` then `{g •ₐ T | μ ν}ᵀ` is `actionNode g (tensorNode T)`.
- Suppose `T2` is a tensor `S.F (OverColor.mk ![c3])`.
Then `{T | μ ν ⊗ T2 | σ}ᵀ` is `prodNode (tensorNode T1) (tensorNode T2)`.
- If `T3` is a tensor `S.F (OverColor.mk ![S.τ c1, S.τ c2])`, then
`{T | μ ν ⊗ T3 | μ σ}ᵀ` is `contr 0 1 _ (prodNode (tensorNode T1) (tensorNode T3))`.
`{T | μ ν ⊗ T3 | μ ν }ᵀ` is
`contr 0 0 _ (contr 0 1 _ (prodNode (tensorNode T1) (tensorNode T3)))`.
- If `T4` is a tensor `S.F (OverColor.mk ![c2, c1])` then
`{T | μ ν + T4 | ν μ }ᵀ`is `addNode (tensorNode T) (perm _ (tensorNode T4))` where `_`
is the permutation of the two indices of `T4`.
`{T | μ ν = T4 | ν μ }ᵀ` is `(tensorNode T).tensor = (perm _ (tensorNode T4)).tensor` is the
permutation of the two indices of `T4`.
## Comments
- In all of theses expressions `μ`, `ν` etc are free. It does not matter what they are called,
Lean will elaborate them in the same way. I.e. `{T | μ ν ⊗ T3 | μ ν }ᵀ` is exactly the same
to Lean as `{T | α β ⊗ T3 | α β }ᵀ`.
- Note that compared to ordinary index notation, we do not rise or lower the indices.
This is for two reasons: 1) It is difficult to make this general for all tensor species,
2) It is a reduency in ordinary index notation, since the tensor `T` itself already tells you
this information.
-/
open Lean
open Lean.Elab.Term
open Lean
open Lean.Meta
open Lean.Elab
open Lean.Elab.Term
open Lean Meta Elab Tactic
open IndexNotation
open complexLorentzTensor
namespace TensorTree
/-!
## Indexies
-/
/-- A syntax category for indices of tensor expressions. -/
declare_syntax_cat indexExpr
/-- A basic index is a ident. -/
syntax ident : indexExpr
/-- An index can be a num, which will be used to evaluate the tensor. -/
syntax num : indexExpr
/-- Notation to discribe the jiggle of a tensor index. -/
syntax "τ(" ident ")" : indexExpr
/-- Bool which is ture if an index is a num. -/
def indexExprIsNum (stx : Syntax) : Bool :=
match stx with
| `(indexExpr|$_:num) => true
| _ => false
/-- If an index is a num - the undelrying natural number. -/
def indexToNum (stx : Syntax) : TermElabM Nat :=
match stx with
| `(indexExpr|$a:num) =>
match a.raw.isNatLit? with
| some n => return n
| none => throwError "Expected a natural number literal."
| _ =>
throwError "Unsupported tensor expression syntax in indexToNum: {stx}"
/-- When an index is not a num, the corresponding ident. -/
def indexToIdent (stx : Syntax) : TermElabM Ident :=
match stx with
| `(indexExpr|$a:ident) => return a
| `(indexExpr| τ($a:ident)) => return a
| _ =>
throwError "Unsupported tensor expression syntax in indexToIdent: {stx}"
/-- Takes a pair ``a b : × TSyntax `indexExpr``. If `a.1 < b.1` and `a.2 = b.2` then
outputs `some (a.1, b.1)`, otherwise `none`. -/
def indexPosEq (a b : × TSyntax `indexExpr) : TermElabM (Option ( × )) := do
let a' ← indexToIdent a.2
let b' ← indexToIdent b.2
if a.1 < b.1 ∧ Lean.TSyntax.getId a' = Lean.TSyntax.getId b' then
return some (a.1, b.1)
else
return none
/-- Bool which is true if an index is of the form τ(i) that is, to be dualed. -/
def indexToDual (stx : Syntax) : Bool :=
match stx with
| `(indexExpr| τ($_)) => true
| _ => false
/-!
## Tensor expressions
-/
/-- A syntax category for tensor expressions. -/
declare_syntax_cat tensorExpr
/-- The syntax for a tensor node. -/
syntax term "|" (ppSpace indexExpr)* : tensorExpr
/-- Equality. -/
syntax:40 tensorExpr "=" tensorExpr:41 : tensorExpr
/-- The syntax for tensor prod two tensor nodes. -/
syntax:70 tensorExpr "⊗" tensorExpr:71 : tensorExpr
/-- The syntax for tensor addition. -/
syntax tensorExpr "+" tensorExpr : tensorExpr
/-- Allowing brackets to be used in a tensor expression. -/
syntax "(" tensorExpr ")" : tensorExpr
/-- Scalar multiplication for tensors. -/
syntax term "•ₜ" tensorExpr : tensorExpr
/-- group action for tensors. -/
syntax term "•ₐ" tensorExpr : tensorExpr
/-- Negation of a tensor tree. -/
syntax "-" tensorExpr : tensorExpr
namespace TensorNode
/-!
## For tensor nodes.
The operations are done in the following order:
- evaluation.
- dualization.
- contraction.
We also want to ensure the number of indices is correct.
-/
/-- The indices of a tensor node. Before contraction, and evaluation. -/
partial def getIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $_:term | $[$args]*) => do
let indices ← args.toList.mapM fun arg => do
match arg with
| `(indexExpr|$t:indexExpr) => pure t
return indices
| _ =>
throwError "Unsupported tensor expression syntax in getIndicesNode: {stx}"
/-- Uses the structure of the tensor to get the number of indices. -/
def getNoIndicesExact (stx : Syntax) : TermElabM := do
let expr ← elabTerm stx none
let type ← inferType expr
let strType := toString type
let n := (String.splitOn strType "CategoryTheory.MonoidalCategoryStruct.tensorObj").length
match n with
| 1 =>
match type with
| Expr.app _ (Expr.app _ (Expr.app _ c)) =>
let typeC ← inferType c
match typeC with
| Expr.forallE _ (Expr.app _ a) _ _ =>
let a' ← whnf a
match a' with
| Expr.lit (Literal.natVal n) => return n
|_ => throwError s!"Could not extract number of indices from tensor
{stx} (getNoIndicesExact). "
| _ => throwError s!"Could not extract number of indices from tensor
{stx} (getNoIndicesExact). "
| _ => return 1
| k => return k
/-- The construction of an expression corresponding to the type of a given string once parsed. -/
def stringToType (str : String) : TermElabM (Option Expr) := do
let env ← getEnv
let stx := Parser.runParserCategory env `term str
match stx with
| Except.error _ => return none
| Except.ok stx => return (some (← elabTerm stx none))
/-- The construction of an expression corresponding to the type of a given string once parsed. -/
def stringToTerm (str : String) : TermElabM Term := do
let env ← getEnv
let stx := Parser.runParserCategory env `term str
match stx with
| Except.error _ => throwError "Could not create type from string (stringToTerm). "
| Except.ok stx =>
match stx with
| `(term| $e) => return e
/-- Specific types of tensors which appear which we want to elaborate in specific ways. -/
def specialTypes : List (String × (Term → Term)) := [
("CoeSort.coe Lorentz.complexCo", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.vecNodeE) #[mkIdent ``complexLorentzTensor,
mkIdent ``complexLorentzTensor.Color.down, T]),
("CoeSort.coe Lorentz.complexContr", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.vecNodeE) #[mkIdent ``complexLorentzTensor,
mkIdent ``complexLorentzTensor.Color.up, T]),
("ModuleCat.carrier (Lorentz.complexContr ⊗ Lorentz.complexCo).V", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.twoNodeE) #[mkIdent ``complexLorentzTensor,
mkIdent ``complexLorentzTensor.Color.up, mkIdent ``complexLorentzTensor.Color.down, T]),
("ModuleCat.carrier (Lorentz.complexContr ⊗ Lorentz.complexContr).V", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.twoNodeE) #[mkIdent ``complexLorentzTensor,
mkIdent ``complexLorentzTensor.Color.up, mkIdent ``complexLorentzTensor.Color.up, T]),
("ModuleCat.carrier (Lorentz.complexCo ⊗ Lorentz.complexCo).V", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.twoNodeE) #[mkIdent ``complexLorentzTensor,
mkIdent ``complexLorentzTensor.Color.down, mkIdent ``complexLorentzTensor.Color.down, T]),
("ModuleCat.carrier (Lorentz.complexCo ⊗ Lorentz.complexContr).V", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.twoNodeE) #[
mkIdent ``complexLorentzTensor,
mkIdent ``complexLorentzTensor.Color.down,
mkIdent ``complexLorentzTensor.Color.up, T]),
("𝟙_ (Rep SL(2, )) ⟶ Lorentz.complexCo ⊗ Lorentz.complexCo", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
mkIdent ``complexLorentzTensor,
mkIdent ``complexLorentzTensor.Color.down,
mkIdent ``complexLorentzTensor.Color.down, T]),
("𝟙_ (Rep SL(2, )) ⟶ Lorentz.complexContr ⊗ Lorentz.complexContr", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
mkIdent ``complexLorentzTensor,
mkIdent ``complexLorentzTensor.Color.up,
mkIdent ``complexLorentzTensor.Color.up, T]),
("𝟙_ (Rep SL(2, )) ⟶ Lorentz.complexContr ⊗ Fermion.leftHanded ⊗ Fermion.rightHanded", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.constThreeNodeE) #[
mkIdent ``complexLorentzTensor, mkIdent ``complexLorentzTensor.Color.up,
mkIdent ``complexLorentzTensor.Color.upL,
mkIdent ``complexLorentzTensor.Color.upR, T]),
("𝟙_ (Rep SL(2, )) ⟶ Fermion.leftHanded ⊗ Fermion.leftHanded", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
mkIdent ``complexLorentzTensor,
mkIdent ``complexLorentzTensor.Color.upL,
mkIdent ``complexLorentzTensor.Color.upL, T]),
("𝟙_ (Rep SL(2, )) ⟶ Fermion.altLeftHanded ⊗ Fermion.altLeftHanded", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
mkIdent ``complexLorentzTensor,
mkIdent ``complexLorentzTensor.Color.downL,
mkIdent ``complexLorentzTensor.Color.downL, T]),
("𝟙_ (Rep SL(2, )) ⟶ Fermion.altRightHanded ⊗ Fermion.altRightHanded", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
mkIdent ``complexLorentzTensor,
mkIdent ``complexLorentzTensor.Color.downR,
mkIdent ``complexLorentzTensor.Color.downR, T]),
("𝟙_ (Rep SL(2, )) ⟶ Fermion.rightHanded ⊗ Fermion.rightHanded", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
mkIdent ``complexLorentzTensor,
mkIdent ``complexLorentzTensor.Color.upR,
mkIdent ``complexLorentzTensor.Color.upR, T])]
/-- The syntax associated with a terminal node of a tensor tree. -/
def termNodeSyntax (T : Term) : TermElabM Term := do
let expr ← elabTerm T none
let type ← inferType expr
let defEqList ← specialTypes.filterM (fun x => do
let type' ← stringToType x.1
match type' with
| none => return false
| some type' =>
let defEq ← isDefEq type type'
return defEq)
match defEqList with
| [(_, f)] =>
return f T
| _ =>
match type with
| Expr.app _ (Expr.app _ (Expr.app _ _)) =>
return Syntax.mkApp (mkIdent ``TensorTree.tensorNode) #[T]
| _ => return Syntax.mkApp (mkIdent ``TensorTree.vecNode) #[T]
/-- Adjusts a list `List ` by subtracting from each natrual number the number
of elements before it in the list which are less then itself. This is used
to form a list of pairs which can be used for evaluating indices. -/
def evalAdjustPos (l : List ) : List :=
let l' := List.mapAccumr
(fun x (prev : List ) =>
let e := prev.countP (fun y => y < x)
(x :: prev, x - e)) l.reverse []
l'.2.reverse
/-- The positions in getIndicesNode which get evaluated, and the value they take. -/
partial def getEvalPos (stx : Syntax) : TermElabM (List ( × )) := do
let ind ← getIndices stx
let indEnum := ind.enum
let evals := indEnum.filter (fun x => indexExprIsNum x.2)
let evals2 ← (evals.mapM (fun x => indexToNum x.2))
let pos := evalAdjustPos (evals.map (fun x => x.1))
return List.zip pos evals2
/-- For each element of `l : List ( × )` applies `TensorTree.eval` to the given term. -/
def evalSyntax (l : List ( × )) (T : Term) : Term :=
l.foldl (fun T' (x1, x2) => Syntax.mkApp (mkIdent ``TensorTree.eval)
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x2), T']) T
/-- The pairs of positions in getIndicesNode which get contracted. -/
partial def getContrPos (stx : Syntax) : TermElabM (List ( × )) := do
let ind ← getIndices stx
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
let indEnum := indFilt.enum
let bind := List.flatMap indEnum (fun a => indEnum.map (fun b => (a, b)))
let filt ← bind.filterMapM (fun x => indexPosEq x.1 x.2)
if ¬ ((filt.map Prod.fst).Nodup ∧ (filt.map Prod.snd).Nodup) then
throwError "To many contractions"
return filt
/-- The list of indices after contraction or evaluation. -/
def withoutContr (ind : List (TSyntax `indexExpr)) : TermElabM (List (TSyntax `indexExpr)) := do
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
return indFilt.filter (fun x => indFilt.count x ≤ 1)
end TensorNode
/-- Takes a list and puts conseutive elements into pairs.
e.g. [0, 1, 2, 3] becomes [(0, 1), (2, 3)]. -/
def toPairs (l : List ) : List ( × ) :=
match l with
| x1 :: x2 :: xs => (x1, x2) :: toPairs xs
| [] => []
| [x] => [(x, 0)]
/-- Adjusts a list `List ( × )` by subtracting from each natrual number the number
of elements before it in the list which are less then itself. This is used
to form a list of pairs which can be used for contracting indices. -/
def contrListAdjust (l : List ( × )) : List ( × ) :=
let l' := l.flatMap (fun p => [p.1, p.2])
let l'' := List.mapAccumr
(fun x (prev : List ) =>
let e := prev.countP (fun y => y < x)
(x :: prev, x - e)) l'.reverse []
toPairs l''.2.reverse
/-- For each element of `l : List ( × )` applies `TensorTree.contr` to the given term. -/
def contrSyntax (l : List ( × )) (T : Term) : Term :=
(contrListAdjust l).foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr)
#[Syntax.mkNumLit (toString x0),
Syntax.mkNumLit (toString x1), mkIdent ``rfl, T']) T
namespace ProdNode
/-!
## For product nodes.
For a product node we can take the tensor product, and then contract the indices.
-/
/-- Gets the indices associated with a product node. -/
partial def getIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $_:term | $[$args]*) => do
return (← TensorNode.withoutContr (← TensorNode.getIndices stx))
| `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do
let indicesA ← TensorNode.withoutContr (← getIndices a)
let indicesB ← TensorNode.withoutContr (← getIndices b)
return indicesA ++ indicesB
| `(tensorExpr| ($a:tensorExpr)) => do
return (← getIndices a)
| _ =>
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
/-- The pairs of positions in getIndicesNode which get contracted. -/
partial def getContrPos (stx : Syntax) : TermElabM (List ( × )) := do
let ind ← getIndices stx
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
let indEnum := indFilt.enum
let bind := List.flatMap indEnum (fun a => indEnum.map (fun b => (a, b)))
let filt ← bind.filterMapM (fun x => indexPosEq x.1 x.2)
if ¬ ((filt.map Prod.fst).Nodup ∧ (filt.map Prod.snd).Nodup) then
throwError "To many contractions"
return filt
/-- The list of indices after contraction. -/
def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
let ind ← getIndices stx
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
return ind.filter (fun x => indFilt.count x ≤ 1)
/-- The syntax associated with a product of tensors. -/
def prodSyntax (T1 T2 : Term) : Term :=
Syntax.mkApp (mkIdent ``TensorTree.prod) #[T1, T2]
end ProdNode
/-!
## Permutation constructions
-/
/-- Given two lists of indices returns the `List ()` representing the how one list
permutes into the other. -/
def getPermutation (l1 l2 : List (TSyntax `indexExpr)) : TermElabM (List ()) := do
let l1' ← l1.mapM (fun x => indexToIdent x)
let l2' ← l2.mapM (fun x => indexToIdent x)
let l1enum := l1'.enum
let l2'' := l2'.filterMap
(fun x => l1enum.find? (fun y => Lean.TSyntax.getId y.2 = Lean.TSyntax.getId x))
return l2''.map fun x => x.1
open HepLean.Fin
/-- Given two lists of indices returns the permutation between them based on `finMapToEquiv`. -/
def getPermutationSyntax (l1 l2 : List (TSyntax `indexExpr)) : TermElabM Term := do
let lPerm ← getPermutation l1 l2
let l2Perm ← getPermutation l2 l1
let permString := "![" ++ String.intercalate ", " (lPerm.map toString) ++ "]"
let perm2String := "![" ++ String.intercalate ", " (l2Perm.map toString) ++ "]"
let P1 ← TensorNode.stringToTerm permString
let P2 ← TensorNode.stringToTerm perm2String
let stx := Syntax.mkApp (mkIdent ``finMapToEquiv) #[P1, P2]
return stx
namespace negNode
/-- The syntax associated with a product of tensors. -/
def negSyntax (T1 : Term) : Term :=
Syntax.mkApp (mkIdent ``TensorTree.neg) #[T1]
end negNode
/-- Returns the full list of indices after contraction. TODO: Include evaluation. -/
partial def getIndicesFull (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $_:term | $[$args]*) => do
return (← TensorNode.withoutContr (← TensorNode.getIndices stx))
| `(tensorExpr| $_:tensorExpr ⊗ $_:tensorExpr) => do
return (← ProdNode.withoutContr stx)
| `(tensorExpr| ($a:tensorExpr)) => do
return (← getIndicesFull a)
| `(tensorExpr| -$a:tensorExpr) => do
return (← getIndicesFull a)
| `(tensorExpr| $_:term •ₜ $a) => do
return (← getIndicesFull a)
| `(tensorExpr| $a:tensorExpr + $_:tensorExpr) => do
return (← getIndicesFull a)
| _ =>
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
namespace SMul
/-- The syntax associated with the scalar multiplication of tensors. -/
def smulSyntax (c T : Term) : Term :=
Syntax.mkApp (mkIdent ``TensorTree.smul) #[c, T]
end SMul
namespace Action
/-- The syntax associated with the group action of tensors. -/
def actionSyntax (c T : Term) : Term :=
Syntax.mkApp (mkIdent ``TensorTree.action) #[c, T]
end Action
namespace Add
/-- Gets the indices associated with the LHS of an addition. -/
partial def getIndicesLeft (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $a:tensorExpr + $_:tensorExpr) => do
return (← getIndicesFull a)
| _ =>
throwError "Unsupported tensor expression syntax in Add.getIndicesLeft: {stx}"
/-- Gets the indices associated with the RHS of an addition. -/
partial def getIndicesRight (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $_:tensorExpr + $a:tensorExpr) => do
return (← getIndicesFull a)
| _ =>
throwError "Unsupported tensor expression syntax in Add.getIndicesRight: {stx}"
/-- The syntax for a equality of tensor trees. -/
def addSyntax (permSyntax : Term) (T1 T2 : Term) : TermElabM Term := do
let P := Syntax.mkApp (mkIdent ``OverColor.equivToHomEq) #[permSyntax]
let RHS := Syntax.mkApp (mkIdent ``TensorTree.perm) #[P, T2]
return Syntax.mkApp (mkIdent ``add) #[T1, RHS]
end Add
namespace Equality
/-!
## For equality.
-/
/-- Gets the indices associated with the LHS of an equality. -/
partial def getIndicesLeft (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $a:tensorExpr = $_:tensorExpr) => do
return (← getIndicesFull a)
| _ =>
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
/-- Gets the indices associated with the RHS of an equality. -/
partial def getIndicesRight (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $_:tensorExpr = $a:tensorExpr) => do
return (← getIndicesFull a)
| _ =>
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
/-- The syntax for a equality of tensor trees. -/
def equalSyntax (permSyntax : Term) (T1 T2 : Term) : TermElabM Term := do
let X1 := Syntax.mkApp (mkIdent ``TensorTree.tensor) #[T1]
let P := Syntax.mkApp (mkIdent ``OverColor.equivToHomEq) #[permSyntax]
let X2' := Syntax.mkApp (mkIdent ``TensorTree.perm) #[P, T2]
let X2 := Syntax.mkApp (mkIdent ``TensorTree.tensor) #[X2']
return Syntax.mkApp (mkIdent ``Eq) #[X1, X2]
end Equality
/-- Creates the syntax associated with a tensor node. -/
partial def syntaxFull (stx : Syntax) : TermElabM Term := do
match stx with
| `(tensorExpr| $T:term | $[$args]*) =>
let indices ← TensorNode.getIndices stx
let rawIndex ← TensorNode.getNoIndicesExact T
if indices.length ≠ rawIndex then
throwError "The expected number of indices {rawIndex} does not match the tensor {T}."
let tensorNodeSyntax ← TensorNode.termNodeSyntax T
let evalSyntax := TensorNode.evalSyntax (← TensorNode.getEvalPos stx) tensorNodeSyntax
let contrSyntax := contrSyntax (← TensorNode.getContrPos stx) evalSyntax
return contrSyntax
| `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do
let prodSyntax := ProdNode.prodSyntax (← syntaxFull a) (← syntaxFull b)
let contrSyntax := contrSyntax (← ProdNode.getContrPos stx) prodSyntax
return contrSyntax
| `(tensorExpr| ($a:tensorExpr)) => do
return (← syntaxFull a)
| `(tensorExpr| -$a:tensorExpr) => do
return negNode.negSyntax (← syntaxFull a)
| `(tensorExpr| $c:term •ₜ $a:tensorExpr) => do
return SMul.smulSyntax c (← syntaxFull a)
| `(tensorExpr| $c:term •ₐ $a:tensorExpr) => do
return Action.actionSyntax c (← syntaxFull a)
| `(tensorExpr| $a + $b) => do
let indicesLeft ← Add.getIndicesLeft stx
let indicesRight ← Add.getIndicesRight stx
let permSyntax ← getPermutationSyntax indicesLeft indicesRight
let addSyntax ← Add.addSyntax permSyntax (← syntaxFull a) (← syntaxFull b)
return addSyntax
| `(tensorExpr| $a:tensorExpr = $b:tensorExpr) => do
let indicesLeft ← Equality.getIndicesLeft stx
let indicesRight ← Equality.getIndicesRight stx
let permSyntax ← getPermutationSyntax indicesLeft indicesRight
let equalSyntax ← Equality.equalSyntax permSyntax (← syntaxFull a) (← syntaxFull b)
return equalSyntax
| _ =>
throwError "Unsupported tensor expression syntax in elaborateTensorNode: {stx}"
/-- An elaborator for tensor nodes. This is to be generalized. -/
def elaborateTensorNode (stx : Syntax) : TermElabM Expr := do
let tensorExpr ← elabTerm (← syntaxFull stx) none
return tensorExpr
/-- Syntax turning a tensor expression into a term. -/
syntax (name := tensorExprSyntax) "{" tensorExpr "}ᵀ" : term
elab_rules (kind:=tensorExprSyntax) : term
| `(term| {$e:tensorExpr}ᵀ) => do
let tensorTree ← elaborateTensorNode e
return tensorTree
/-!
## Test cases
-/
/-
variable {S : TensorSpecies} {c : Fin (Nat.succ (Nat.succ 0)) → S.C} {t : S.F.obj (OverColor.mk c)}
#check {t | α β}ᵀ
-/
end TensorTree