PhysLean/HepLean/Tensors/Tree/Elab.lean

537 lines
20 KiB
Text
Raw Normal View History

/-
Copyright (c) 2024 Joseph Tooby-Smith. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Joseph Tooby-Smith
-/
import HepLean.Tensors.Tree.Basic
import Lean.Elab.Term
2024-10-12 09:00:08 +00:00
import HepLean.Tensors.Tree.Dot
2024-10-16 16:38:36 +00:00
import HepLean.Tensors.ComplexLorentz.Basic
/-!
## Elaboration of tensor trees
2024-10-08 07:26:23 +00:00
This file turns tensor expressions into tensor trees.
-/
open Lean
open Lean.Elab.Term
open Lean
open Lean.Meta
open Lean.Elab
open Lean.Elab.Term
2024-10-08 05:53:16 +00:00
open Lean Meta Elab Tactic
2024-10-12 09:00:08 +00:00
open IndexNotation
namespace TensorTree
2024-10-08 05:53:16 +00:00
/-!
## Indexies
-/
2024-10-08 07:52:55 +00:00
/-- A syntax category for indices of tensor expressions. -/
2024-10-08 05:53:16 +00:00
declare_syntax_cat indexExpr
2024-10-08 07:52:55 +00:00
/-- A basic index is a ident. -/
2024-10-08 05:53:16 +00:00
syntax ident : indexExpr
2024-10-08 07:52:55 +00:00
/-- An index can be a num, which will be used to evaluate the tensor. -/
2024-10-08 05:53:16 +00:00
syntax num : indexExpr
2024-10-08 07:52:55 +00:00
/-- Notation to discribe the jiggle of a tensor index. -/
2024-10-08 05:53:16 +00:00
syntax "τ(" ident ")" : indexExpr
2024-10-08 07:52:55 +00:00
/-- Bool which is ture if an index is a num. -/
2024-10-08 05:53:16 +00:00
def indexExprIsNum (stx : Syntax) : Bool :=
match stx with
| `(indexExpr|$_:num) => true
| _ => false
2024-10-08 07:52:55 +00:00
/-- If an index is a num - the undelrying natural number. -/
2024-10-08 05:53:16 +00:00
def indexToNum (stx : Syntax) : TermElabM Nat :=
match stx with
| `(indexExpr|$a:num) =>
match a.raw.isNatLit? with
| some n => return n
2024-10-08 07:52:55 +00:00
| none => throwError "Expected a natural number literal."
2024-10-08 05:53:16 +00:00
| _ =>
2024-10-08 07:52:55 +00:00
throwError "Unsupported tensor expression syntax in indexToNum: {stx}"
2024-10-08 05:53:16 +00:00
2024-10-08 07:52:55 +00:00
/-- When an index is not a num, the corresponding ident. -/
2024-10-08 07:26:23 +00:00
def indexToIdent (stx : Syntax) : TermElabM Ident :=
match stx with
| `(indexExpr|$a:ident) => return a
| `(indexExpr| τ($a:ident)) => return a
| _ =>
2024-10-08 07:52:55 +00:00
throwError "Unsupported tensor expression syntax in indexToIdent: {stx}"
2024-10-08 07:26:23 +00:00
2024-10-08 07:52:55 +00:00
/-- Takes a pair ``a b : × TSyntax `indexExpr``. If `a.1 < b.1` and `a.2 = b.2` then
outputs `some (a.1, b.1)`, otherwise `none`. -/
2024-10-08 07:26:23 +00:00
def indexPosEq (a b : × TSyntax `indexExpr) : TermElabM (Option ( × )) := do
let a' ← indexToIdent a.2
let b' ← indexToIdent b.2
if a.1 < b.1 ∧ Lean.TSyntax.getId a' = Lean.TSyntax.getId b' then
return some (a.1, b.1)
else
return none
2024-10-08 07:52:55 +00:00
/-- Bool which is true if an index is of the form τ(i) that is, to be dualed. -/
2024-10-08 05:53:16 +00:00
def indexToDual (stx : Syntax) : Bool :=
match stx with
| `(indexExpr| τ($_)) => true
| _ => false
2024-10-16 16:38:36 +00:00
2024-10-08 05:53:16 +00:00
/-!
## Tensor expressions
-/
2024-10-08 07:52:55 +00:00
/-- A syntax category for tensor expressions. -/
declare_syntax_cat tensorExpr
2024-10-08 07:52:55 +00:00
/-- The syntax for a tensor node. -/
2024-10-08 07:26:23 +00:00
syntax term "|" (ppSpace indexExpr)* : tensorExpr
2024-10-22 06:42:06 +00:00
/-- Equality. -/
syntax:40 tensorExpr "=" tensorExpr:41 : tensorExpr
2024-10-08 07:52:55 +00:00
/-- The syntax for tensor prod two tensor nodes. -/
2024-10-22 06:42:06 +00:00
syntax:70 tensorExpr "⊗" tensorExpr:71 : tensorExpr
2024-10-08 05:53:16 +00:00
/-- The syntax for tensor addition. -/
syntax tensorExpr "+" tensorExpr : tensorExpr
2024-10-08 07:52:55 +00:00
/-- Allowing brackets to be used in a tensor expression. -/
2024-10-08 05:53:16 +00:00
syntax "(" tensorExpr ")" : tensorExpr
2024-10-08 15:47:53 +00:00
/-- Scalar multiplication for tensors. -/
2024-10-22 16:43:32 +00:00
syntax term "•ₜ" tensorExpr : tensorExpr
/-- Negation of a tensor tree. -/
syntax "-" tensorExpr : tensorExpr
namespace TensorNode
2024-10-08 05:53:16 +00:00
/-!
## For tensor nodes.
2024-10-08 07:26:23 +00:00
The operations are done in the following order:
- evaluation.
- dualization.
- contraction.
We also want to ensure the number of indices is correct.
2024-10-08 05:53:16 +00:00
-/
2024-10-22 06:42:06 +00:00
/-- The indices of a tensor node. Before contraction, and evaluation. -/
partial def getIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
2024-10-08 05:53:16 +00:00
match stx with
2024-10-08 07:26:23 +00:00
| `(tensorExpr| $_:term | $[$args]*) => do
2024-10-08 05:53:16 +00:00
let indices ← args.toList.mapM fun arg => do
match arg with
| `(indexExpr|$t:indexExpr) => pure t
return indices
| _ =>
2024-10-08 07:52:55 +00:00
throwError "Unsupported tensor expression syntax in getIndicesNode: {stx}"
2024-10-08 05:53:16 +00:00
/-- Uses the structure of the tensor to get the number of indices. -/
def getNoIndicesExact (stx : Syntax) : TermElabM := do
let expr ← elabTerm stx none
let type ← inferType expr
2024-10-16 16:38:36 +00:00
let strType := toString type
let n := (String.splitOn strType "CategoryTheory.MonoidalCategoryStruct.tensorObj").length
match n with
| 1 =>
match type with
| Expr.app _ (Expr.app _ (Expr.app _ c)) =>
let typeC ← inferType c
match typeC with
| Expr.forallE _ (Expr.app _ (Expr.app (Expr.app _ (Expr.lit (Literal.natVal n))) _)) _ _ =>
return n
| _ => throwError "Could not extract number of indices from tensor (getNoIndicesExact). "
| _ => return 1
| k => return k
/-- The construction of an expression corresponding to the type of a given string once parsed. -/
def stringToType (str : String) : TermElabM Expr := do
let env ← getEnv
let stx := Parser.runParserCategory env `term str
match stx with
| Except.error _ => throwError "Could not create type from string (stringToType). "
| Except.ok stx => elabTerm stx none
/-- The construction of an expression corresponding to the type of a given string once parsed. -/
def stringToTerm (str : String) : TermElabM Term := do
let env ← getEnv
let stx := Parser.runParserCategory env `term str
match stx with
| Except.error _ => throwError "Could not create type from string (stringToType). "
| Except.ok stx =>
match stx with
| `(term| $e) => return e
2024-10-22 10:47:37 +00:00
/-- Specific types of tensors which appear which we want to elaborate in specific ways. -/
2024-10-22 07:29:25 +00:00
def specialTypes : List (String × (Term → Term)) := [
("CoeSort.coe Lorentz.complexCo", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.vecNodeE) #[mkIdent ``Fermion.complexLorentzTensor,
mkIdent ``Fermion.Color.down, T]),
("CoeSort.coe Lorentz.complexContr", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.vecNodeE) #[mkIdent ``Fermion.complexLorentzTensor,
mkIdent ``Fermion.Color.up, T]),
("ModuleCat.carrier (Lorentz.complexContr ⊗ Lorentz.complexCo).V", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.twoNodeE) #[mkIdent ``Fermion.complexLorentzTensor,
mkIdent ``Fermion.Color.up, mkIdent ``Fermion.Color.down, T]),
("ModuleCat.carrier (Lorentz.complexContr ⊗ Lorentz.complexContr).V", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.twoNodeE) #[mkIdent ``Fermion.complexLorentzTensor,
mkIdent ``Fermion.Color.up, mkIdent ``Fermion.Color.up, T]),
("ModuleCat.carrier (Lorentz.complexCo ⊗ Lorentz.complexCo).V", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.twoNodeE) #[mkIdent ``Fermion.complexLorentzTensor,
mkIdent ``Fermion.Color.down, mkIdent ``Fermion.Color.down, T]),
("ModuleCat.carrier (Lorentz.complexCo ⊗ Lorentz.complexContr).V", fun T =>
2024-10-22 14:27:44 +00:00
Syntax.mkApp (mkIdent ``TensorTree.twoNodeE) #[
mkIdent ``Fermion.complexLorentzTensor,
mkIdent ``Fermion.Color.down,
mkIdent ``Fermion.Color.up, T]),
2024-10-22 07:29:25 +00:00
("𝟙_ (Rep SL(2, )) ⟶ Lorentz.complexCo ⊗ Lorentz.complexCo", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
2024-10-22 14:27:44 +00:00
mkIdent ``Fermion.complexLorentzTensor,
mkIdent ``Fermion.Color.down,
mkIdent ``Fermion.Color.down, T]),
2024-10-22 14:19:43 +00:00
("𝟙_ (Rep SL(2, )) ⟶ Lorentz.complexContr ⊗ Fermion.leftHanded ⊗ Fermion.rightHanded", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.constThreeNodeE) #[
2024-10-22 14:27:44 +00:00
mkIdent ``Fermion.complexLorentzTensor, mkIdent ``Fermion.Color.up,
mkIdent ``Fermion.Color.upL,
2024-10-22 16:43:32 +00:00
mkIdent ``Fermion.Color.upR, T]),
("𝟙_ (Rep SL(2, )) ⟶ Fermion.leftHanded ⊗ Fermion.leftHanded", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
mkIdent ``Fermion.complexLorentzTensor,
mkIdent ``Fermion.Color.upL,
mkIdent ``Fermion.Color.upL, T]),
("𝟙_ (Rep SL(2, )) ⟶ Fermion.rightHanded ⊗ Fermion.rightHanded", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
mkIdent ``Fermion.complexLorentzTensor,
mkIdent ``Fermion.Color.upR,
2024-10-22 14:27:44 +00:00
mkIdent ``Fermion.Color.upR, T])]
2024-10-22 07:29:25 +00:00
2024-10-16 16:38:36 +00:00
/-- The syntax associated with a terminal node of a tensor tree. -/
def termNodeSyntax (T : Term) : TermElabM Term := do
let expr ← elabTerm T none
let type ← inferType expr
2024-10-22 07:29:25 +00:00
let defEqList ← specialTypes.filterM (fun x => do
let type' ← stringToType x.1
let defEq ← isDefEq type type'
return defEq)
match defEqList with
| [(_, f)] =>
return f T
2024-10-22 06:42:06 +00:00
| _ =>
2024-10-22 07:29:25 +00:00
match type with
| Expr.app _ (Expr.app _ (Expr.app _ c)) =>
let typeC ← inferType c
match typeC with
| Expr.forallE _ (Expr.app _ (Expr.app (Expr.app _ (Expr.lit (Literal.natVal _))) _)) _ _ =>
return Syntax.mkApp (mkIdent ``TensorTree.tensorNode) #[T]
| _ => throwError "Could not create terminal node syntax (termNodeSyntax). "
| _ => return Syntax.mkApp (mkIdent ``TensorTree.vecNode) #[T]
2024-10-22 06:42:06 +00:00
/-- Adjusts a list `List ` by subtracting from each natrual number the number
of elements before it in the list which are less then itself. This is used
to form a list of pairs which can be used for evaluating indices. -/
2024-10-22 10:41:14 +00:00
def evalAdjustPos (l : List ) : List :=
2024-10-22 06:42:06 +00:00
let l' := List.mapAccumr
(fun x (prev : List ) =>
let e := prev.countP (fun y => y < x)
(x :: prev, x - e)) l.reverse []
l'.2.reverse
2024-10-08 07:26:23 +00:00
/-- The positions in getIndicesNode which get evaluated, and the value they take. -/
2024-10-08 05:53:16 +00:00
partial def getEvalPos (stx : Syntax) : TermElabM (List ( × )) := do
let ind ← getIndices stx
2024-10-08 05:53:16 +00:00
let indEnum := ind.enum
let evals := indEnum.filter (fun x => indexExprIsNum x.2)
let evals2 ← (evals.mapM (fun x => indexToNum x.2))
2024-10-22 10:41:14 +00:00
let pos := evalAdjustPos (evals.map (fun x => x.1))
2024-10-22 06:42:06 +00:00
return List.zip pos evals2
2024-10-08 05:53:16 +00:00
2024-10-08 07:52:55 +00:00
/-- For each element of `l : List ( × )` applies `TensorTree.eval` to the given term. -/
2024-10-08 07:26:23 +00:00
def evalSyntax (l : List ( × )) (T : Term) : Term :=
l.foldl (fun T' (x1, x2) => Syntax.mkApp (mkIdent ``TensorTree.eval)
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x2), T']) T
2024-10-08 07:52:55 +00:00
/-- The pairs of positions in getIndicesNode which get contracted. -/
2024-10-08 07:26:23 +00:00
partial def getContrPos (stx : Syntax) : TermElabM (List ( × )) := do
let ind ← getIndices stx
2024-10-08 07:26:23 +00:00
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
let indEnum := indFilt.enum
let bind := List.bind indEnum (fun a => indEnum.map (fun b => (a, b)))
let filt ← bind.filterMapM (fun x => indexPosEq x.1 x.2)
2024-10-08 07:52:55 +00:00
if ¬ ((filt.map Prod.fst).Nodup ∧ (filt.map Prod.snd).Nodup) then
2024-10-08 05:53:16 +00:00
throwError "To many contractions"
2024-10-08 07:26:23 +00:00
return filt
2024-10-08 05:53:16 +00:00
2024-10-22 06:42:06 +00:00
/-- The list of indices after contraction or evaluation. -/
2024-10-22 16:43:32 +00:00
def withoutContr (ind : List (TSyntax `indexExpr)) : TermElabM (List (TSyntax `indexExpr)) := do
2024-10-08 07:26:23 +00:00
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
2024-10-22 06:42:06 +00:00
return indFilt.filter (fun x => indFilt.count x ≤ 1)
2024-10-08 05:53:16 +00:00
2024-10-22 12:10:55 +00:00
end TensorNode
2024-10-21 06:53:58 +00:00
/-- Takes a list and puts conseutive elements into pairs.
2024-10-19 10:07:03 +00:00
e.g. [0, 1, 2, 3] becomes [(0, 1), (2, 3)]. -/
def toPairs (l : List ) : List ( × ) :=
match l with
| x1 :: x2 :: xs => (x1, x2) :: toPairs xs
2024-10-19 09:19:29 +00:00
| [] => []
| [x] => [(x, 0)]
2024-10-19 10:07:03 +00:00
/-- Adjusts a list `List ( × )` by subtracting from each natrual number the number
of elements before it in the list which are less then itself. This is used
to form a list of pairs which can be used for contracting indices. -/
2024-10-19 08:49:26 +00:00
def contrListAdjust (l : List ( × )) : List ( × ) :=
let l' := l.bind (fun p => [p.1, p.2])
let l'' := List.mapAccumr
2024-10-19 09:19:29 +00:00
(fun x (prev : List ) =>
let e := prev.countP (fun y => y < x)
(x :: prev, x - e)) l'.reverse []
toPairs l''.2.reverse
2024-10-08 07:52:55 +00:00
/-- For each element of `l : List ( × )` applies `TensorTree.contr` to the given term. -/
2024-10-08 07:26:23 +00:00
def contrSyntax (l : List ( × )) (T : Term) : Term :=
(contrListAdjust l).foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr)
#[Syntax.mkNumLit (toString x0),
Syntax.mkNumLit (toString x1), mkIdent ``rfl, T']) T
2024-10-08 07:26:23 +00:00
namespace ProdNode
/-!
## For product nodes.
For a product node we can take the tensor product, and then contract the indices.
-/
2024-10-08 11:55:06 +00:00
/-- Gets the indices associated with a product node. -/
partial def getIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $_:term | $[$args]*) => do
2024-10-22 16:43:32 +00:00
return (← TensorNode.withoutContr (← TensorNode.getIndices stx))
| `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do
2024-10-22 16:43:32 +00:00
let indicesA ← TensorNode.withoutContr (← getIndices a)
let indicesB ← TensorNode.withoutContr (← getIndices b)
return indicesA ++ indicesB
| `(tensorExpr| ($a:tensorExpr)) => do
return (← getIndices a)
| _ =>
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
/-- The pairs of positions in getIndicesNode which get contracted. -/
partial def getContrPos (stx : Syntax) : TermElabM (List ( × )) := do
let ind ← getIndices stx
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
let indEnum := indFilt.enum
let bind := List.bind indEnum (fun a => indEnum.map (fun b => (a, b)))
let filt ← bind.filterMapM (fun x => indexPosEq x.1 x.2)
if ¬ ((filt.map Prod.fst).Nodup ∧ (filt.map Prod.snd).Nodup) then
throwError "To many contractions"
return filt
/-- The list of indices after contraction. -/
def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
let ind ← getIndices stx
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
return ind.filter (fun x => indFilt.count x ≤ 1)
2024-10-08 11:55:06 +00:00
/-- The syntax associated with a product of tensors. -/
def prodSyntax (T1 T2 : Term) : Term :=
Syntax.mkApp (mkIdent ``TensorTree.prod) #[T1, T2]
end ProdNode
/-!
## Permutation constructions
-/
/-- Given two lists of indices returns the `List ()` representing the how one list
permutes into the other. -/
def getPermutation (l1 l2 : List (TSyntax `indexExpr)) : TermElabM (List ()) := do
let l1' ← l1.mapM (fun x => indexToIdent x)
let l2' ← l2.mapM (fun x => indexToIdent x)
let l1enum := l1'.enum
let l2'' := l2'.filterMap
(fun x => l1enum.find? (fun y => Lean.TSyntax.getId y.2 = Lean.TSyntax.getId x))
return l2''.map fun x => x.1
/-- Takes two maps `Fin n → Fin n` and returns the equivelance they form. -/
2024-10-22 16:43:32 +00:00
def finMapToEquiv (f1 : Fin n → Fin m) (f2 : Fin m → Fin n)
2024-10-22 18:05:38 +00:00
(h : ∀ x, f1 (f2 x) = x := by decide)
(h' : ∀ x, f2 (f1 x) = x := by decide) : Fin n ≃ Fin m where
toFun := f1
invFun := f2
left_inv := h'
right_inv := h
/-- Given two lists of indices returns the permutation between them based on `finMapToEquiv`. -/
def getPermutationSyntax (l1 l2 : List (TSyntax `indexExpr)) : TermElabM Term := do
let lPerm ← getPermutation l1 l2
2024-10-22 16:43:32 +00:00
let l2Perm ← getPermutation l2 l1
let permString := "![" ++ String.intercalate ", " (lPerm.map toString) ++ "]"
let perm2String := "![" ++ String.intercalate ", " (l2Perm.map toString) ++ "]"
let P1 ← TensorNode.stringToTerm permString
let P2 ← TensorNode.stringToTerm perm2String
let stx := Syntax.mkApp (mkIdent ``finMapToEquiv) #[P1, P2]
return stx
namespace negNode
/-- The syntax associated with a product of tensors. -/
def negSyntax (T1 : Term) : Term :=
Syntax.mkApp (mkIdent ``TensorTree.neg) #[T1]
end negNode
2024-10-19 10:07:03 +00:00
/-- Returns the full list of indices after contraction. TODO: Include evaluation. -/
partial def getIndicesFull (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $_:term | $[$args]*) => do
2024-10-22 16:43:32 +00:00
return (← TensorNode.withoutContr (← TensorNode.getIndices stx))
| `(tensorExpr| $_:tensorExpr ⊗ $_:tensorExpr) => do
return (← ProdNode.withoutContr stx)
| `(tensorExpr| ($a:tensorExpr)) => do
return (← getIndicesFull a)
| `(tensorExpr| -$a:tensorExpr) => do
return (← getIndicesFull a)
2024-10-22 16:43:32 +00:00
| `(tensorExpr| $_:term •ₜ $a) => do
return (← getIndicesFull a)
| `(tensorExpr| $a:tensorExpr + $_:tensorExpr) => do
return (← getIndicesFull a)
| _ =>
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
2024-10-22 16:43:32 +00:00
namespace SMul
2024-10-22 18:05:38 +00:00
/-- The syntax associated with the scalar multiplication of tensors. -/
2024-10-22 16:43:32 +00:00
def smulSyntax (c T : Term) : Term :=
Syntax.mkApp (mkIdent ``TensorTree.smul) #[c, T]
end SMul
namespace Add
/-- Gets the indices associated with the LHS of an addition. -/
partial def getIndicesLeft (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $a:tensorExpr + $_:tensorExpr) => do
return (← getIndicesFull a)
| _ =>
throwError "Unsupported tensor expression syntax in Add.getIndicesLeft: {stx}"
/-- Gets the indices associated with the RHS of an addition. -/
partial def getIndicesRight (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $_:tensorExpr + $a:tensorExpr) => do
return (← getIndicesFull a)
| _ =>
throwError "Unsupported tensor expression syntax in Add.getIndicesRight: {stx}"
/-- The syntax for a equality of tensor trees. -/
def addSyntax (permSyntax : Term) (T1 T2 : Term) : TermElabM Term := do
let P := Syntax.mkApp (mkIdent ``OverColor.equivToHomEq) #[permSyntax]
let RHS := Syntax.mkApp (mkIdent ``TensorTree.perm) #[P, T2]
return Syntax.mkApp (mkIdent ``add) #[T1, RHS]
end Add
namespace Equality
/-!
## For equality.
-/
/-- Gets the indices associated with the LHS of an equality. -/
partial def getIndicesLeft (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $a:tensorExpr = $_:tensorExpr) => do
return (← getIndicesFull a)
| _ =>
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
/-- Gets the indices associated with the RHS of an equality. -/
partial def getIndicesRight (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $_:tensorExpr = $a:tensorExpr) => do
return (← getIndicesFull a)
| _ =>
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
2024-10-19 10:50:38 +00:00
/-- The syntax for a equality of tensor trees. -/
def equalSyntax (permSyntax : Term) (T1 T2 : Term) : TermElabM Term := do
let X1 := Syntax.mkApp (mkIdent ``TensorTree.tensor) #[T1]
let P := Syntax.mkApp (mkIdent ``OverColor.equivToHomEq) #[permSyntax]
let X2' := Syntax.mkApp (mkIdent ``TensorTree.perm) #[P, T2]
let X2 := Syntax.mkApp (mkIdent ``TensorTree.tensor) #[X2']
return Syntax.mkApp (mkIdent ``Eq) #[X1, X2]
end Equality
/-- Creates the syntax associated with a tensor node. -/
partial def syntaxFull (stx : Syntax) : TermElabM Term := do
match stx with
2024-10-22 12:10:55 +00:00
| `(tensorExpr| $T:term | $[$args]*) =>
let indices ← TensorNode.getIndices stx
let rawIndex ← TensorNode.getNoIndicesExact T
if indices.length ≠ rawIndex then
throwError "The expected number of indices {rawIndex} does not match the tensor {T}."
let tensorNodeSyntax ← TensorNode.termNodeSyntax T
let evalSyntax := TensorNode.evalSyntax (← TensorNode.getEvalPos stx) tensorNodeSyntax
let contrSyntax := contrSyntax (← TensorNode.getContrPos stx) evalSyntax
return contrSyntax
| `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do
let prodSyntax := ProdNode.prodSyntax (← syntaxFull a) (← syntaxFull b)
let contrSyntax := contrSyntax (← ProdNode.getContrPos stx) prodSyntax
return contrSyntax
| `(tensorExpr| ($a:tensorExpr)) => do
return (← syntaxFull a)
| `(tensorExpr| -$a:tensorExpr) => do
return negNode.negSyntax (← syntaxFull a)
2024-10-22 16:43:32 +00:00
| `(tensorExpr| $c:term •ₜ $a:tensorExpr) => do
return SMul.smulSyntax c (← syntaxFull a)
| `(tensorExpr| $a + $b) => do
let indicesLeft ← Add.getIndicesLeft stx
let indicesRight ← Add.getIndicesRight stx
let permSyntax ← getPermutationSyntax indicesLeft indicesRight
let addSyntax ← Add.addSyntax permSyntax (← syntaxFull a) (← syntaxFull b)
return addSyntax
| `(tensorExpr| $a:tensorExpr = $b:tensorExpr) => do
let indicesLeft ← Equality.getIndicesLeft stx
let indicesRight ← Equality.getIndicesRight stx
let permSyntax ← getPermutationSyntax indicesLeft indicesRight
let equalSyntax ← Equality.equalSyntax permSyntax (← syntaxFull a) (← syntaxFull b)
return equalSyntax
| _ =>
throwError "Unsupported tensor expression syntax in elaborateTensorNode: {stx}"
/-- An elaborator for tensor nodes. This is to be generalized. -/
def elaborateTensorNode (stx : Syntax) : TermElabM Expr := do
let tensorExpr ← elabTerm (← syntaxFull stx) none
return tensorExpr
2024-10-08 07:52:55 +00:00
/-- Syntax turning a tensor expression into a term. -/
2024-10-08 07:26:23 +00:00
syntax (name := tensorExprSyntax) "{" tensorExpr "}ᵀ" : term
2024-10-08 05:53:16 +00:00
2024-10-08 07:26:23 +00:00
elab_rules (kind:=tensorExprSyntax) : term
| `(term| {$e:tensorExpr}ᵀ) => do
let tensorTree ← elaborateTensorNode e
return tensorTree
2024-10-08 05:53:16 +00:00
end TensorTree