refactor: Index notation
This commit is contained in:
parent
d9f6760541
commit
ec69deaff2
12 changed files with 299 additions and 865 deletions
|
@ -6,6 +6,7 @@ Authors: Joseph Tooby-Smith
|
|||
import HepLean.Tensors.Tree.Basic
|
||||
import Lean.Elab.Term
|
||||
import HepLean.Tensors.Tree.Dot
|
||||
import HepLean.Tensors.ComplexLorentz.Basic
|
||||
/-!
|
||||
|
||||
## Elaboration of tensor trees
|
||||
|
@ -82,6 +83,7 @@ def indexToDual (stx : Syntax) : Bool :=
|
|||
match stx with
|
||||
| `(indexExpr| τ($_)) => true
|
||||
| _ => false
|
||||
|
||||
/-!
|
||||
|
||||
## Tensor expressions
|
||||
|
@ -136,15 +138,72 @@ partial def getIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) :=
|
|||
def getNoIndicesExact (stx : Syntax) : TermElabM ℕ := do
|
||||
let expr ← elabTerm stx none
|
||||
let type ← inferType expr
|
||||
match type with
|
||||
| Expr.app _ (Expr.app _ (Expr.app _ c)) =>
|
||||
let typeC ← inferType c
|
||||
match typeC with
|
||||
| Expr.forallE _ (Expr.app _ (Expr.app (Expr.app _ (Expr.lit (Literal.natVal n))) _)) _ _ =>
|
||||
return n
|
||||
| _ => throwError "Could not extract number of indices from tensor (getNoIndicesExact). "
|
||||
| _ =>
|
||||
throwError "Could not extract number of indices from tensor (getNoIndicesExact)."
|
||||
let strType := toString type
|
||||
let n := (String.splitOn strType "CategoryTheory.MonoidalCategoryStruct.tensorObj").length
|
||||
match n with
|
||||
| 1 =>
|
||||
match type with
|
||||
| Expr.app _ (Expr.app _ (Expr.app _ c)) =>
|
||||
let typeC ← inferType c
|
||||
match typeC with
|
||||
| Expr.forallE _ (Expr.app _ (Expr.app (Expr.app _ (Expr.lit (Literal.natVal n))) _)) _ _ =>
|
||||
return n
|
||||
| _ => throwError "Could not extract number of indices from tensor (getNoIndicesExact). "
|
||||
| _ => return 1
|
||||
| k => return k
|
||||
|
||||
/-- The construction of an expression corresponding to the type of a given string once parsed. -/
|
||||
def stringToType (str : String) : TermElabM Expr := do
|
||||
let env ← getEnv
|
||||
let stx := Parser.runParserCategory env `term str
|
||||
match stx with
|
||||
| Except.error _ => throwError "Could not create type from string (stringToType). "
|
||||
| Except.ok stx => elabTerm stx none
|
||||
|
||||
/-- The syntax associated with a terminal node of a tensor tree. -/
|
||||
def termNodeSyntax (T : Term) : TermElabM Term := do
|
||||
let expr ← elabTerm T none
|
||||
let type ← inferType expr
|
||||
let strType := toString type
|
||||
let n := (String.splitOn strType "CategoryTheory.MonoidalCategoryStruct.tensorObj").length
|
||||
let const := (String.splitOn strType "Quiver.Hom").length
|
||||
match n, const with
|
||||
| 1, 1 =>
|
||||
match type with
|
||||
| Expr.app _ (Expr.app _ (Expr.app _ c)) =>
|
||||
let typeC ← inferType c
|
||||
match typeC with
|
||||
| Expr.forallE _ (Expr.app _ (Expr.app (Expr.app _ (Expr.lit (Literal.natVal _))) _)) _ _ =>
|
||||
return Syntax.mkApp (mkIdent ``TensorTree.tensorNode) #[T]
|
||||
| _ => throwError "Could not create terminal node syntax (termNodeSyntax). "
|
||||
| _ => return Syntax.mkApp (mkIdent ``TensorTree.vecNode) #[T]
|
||||
| 2, 1 =>
|
||||
match ← isDefEq type (← stringToType "CoeSort.coe leftHanded ⊗ CoeSort.coe Lorentz.complexContr") with
|
||||
| true => return Syntax.mkApp (mkIdent ``TensorTree.twoNodeE)
|
||||
#[mkIdent ``Fermion.complexLorentzTensor, mkIdent ``Fermion.Color.upL, mkIdent ``Fermion.Color.up, T]
|
||||
| _ => return Syntax.mkApp (mkIdent ``TensorTree.twoNode) #[T]
|
||||
| 3, 1 => return Syntax.mkApp (mkIdent ``TensorTree.threeNode) #[T]
|
||||
| 1, 2 => return Syntax.mkApp (mkIdent ``TensorTree.constVecNode) #[T]
|
||||
| 2, 2 =>
|
||||
match ← isDefEq type (← stringToType
|
||||
"𝟙_ (Rep ℂ SL(2, ℂ)) ⟶ Lorentz.complexCo ⊗ Lorentz.complexCo") with
|
||||
| true =>
|
||||
println! "here"
|
||||
return Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
|
||||
mkIdent ``Fermion.complexLorentzTensor, mkIdent ``Fermion.Color.down,
|
||||
mkIdent ``Fermion.Color.down, T]
|
||||
| _ => return Syntax.mkApp (mkIdent ``TensorTree.constTwoNode) #[T]
|
||||
| 3, 2 =>
|
||||
/- Specific types. -/
|
||||
match ← isDefEq type (← stringToType
|
||||
"𝟙_ (Rep ℂ SL(2, ℂ)) ⟶ Lorentz.complexContr ⊗ Fermion.leftHanded ⊗ Fermion.rightHanded") with
|
||||
| true =>
|
||||
return Syntax.mkApp (mkIdent ``TensorTree.constThreeNodeE) #[
|
||||
mkIdent ``Fermion.complexLorentzTensor, mkIdent ``Fermion.Color.up,
|
||||
mkIdent ``Fermion.Color.upL, mkIdent ``Fermion.Color.upR, T]
|
||||
| _ =>
|
||||
return Syntax.mkApp (mkIdent ``TensorTree.constThreeNode) #[T]
|
||||
| _, _ => throwError "Could not create terminal node syntax (termNodeSyntax). "
|
||||
|
||||
/-- The positions in getIndicesNode which get evaluated, and the value they take. -/
|
||||
partial def getEvalPos (stx : Syntax) : TermElabM (List (ℕ × ℕ)) := do
|
||||
|
@ -159,19 +218,6 @@ def evalSyntax (l : List (ℕ × ℕ)) (T : Term) : Term :=
|
|||
l.foldl (fun T' (x1, x2) => Syntax.mkApp (mkIdent ``TensorTree.eval)
|
||||
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x2), T']) T
|
||||
|
||||
/-- The positions in getIndicesNode which get dualized. -/
|
||||
partial def getDualPos (stx : Syntax) : TermElabM (List ℕ) := do
|
||||
let ind ← getIndices stx
|
||||
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
|
||||
let indEnum := indFilt.enum
|
||||
let duals := indEnum.filter (fun x => indexToDual x.2)
|
||||
return duals.map (fun x => x.1)
|
||||
|
||||
/-- For each element of `l : List ℕ` applies `TensorTree.jiggle` to the given term. -/
|
||||
def dualSyntax (l : List ℕ) (T : Term) : Term :=
|
||||
l.foldl (fun T' x => Syntax.mkApp (mkIdent ``TensorTree.jiggle)
|
||||
#[Syntax.mkNumLit (toString x), T']) T
|
||||
|
||||
/-- The pairs of positions in getIndicesNode which get contracted. -/
|
||||
partial def getContrPos (stx : Syntax) : TermElabM (List (ℕ × ℕ)) := do
|
||||
let ind ← getIndices stx
|
||||
|
@ -192,7 +238,7 @@ def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
|||
/-- For each element of `l : List (ℕ × ℕ)` applies `TensorTree.contr` to the given term. -/
|
||||
def contrSyntax (l : List (ℕ × ℕ)) (T : Term) : Term :=
|
||||
l.foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr)
|
||||
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x0), mkIdent ``rfl, T']) T
|
||||
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x0), mkIdent `sorry, mkIdent ``rfl, T']) T
|
||||
|
||||
/-- Creates the syntax associated with a tensor node. -/
|
||||
def syntaxFull (stx : Syntax) : TermElabM Term := do
|
||||
|
@ -202,10 +248,9 @@ def syntaxFull (stx : Syntax) : TermElabM Term := do
|
|||
let rawIndex ← getNoIndicesExact T
|
||||
if indices.length ≠ rawIndex then
|
||||
throwError "The number of indices does not match the tensor {T}."
|
||||
let tensorNodeSyntax := Syntax.mkApp (mkIdent ``TensorTree.tensorNode) #[T]
|
||||
let tensorNodeSyntax ← termNodeSyntax T
|
||||
let evalSyntax := evalSyntax (← getEvalPos stx) tensorNodeSyntax
|
||||
let dualSyntax := dualSyntax (← getDualPos stx) evalSyntax
|
||||
let contrSyntax := contrSyntax (← getContrPos stx) dualSyntax
|
||||
let contrSyntax := contrSyntax (← getContrPos stx) evalSyntax
|
||||
return contrSyntax
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in elaborateTensorNode: {stx}"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue