604 lines
24 KiB
Text
604 lines
24 KiB
Text
/-
|
||
Copyright (c) 2024 Joseph Tooby-Smith. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Joseph Tooby-Smith
|
||
-/
|
||
import HepLean.Lorentz.ComplexTensor.Basic
|
||
/-!
|
||
|
||
# Elaboration of tensor trees
|
||
|
||
- Syntax in Lean allows us to represent tensor expressions in a way close to what we expect to
|
||
see on pen-and-paper.
|
||
- The elaborator turns this syntax into a tensor tree.
|
||
|
||
## Examples
|
||
|
||
- Suppose `T` and `T'` are tensors `S.F (OverColor.mk ![c1, c2])`.
|
||
- `{T | μ ν}ᵀ` is `tensorNode T`.
|
||
- We can also write e.g. `{T | μ ν}ᵀ.tensor` to get the tensor itself.
|
||
- `{- T | μ ν}ᵀ` is `neg (tensorNode T)`.
|
||
- `{T | 0 ν}ᵀ` is `eval 0 0 (tensorNode T)`.
|
||
- `{T | μ ν + T' | μ ν}ᵀ` is `addNode (tensorNode T) (perm _ (tensorNode T'))`, where
|
||
here `_` will be the identity permutation so does nothing.
|
||
- `{T | μ ν = T' | μ ν}ᵀ` is `(tensorNode T).tensor = (perm _ (tensorNode T')).tensor`.
|
||
- If `a ∈ S.k` then `{a •ₜ T | μ ν}ᵀ` is `smulNode a (tensorNode T)`.
|
||
- If `g ∈ S.G` then `{g •ₐ T | μ ν}ᵀ` is `actionNode g (tensorNode T)`.
|
||
- Suppose `T2` is a tensor `S.F (OverColor.mk ![c3])`.
|
||
Then `{T | μ ν ⊗ T2 | σ}ᵀ` is `prodNode (tensorNode T1) (tensorNode T2)`.
|
||
- If `T3` is a tensor `S.F (OverColor.mk ![S.τ c1, S.τ c2])`, then
|
||
`{T | μ ν ⊗ T3 | μ σ}ᵀ` is `contr 0 1 _ (prodNode (tensorNode T1) (tensorNode T3))`.
|
||
`{T | μ ν ⊗ T3 | μ ν }ᵀ` is
|
||
`contr 0 0 _ (contr 0 1 _ (prodNode (tensorNode T1) (tensorNode T3)))`.
|
||
- If `T4` is a tensor `S.F (OverColor.mk ![c2, c1])` then
|
||
`{T | μ ν + T4 | ν μ }ᵀ`is `addNode (tensorNode T) (perm _ (tensorNode T4))` where `_`
|
||
is the permutation of the two indices of `T4`.
|
||
`{T | μ ν = T4 | ν μ }ᵀ` is `(tensorNode T).tensor = (perm _ (tensorNode T4)).tensor` is the
|
||
permutation of the two indices of `T4`.
|
||
|
||
## Comments
|
||
|
||
- In all of theses expressions `μ`, `ν` etc are free. It does not matter what they are called,
|
||
Lean will elaborate them in the same way. I.e. `{T | μ ν ⊗ T3 | μ ν }ᵀ` is exactly the same
|
||
to Lean as `{T | α β ⊗ T3 | α β }ᵀ`.
|
||
- Note that compared to ordinary index notation, we do not rise or lower the indices.
|
||
This is for two reasons: 1) It is difficult to make this general for all tensor species,
|
||
2) It is a reduency in ordinary index notation, since the tensor `T` itself already tells you
|
||
this information.
|
||
|
||
-/
|
||
open Lean
|
||
open Lean.Elab.Term
|
||
|
||
open Lean
|
||
open Lean.Meta
|
||
open Lean.Elab
|
||
open Lean.Elab.Term
|
||
open Lean Meta Elab Tactic
|
||
open IndexNotation
|
||
open complexLorentzTensor
|
||
namespace TensorTree
|
||
|
||
/-!
|
||
|
||
## Indexies
|
||
|
||
-/
|
||
|
||
/-- A syntax category for indices of tensor expressions. -/
|
||
declare_syntax_cat indexExpr
|
||
|
||
/-- A basic index is a ident. -/
|
||
syntax ident : indexExpr
|
||
|
||
/-- An index can be a num, which will be used to evaluate the tensor. -/
|
||
syntax num : indexExpr
|
||
|
||
/-- Notation to discribe the jiggle of a tensor index. -/
|
||
syntax "τ(" ident ")" : indexExpr
|
||
|
||
/-- Bool which is ture if an index is a num. -/
|
||
def indexExprIsNum (stx : Syntax) : Bool :=
|
||
match stx with
|
||
| `(indexExpr|$_:num) => true
|
||
| _ => false
|
||
|
||
/-- If an index is a num - the undelrying natural number. -/
|
||
def indexToNum (stx : Syntax) : TermElabM Nat :=
|
||
match stx with
|
||
| `(indexExpr|$a:num) =>
|
||
match a.raw.isNatLit? with
|
||
| some n => return n
|
||
| none => throwError "Expected a natural number literal."
|
||
| _ =>
|
||
throwError "Unsupported tensor expression syntax in indexToNum: {stx}"
|
||
|
||
/-- When an index is not a num, the corresponding ident. -/
|
||
def indexToIdent (stx : Syntax) : TermElabM Ident :=
|
||
match stx with
|
||
| `(indexExpr|$a:ident) => return a
|
||
| `(indexExpr| τ($a:ident)) => return a
|
||
| _ =>
|
||
throwError "Unsupported tensor expression syntax in indexToIdent: {stx}"
|
||
|
||
/-- Takes a pair ``a b : ℕ × TSyntax `indexExpr``. If `a.1 < b.1` and `a.2 = b.2` then
|
||
outputs `some (a.1, b.1)`, otherwise `none`. -/
|
||
def indexPosEq (a b : ℕ × TSyntax `indexExpr) : TermElabM (Option (ℕ × ℕ)) := do
|
||
let a' ← indexToIdent a.2
|
||
let b' ← indexToIdent b.2
|
||
if a.1 < b.1 ∧ Lean.TSyntax.getId a' = Lean.TSyntax.getId b' then
|
||
return some (a.1, b.1)
|
||
else
|
||
return none
|
||
|
||
/-- Bool which is true if an index is of the form τ(i) that is, to be dualed. -/
|
||
def indexToDual (stx : Syntax) : Bool :=
|
||
match stx with
|
||
| `(indexExpr| τ($_)) => true
|
||
| _ => false
|
||
|
||
/-!
|
||
|
||
## Tensor expressions
|
||
|
||
-/
|
||
|
||
/-- A syntax category for tensor expressions. -/
|
||
declare_syntax_cat tensorExpr
|
||
|
||
/-- The syntax for a tensor node. -/
|
||
syntax term "|" (ppSpace indexExpr)* : tensorExpr
|
||
|
||
/-- Equality. -/
|
||
syntax:40 tensorExpr "=" tensorExpr:41 : tensorExpr
|
||
|
||
/-- The syntax for tensor prod two tensor nodes. -/
|
||
syntax:70 tensorExpr "⊗" tensorExpr:71 : tensorExpr
|
||
|
||
/-- The syntax for tensor addition. -/
|
||
syntax tensorExpr "+" tensorExpr : tensorExpr
|
||
|
||
/-- Allowing brackets to be used in a tensor expression. -/
|
||
syntax "(" tensorExpr ")" : tensorExpr
|
||
|
||
/-- Scalar multiplication for tensors. -/
|
||
syntax term "•ₜ" tensorExpr : tensorExpr
|
||
|
||
/-- group action for tensors. -/
|
||
syntax term "•ₐ" tensorExpr : tensorExpr
|
||
|
||
/-- Negation of a tensor tree. -/
|
||
syntax "-" tensorExpr : tensorExpr
|
||
|
||
namespace TensorNode
|
||
|
||
/-!
|
||
|
||
## For tensor nodes.
|
||
|
||
The operations are done in the following order:
|
||
- evaluation.
|
||
- dualization.
|
||
- contraction.
|
||
|
||
We also want to ensure the number of indices is correct.
|
||
|
||
-/
|
||
|
||
/-- The indices of a tensor node. Before contraction, and evaluation. -/
|
||
partial def getIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||
match stx with
|
||
| `(tensorExpr| $_:term | $[$args]*) => do
|
||
let indices ← args.toList.mapM fun arg => do
|
||
match arg with
|
||
| `(indexExpr|$t:indexExpr) => pure t
|
||
return indices
|
||
| _ =>
|
||
throwError "Unsupported tensor expression syntax in getIndicesNode: {stx}"
|
||
|
||
/-- Uses the structure of the tensor to get the number of indices. -/
|
||
def getNoIndicesExact (stx : Syntax) : TermElabM ℕ := do
|
||
let expr ← elabTerm stx none
|
||
let type ← inferType expr
|
||
let strType := toString type
|
||
let n := (String.splitOn strType "CategoryTheory.MonoidalCategoryStruct.tensorObj").length
|
||
match n with
|
||
| 1 =>
|
||
match type with
|
||
| Expr.app _ (Expr.app _ (Expr.app _ c)) =>
|
||
let typeC ← inferType c
|
||
match typeC with
|
||
| Expr.forallE _ (Expr.app _ a) _ _ =>
|
||
let a' ← whnf a
|
||
match a' with
|
||
| Expr.lit (Literal.natVal n) => return n
|
||
|_ => throwError s!"Could not extract number of indices from tensor
|
||
{stx} (getNoIndicesExact). "
|
||
| _ => throwError s!"Could not extract number of indices from tensor
|
||
{stx} (getNoIndicesExact). "
|
||
| _ => return 1
|
||
| k => return k
|
||
|
||
/-- The construction of an expression corresponding to the type of a given string once parsed. -/
|
||
def stringToType (str : String) : TermElabM (Option Expr) := do
|
||
let env ← getEnv
|
||
let stx := Parser.runParserCategory env `term str
|
||
match stx with
|
||
| Except.error _ => return none
|
||
| Except.ok stx => return (some (← elabTerm stx none))
|
||
|
||
/-- The construction of an expression corresponding to the type of a given string once parsed. -/
|
||
def stringToTerm (str : String) : TermElabM Term := do
|
||
let env ← getEnv
|
||
let stx := Parser.runParserCategory env `term str
|
||
match stx with
|
||
| Except.error _ => throwError "Could not create type from string (stringToTerm). "
|
||
| Except.ok stx =>
|
||
match stx with
|
||
| `(term| $e) => return e
|
||
|
||
/-- Specific types of tensors which appear which we want to elaborate in specific ways. -/
|
||
def specialTypes : List (String × (Term → Term)) := [
|
||
("CoeSort.coe Lorentz.complexCo", fun T =>
|
||
Syntax.mkApp (mkIdent ``TensorTree.vecNodeE) #[mkIdent ``complexLorentzTensor,
|
||
mkIdent ``complexLorentzTensor.Color.down, T]),
|
||
("CoeSort.coe Lorentz.complexContr", fun T =>
|
||
Syntax.mkApp (mkIdent ``TensorTree.vecNodeE) #[mkIdent ``complexLorentzTensor,
|
||
mkIdent ``complexLorentzTensor.Color.up, T]),
|
||
("ModuleCat.carrier (Lorentz.complexContr ⊗ Lorentz.complexCo).V", fun T =>
|
||
Syntax.mkApp (mkIdent ``TensorTree.twoNodeE) #[mkIdent ``complexLorentzTensor,
|
||
mkIdent ``complexLorentzTensor.Color.up, mkIdent ``complexLorentzTensor.Color.down, T]),
|
||
("ModuleCat.carrier (Lorentz.complexContr ⊗ Lorentz.complexContr).V", fun T =>
|
||
Syntax.mkApp (mkIdent ``TensorTree.twoNodeE) #[mkIdent ``complexLorentzTensor,
|
||
mkIdent ``complexLorentzTensor.Color.up, mkIdent ``complexLorentzTensor.Color.up, T]),
|
||
("ModuleCat.carrier (Lorentz.complexCo ⊗ Lorentz.complexCo).V", fun T =>
|
||
Syntax.mkApp (mkIdent ``TensorTree.twoNodeE) #[mkIdent ``complexLorentzTensor,
|
||
mkIdent ``complexLorentzTensor.Color.down, mkIdent ``complexLorentzTensor.Color.down, T]),
|
||
("ModuleCat.carrier (Lorentz.complexCo ⊗ Lorentz.complexContr).V", fun T =>
|
||
Syntax.mkApp (mkIdent ``TensorTree.twoNodeE) #[
|
||
mkIdent ``complexLorentzTensor,
|
||
mkIdent ``complexLorentzTensor.Color.down,
|
||
mkIdent ``complexLorentzTensor.Color.up, T]),
|
||
("𝟙_ (Rep ℂ SL(2, ℂ)) ⟶ Lorentz.complexCo ⊗ Lorentz.complexCo", fun T =>
|
||
Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
|
||
mkIdent ``complexLorentzTensor,
|
||
mkIdent ``complexLorentzTensor.Color.down,
|
||
mkIdent ``complexLorentzTensor.Color.down, T]),
|
||
("𝟙_ (Rep ℂ SL(2, ℂ)) ⟶ Lorentz.complexContr ⊗ Lorentz.complexContr", fun T =>
|
||
Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
|
||
mkIdent ``complexLorentzTensor,
|
||
mkIdent ``complexLorentzTensor.Color.up,
|
||
mkIdent ``complexLorentzTensor.Color.up, T]),
|
||
("𝟙_ (Rep ℂ SL(2, ℂ)) ⟶ Lorentz.complexContr ⊗ Fermion.leftHanded ⊗ Fermion.rightHanded", fun T =>
|
||
Syntax.mkApp (mkIdent ``TensorTree.constThreeNodeE) #[
|
||
mkIdent ``complexLorentzTensor, mkIdent ``complexLorentzTensor.Color.up,
|
||
mkIdent ``complexLorentzTensor.Color.upL,
|
||
mkIdent ``complexLorentzTensor.Color.upR, T]),
|
||
("𝟙_ (Rep ℂ SL(2, ℂ)) ⟶ Fermion.leftHanded ⊗ Fermion.leftHanded", fun T =>
|
||
Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
|
||
mkIdent ``complexLorentzTensor,
|
||
mkIdent ``complexLorentzTensor.Color.upL,
|
||
mkIdent ``complexLorentzTensor.Color.upL, T]),
|
||
("𝟙_ (Rep ℂ SL(2, ℂ)) ⟶ Fermion.altLeftHanded ⊗ Fermion.altLeftHanded", fun T =>
|
||
Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
|
||
mkIdent ``complexLorentzTensor,
|
||
mkIdent ``complexLorentzTensor.Color.downL,
|
||
mkIdent ``complexLorentzTensor.Color.downL, T]),
|
||
("𝟙_ (Rep ℂ SL(2, ℂ)) ⟶ Fermion.altRightHanded ⊗ Fermion.altRightHanded", fun T =>
|
||
Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
|
||
mkIdent ``complexLorentzTensor,
|
||
mkIdent ``complexLorentzTensor.Color.downR,
|
||
mkIdent ``complexLorentzTensor.Color.downR, T]),
|
||
("𝟙_ (Rep ℂ SL(2, ℂ)) ⟶ Fermion.rightHanded ⊗ Fermion.rightHanded", fun T =>
|
||
Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
|
||
mkIdent ``complexLorentzTensor,
|
||
mkIdent ``complexLorentzTensor.Color.upR,
|
||
mkIdent ``complexLorentzTensor.Color.upR, T])]
|
||
|
||
/-- The syntax associated with a terminal node of a tensor tree. -/
|
||
def termNodeSyntax (T : Term) : TermElabM Term := do
|
||
let expr ← elabTerm T none
|
||
let type ← inferType expr
|
||
let defEqList ← specialTypes.filterM (fun x => do
|
||
let type' ← stringToType x.1
|
||
match type' with
|
||
| none => return false
|
||
| some type' =>
|
||
let defEq ← isDefEq type type'
|
||
return defEq)
|
||
match defEqList with
|
||
| [(_, f)] =>
|
||
return f T
|
||
| _ =>
|
||
match type with
|
||
| Expr.app _ (Expr.app _ (Expr.app _ _)) =>
|
||
return Syntax.mkApp (mkIdent ``TensorTree.tensorNode) #[T]
|
||
| _ => return Syntax.mkApp (mkIdent ``TensorTree.vecNode) #[T]
|
||
|
||
/-- Adjusts a list `List ℕ` by subtracting from each natrual number the number
|
||
of elements before it in the list which are less then itself. This is used
|
||
to form a list of pairs which can be used for evaluating indices. -/
|
||
def evalAdjustPos (l : List ℕ) : List ℕ :=
|
||
let l' := List.mapAccumr
|
||
(fun x (prev : List ℕ) =>
|
||
let e := prev.countP (fun y => y < x)
|
||
(x :: prev, x - e)) l.reverse []
|
||
l'.2.reverse
|
||
|
||
/-- The positions in getIndicesNode which get evaluated, and the value they take. -/
|
||
partial def getEvalPos (stx : Syntax) : TermElabM (List (ℕ × ℕ)) := do
|
||
let ind ← getIndices stx
|
||
let indEnum := ind.enum
|
||
let evals := indEnum.filter (fun x => indexExprIsNum x.2)
|
||
let evals2 ← (evals.mapM (fun x => indexToNum x.2))
|
||
let pos := evalAdjustPos (evals.map (fun x => x.1))
|
||
return List.zip pos evals2
|
||
|
||
/-- For each element of `l : List (ℕ × ℕ)` applies `TensorTree.eval` to the given term. -/
|
||
def evalSyntax (l : List (ℕ × ℕ)) (T : Term) : Term :=
|
||
l.foldl (fun T' (x1, x2) => Syntax.mkApp (mkIdent ``TensorTree.eval)
|
||
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x2), T']) T
|
||
|
||
/-- The pairs of positions in getIndicesNode which get contracted. -/
|
||
partial def getContrPos (stx : Syntax) : TermElabM (List (ℕ × ℕ)) := do
|
||
let ind ← getIndices stx
|
||
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
|
||
let indEnum := indFilt.enum
|
||
let bind := List.flatMap indEnum (fun a => indEnum.map (fun b => (a, b)))
|
||
let filt ← bind.filterMapM (fun x => indexPosEq x.1 x.2)
|
||
if ¬ ((filt.map Prod.fst).Nodup ∧ (filt.map Prod.snd).Nodup) then
|
||
throwError "To many contractions"
|
||
return filt
|
||
|
||
/-- The list of indices after contraction or evaluation. -/
|
||
def withoutContr (ind : List (TSyntax `indexExpr)) : TermElabM (List (TSyntax `indexExpr)) := do
|
||
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
|
||
return indFilt.filter (fun x => indFilt.count x ≤ 1)
|
||
|
||
end TensorNode
|
||
|
||
/-- Takes a list and puts conseutive elements into pairs.
|
||
e.g. [0, 1, 2, 3] becomes [(0, 1), (2, 3)]. -/
|
||
def toPairs (l : List ℕ) : List (ℕ × ℕ) :=
|
||
match l with
|
||
| x1 :: x2 :: xs => (x1, x2) :: toPairs xs
|
||
| [] => []
|
||
| [x] => [(x, 0)]
|
||
|
||
/-- Adjusts a list `List (ℕ × ℕ)` by subtracting from each natrual number the number
|
||
of elements before it in the list which are less then itself. This is used
|
||
to form a list of pairs which can be used for contracting indices. -/
|
||
def contrListAdjust (l : List (ℕ × ℕ)) : List (ℕ × ℕ) :=
|
||
let l' := l.flatMap (fun p => [p.1, p.2])
|
||
let l'' := List.mapAccumr
|
||
(fun x (prev : List ℕ) =>
|
||
let e := prev.countP (fun y => y < x)
|
||
(x :: prev, x - e)) l'.reverse []
|
||
toPairs l''.2.reverse
|
||
|
||
/-- For each element of `l : List (ℕ × ℕ)` applies `TensorTree.contr` to the given term. -/
|
||
def contrSyntax (l : List (ℕ × ℕ)) (T : Term) : Term :=
|
||
(contrListAdjust l).foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr)
|
||
#[Syntax.mkNumLit (toString x0),
|
||
Syntax.mkNumLit (toString x1), mkIdent ``rfl, T']) T
|
||
|
||
namespace ProdNode
|
||
|
||
/-!
|
||
|
||
## For product nodes.
|
||
|
||
For a product node we can take the tensor product, and then contract the indices.
|
||
|
||
-/
|
||
|
||
/-- Gets the indices associated with a product node. -/
|
||
partial def getIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||
match stx with
|
||
| `(tensorExpr| $_:term | $[$args]*) => do
|
||
return (← TensorNode.withoutContr (← TensorNode.getIndices stx))
|
||
| `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do
|
||
let indicesA ← TensorNode.withoutContr (← getIndices a)
|
||
let indicesB ← TensorNode.withoutContr (← getIndices b)
|
||
return indicesA ++ indicesB
|
||
| `(tensorExpr| ($a:tensorExpr)) => do
|
||
return (← getIndices a)
|
||
| _ =>
|
||
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
|
||
|
||
/-- The pairs of positions in getIndicesNode which get contracted. -/
|
||
partial def getContrPos (stx : Syntax) : TermElabM (List (ℕ × ℕ)) := do
|
||
let ind ← getIndices stx
|
||
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
|
||
let indEnum := indFilt.enum
|
||
let bind := List.flatMap indEnum (fun a => indEnum.map (fun b => (a, b)))
|
||
let filt ← bind.filterMapM (fun x => indexPosEq x.1 x.2)
|
||
if ¬ ((filt.map Prod.fst).Nodup ∧ (filt.map Prod.snd).Nodup) then
|
||
throwError "To many contractions"
|
||
return filt
|
||
|
||
/-- The list of indices after contraction. -/
|
||
def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||
let ind ← getIndices stx
|
||
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
|
||
return ind.filter (fun x => indFilt.count x ≤ 1)
|
||
|
||
/-- The syntax associated with a product of tensors. -/
|
||
def prodSyntax (T1 T2 : Term) : Term :=
|
||
Syntax.mkApp (mkIdent ``TensorTree.prod) #[T1, T2]
|
||
|
||
end ProdNode
|
||
|
||
/-!
|
||
|
||
## Permutation constructions
|
||
|
||
-/
|
||
/-- Given two lists of indices returns the `List (ℕ)` representing the how one list
|
||
permutes into the other. -/
|
||
def getPermutation (l1 l2 : List (TSyntax `indexExpr)) : TermElabM (List (ℕ)) := do
|
||
let l1' ← l1.mapM (fun x => indexToIdent x)
|
||
let l2' ← l2.mapM (fun x => indexToIdent x)
|
||
let l1enum := l1'.enum
|
||
let l2'' := l2'.filterMap
|
||
(fun x => l1enum.find? (fun y => Lean.TSyntax.getId y.2 = Lean.TSyntax.getId x))
|
||
return l2''.map fun x => x.1
|
||
|
||
open HepLean.Fin
|
||
|
||
/-- Given two lists of indices returns the permutation between them based on `finMapToEquiv`. -/
|
||
def getPermutationSyntax (l1 l2 : List (TSyntax `indexExpr)) : TermElabM Term := do
|
||
let lPerm ← getPermutation l1 l2
|
||
let l2Perm ← getPermutation l2 l1
|
||
let permString := "![" ++ String.intercalate ", " (lPerm.map toString) ++ "]"
|
||
let perm2String := "![" ++ String.intercalate ", " (l2Perm.map toString) ++ "]"
|
||
let P1 ← TensorNode.stringToTerm permString
|
||
let P2 ← TensorNode.stringToTerm perm2String
|
||
let stx := Syntax.mkApp (mkIdent ``finMapToEquiv) #[P1, P2]
|
||
return stx
|
||
|
||
namespace negNode
|
||
|
||
/-- The syntax associated with a product of tensors. -/
|
||
def negSyntax (T1 : Term) : Term :=
|
||
Syntax.mkApp (mkIdent ``TensorTree.neg) #[T1]
|
||
|
||
end negNode
|
||
|
||
/-- Returns the full list of indices after contraction. TODO: Include evaluation. -/
|
||
partial def getIndicesFull (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||
match stx with
|
||
| `(tensorExpr| $_:term | $[$args]*) => do
|
||
return (← TensorNode.withoutContr (← TensorNode.getIndices stx))
|
||
| `(tensorExpr| $_:tensorExpr ⊗ $_:tensorExpr) => do
|
||
return (← ProdNode.withoutContr stx)
|
||
| `(tensorExpr| ($a:tensorExpr)) => do
|
||
return (← getIndicesFull a)
|
||
| `(tensorExpr| -$a:tensorExpr) => do
|
||
return (← getIndicesFull a)
|
||
| `(tensorExpr| $_:term •ₜ $a) => do
|
||
return (← getIndicesFull a)
|
||
| `(tensorExpr| $a:tensorExpr + $_:tensorExpr) => do
|
||
return (← getIndicesFull a)
|
||
| _ =>
|
||
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
|
||
|
||
namespace SMul
|
||
|
||
/-- The syntax associated with the scalar multiplication of tensors. -/
|
||
def smulSyntax (c T : Term) : Term :=
|
||
Syntax.mkApp (mkIdent ``TensorTree.smul) #[c, T]
|
||
|
||
end SMul
|
||
|
||
namespace Action
|
||
|
||
/-- The syntax associated with the group action of tensors. -/
|
||
def actionSyntax (c T : Term) : Term :=
|
||
Syntax.mkApp (mkIdent ``TensorTree.action) #[c, T]
|
||
|
||
end Action
|
||
|
||
namespace Add
|
||
|
||
/-- Gets the indices associated with the LHS of an addition. -/
|
||
partial def getIndicesLeft (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||
match stx with
|
||
| `(tensorExpr| $a:tensorExpr + $_:tensorExpr) => do
|
||
return (← getIndicesFull a)
|
||
| _ =>
|
||
throwError "Unsupported tensor expression syntax in Add.getIndicesLeft: {stx}"
|
||
|
||
/-- Gets the indices associated with the RHS of an addition. -/
|
||
partial def getIndicesRight (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||
match stx with
|
||
| `(tensorExpr| $_:tensorExpr + $a:tensorExpr) => do
|
||
return (← getIndicesFull a)
|
||
| _ =>
|
||
throwError "Unsupported tensor expression syntax in Add.getIndicesRight: {stx}"
|
||
|
||
/-- The syntax for a equality of tensor trees. -/
|
||
def addSyntax (permSyntax : Term) (T1 T2 : Term) : TermElabM Term := do
|
||
let P := Syntax.mkApp (mkIdent ``OverColor.equivToHomEq) #[permSyntax]
|
||
let RHS := Syntax.mkApp (mkIdent ``TensorTree.perm) #[P, T2]
|
||
return Syntax.mkApp (mkIdent ``add) #[T1, RHS]
|
||
|
||
end Add
|
||
|
||
namespace Equality
|
||
|
||
/-!
|
||
|
||
## For equality.
|
||
|
||
-/
|
||
|
||
/-- Gets the indices associated with the LHS of an equality. -/
|
||
partial def getIndicesLeft (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||
match stx with
|
||
| `(tensorExpr| $a:tensorExpr = $_:tensorExpr) => do
|
||
return (← getIndicesFull a)
|
||
| _ =>
|
||
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
|
||
|
||
/-- Gets the indices associated with the RHS of an equality. -/
|
||
partial def getIndicesRight (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||
match stx with
|
||
| `(tensorExpr| $_:tensorExpr = $a:tensorExpr) => do
|
||
return (← getIndicesFull a)
|
||
| _ =>
|
||
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
|
||
|
||
/-- The syntax for a equality of tensor trees. -/
|
||
def equalSyntax (permSyntax : Term) (T1 T2 : Term) : TermElabM Term := do
|
||
let X1 := Syntax.mkApp (mkIdent ``TensorTree.tensor) #[T1]
|
||
let P := Syntax.mkApp (mkIdent ``OverColor.equivToHomEq) #[permSyntax]
|
||
let X2' := Syntax.mkApp (mkIdent ``TensorTree.perm) #[P, T2]
|
||
let X2 := Syntax.mkApp (mkIdent ``TensorTree.tensor) #[X2']
|
||
return Syntax.mkApp (mkIdent ``Eq) #[X1, X2]
|
||
|
||
end Equality
|
||
|
||
/-- Creates the syntax associated with a tensor node. -/
|
||
partial def syntaxFull (stx : Syntax) : TermElabM Term := do
|
||
match stx with
|
||
| `(tensorExpr| $T:term | $[$args]*) =>
|
||
let indices ← TensorNode.getIndices stx
|
||
let rawIndex ← TensorNode.getNoIndicesExact T
|
||
if indices.length ≠ rawIndex then
|
||
throwError "The expected number of indices {rawIndex} does not match the tensor {T}."
|
||
let tensorNodeSyntax ← TensorNode.termNodeSyntax T
|
||
let evalSyntax := TensorNode.evalSyntax (← TensorNode.getEvalPos stx) tensorNodeSyntax
|
||
let contrSyntax := contrSyntax (← TensorNode.getContrPos stx) evalSyntax
|
||
return contrSyntax
|
||
| `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do
|
||
let prodSyntax := ProdNode.prodSyntax (← syntaxFull a) (← syntaxFull b)
|
||
let contrSyntax := contrSyntax (← ProdNode.getContrPos stx) prodSyntax
|
||
return contrSyntax
|
||
| `(tensorExpr| ($a:tensorExpr)) => do
|
||
return (← syntaxFull a)
|
||
| `(tensorExpr| -$a:tensorExpr) => do
|
||
return negNode.negSyntax (← syntaxFull a)
|
||
| `(tensorExpr| $c:term •ₜ $a:tensorExpr) => do
|
||
return SMul.smulSyntax c (← syntaxFull a)
|
||
| `(tensorExpr| $c:term •ₐ $a:tensorExpr) => do
|
||
return Action.actionSyntax c (← syntaxFull a)
|
||
| `(tensorExpr| $a + $b) => do
|
||
let indicesLeft ← Add.getIndicesLeft stx
|
||
let indicesRight ← Add.getIndicesRight stx
|
||
let permSyntax ← getPermutationSyntax indicesLeft indicesRight
|
||
let addSyntax ← Add.addSyntax permSyntax (← syntaxFull a) (← syntaxFull b)
|
||
return addSyntax
|
||
| `(tensorExpr| $a:tensorExpr = $b:tensorExpr) => do
|
||
let indicesLeft ← Equality.getIndicesLeft stx
|
||
let indicesRight ← Equality.getIndicesRight stx
|
||
let permSyntax ← getPermutationSyntax indicesLeft indicesRight
|
||
let equalSyntax ← Equality.equalSyntax permSyntax (← syntaxFull a) (← syntaxFull b)
|
||
return equalSyntax
|
||
| _ =>
|
||
throwError "Unsupported tensor expression syntax in elaborateTensorNode: {stx}"
|
||
|
||
/-- An elaborator for tensor nodes. This is to be generalized. -/
|
||
def elaborateTensorNode (stx : Syntax) : TermElabM Expr := do
|
||
let tensorExpr ← elabTerm (← syntaxFull stx) none
|
||
return tensorExpr
|
||
|
||
/-- Syntax turning a tensor expression into a term. -/
|
||
syntax (name := tensorExprSyntax) "{" tensorExpr "}ᵀ" : term
|
||
|
||
elab_rules (kind:=tensorExprSyntax) : term
|
||
| `(term| {$e:tensorExpr}ᵀ) => do
|
||
let tensorTree ← elaborateTensorNode e
|
||
return tensorTree
|
||
/-!
|
||
|
||
## Test cases
|
||
|
||
-/
|
||
|
||
/-
|
||
variable {S : TensorSpecies} {c : Fin (Nat.succ (Nat.succ 0)) → S.C} {t : S.F.obj (OverColor.mk c)}
|
||
|
||
#check {t | α β}ᵀ
|
||
-/
|
||
end TensorTree
|