refactor: Index notation

This commit is contained in:
jstoobysmith 2024-10-16 16:38:36 +00:00
parent d9f6760541
commit ec69deaff2
12 changed files with 299 additions and 865 deletions

View file

@ -6,6 +6,7 @@ Authors: Joseph Tooby-Smith
import HepLean.Tensors.Tree.Basic
import Lean.Elab.Term
import HepLean.Tensors.Tree.Dot
import HepLean.Tensors.ComplexLorentz.Basic
/-!
## Elaboration of tensor trees
@ -82,6 +83,7 @@ def indexToDual (stx : Syntax) : Bool :=
match stx with
| `(indexExpr| τ($_)) => true
| _ => false
/-!
## Tensor expressions
@ -136,15 +138,72 @@ partial def getIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) :=
def getNoIndicesExact (stx : Syntax) : TermElabM := do
let expr ← elabTerm stx none
let type ← inferType expr
match type with
| Expr.app _ (Expr.app _ (Expr.app _ c)) =>
let typeC ← inferType c
match typeC with
| Expr.forallE _ (Expr.app _ (Expr.app (Expr.app _ (Expr.lit (Literal.natVal n))) _)) _ _ =>
return n
| _ => throwError "Could not extract number of indices from tensor (getNoIndicesExact). "
| _ =>
throwError "Could not extract number of indices from tensor (getNoIndicesExact)."
let strType := toString type
let n := (String.splitOn strType "CategoryTheory.MonoidalCategoryStruct.tensorObj").length
match n with
| 1 =>
match type with
| Expr.app _ (Expr.app _ (Expr.app _ c)) =>
let typeC ← inferType c
match typeC with
| Expr.forallE _ (Expr.app _ (Expr.app (Expr.app _ (Expr.lit (Literal.natVal n))) _)) _ _ =>
return n
| _ => throwError "Could not extract number of indices from tensor (getNoIndicesExact). "
| _ => return 1
| k => return k
/-- The construction of an expression corresponding to the type of a given string once parsed. -/
def stringToType (str : String) : TermElabM Expr := do
let env ← getEnv
let stx := Parser.runParserCategory env `term str
match stx with
| Except.error _ => throwError "Could not create type from string (stringToType). "
| Except.ok stx => elabTerm stx none
/-- The syntax associated with a terminal node of a tensor tree. -/
def termNodeSyntax (T : Term) : TermElabM Term := do
let expr ← elabTerm T none
let type ← inferType expr
let strType := toString type
let n := (String.splitOn strType "CategoryTheory.MonoidalCategoryStruct.tensorObj").length
let const := (String.splitOn strType "Quiver.Hom").length
match n, const with
| 1, 1 =>
match type with
| Expr.app _ (Expr.app _ (Expr.app _ c)) =>
let typeC ← inferType c
match typeC with
| Expr.forallE _ (Expr.app _ (Expr.app (Expr.app _ (Expr.lit (Literal.natVal _))) _)) _ _ =>
return Syntax.mkApp (mkIdent ``TensorTree.tensorNode) #[T]
| _ => throwError "Could not create terminal node syntax (termNodeSyntax). "
| _ => return Syntax.mkApp (mkIdent ``TensorTree.vecNode) #[T]
| 2, 1 =>
match ← isDefEq type (← stringToType "CoeSort.coe leftHanded ⊗ CoeSort.coe Lorentz.complexContr") with
| true => return Syntax.mkApp (mkIdent ``TensorTree.twoNodeE)
#[mkIdent ``Fermion.complexLorentzTensor, mkIdent ``Fermion.Color.upL, mkIdent ``Fermion.Color.up, T]
| _ => return Syntax.mkApp (mkIdent ``TensorTree.twoNode) #[T]
| 3, 1 => return Syntax.mkApp (mkIdent ``TensorTree.threeNode) #[T]
| 1, 2 => return Syntax.mkApp (mkIdent ``TensorTree.constVecNode) #[T]
| 2, 2 =>
match ← isDefEq type (← stringToType
"𝟙_ (Rep SL(2, )) ⟶ Lorentz.complexCo ⊗ Lorentz.complexCo") with
| true =>
println! "here"
return Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
mkIdent ``Fermion.complexLorentzTensor, mkIdent ``Fermion.Color.down,
mkIdent ``Fermion.Color.down, T]
| _ => return Syntax.mkApp (mkIdent ``TensorTree.constTwoNode) #[T]
| 3, 2 =>
/- Specific types. -/
match ← isDefEq type (← stringToType
"𝟙_ (Rep SL(2, )) ⟶ Lorentz.complexContr ⊗ Fermion.leftHanded ⊗ Fermion.rightHanded") with
| true =>
return Syntax.mkApp (mkIdent ``TensorTree.constThreeNodeE) #[
mkIdent ``Fermion.complexLorentzTensor, mkIdent ``Fermion.Color.up,
mkIdent ``Fermion.Color.upL, mkIdent ``Fermion.Color.upR, T]
| _ =>
return Syntax.mkApp (mkIdent ``TensorTree.constThreeNode) #[T]
| _, _ => throwError "Could not create terminal node syntax (termNodeSyntax). "
/-- The positions in getIndicesNode which get evaluated, and the value they take. -/
partial def getEvalPos (stx : Syntax) : TermElabM (List ( × )) := do
@ -159,19 +218,6 @@ def evalSyntax (l : List ( × )) (T : Term) : Term :=
l.foldl (fun T' (x1, x2) => Syntax.mkApp (mkIdent ``TensorTree.eval)
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x2), T']) T
/-- The positions in getIndicesNode which get dualized. -/
partial def getDualPos (stx : Syntax) : TermElabM (List ) := do
let ind ← getIndices stx
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
let indEnum := indFilt.enum
let duals := indEnum.filter (fun x => indexToDual x.2)
return duals.map (fun x => x.1)
/-- For each element of `l : List ` applies `TensorTree.jiggle` to the given term. -/
def dualSyntax (l : List ) (T : Term) : Term :=
l.foldl (fun T' x => Syntax.mkApp (mkIdent ``TensorTree.jiggle)
#[Syntax.mkNumLit (toString x), T']) T
/-- The pairs of positions in getIndicesNode which get contracted. -/
partial def getContrPos (stx : Syntax) : TermElabM (List ( × )) := do
let ind ← getIndices stx
@ -192,7 +238,7 @@ def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
/-- For each element of `l : List ( × )` applies `TensorTree.contr` to the given term. -/
def contrSyntax (l : List ( × )) (T : Term) : Term :=
l.foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr)
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x0), mkIdent ``rfl, T']) T
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x0), mkIdent `sorry, mkIdent ``rfl, T']) T
/-- Creates the syntax associated with a tensor node. -/
def syntaxFull (stx : Syntax) : TermElabM Term := do
@ -202,10 +248,9 @@ def syntaxFull (stx : Syntax) : TermElabM Term := do
let rawIndex ← getNoIndicesExact T
if indices.length ≠ rawIndex then
throwError "The number of indices does not match the tensor {T}."
let tensorNodeSyntax := Syntax.mkApp (mkIdent ``TensorTree.tensorNode) #[T]
let tensorNodeSyntax ← termNodeSyntax T
let evalSyntax := evalSyntax (← getEvalPos stx) tensorNodeSyntax
let dualSyntax := dualSyntax (← getDualPos stx) evalSyntax
let contrSyntax := contrSyntax (← getContrPos stx) dualSyntax
let contrSyntax := contrSyntax (← getContrPos stx) evalSyntax
return contrSyntax
| _ =>
throwError "Unsupported tensor expression syntax in elaborateTensorNode: {stx}"