feat: Get evaluation working.

This commit is contained in:
jstoobysmith 2024-10-22 06:42:06 +00:00
parent e180d4cca9
commit 8fa8a51367
4 changed files with 56 additions and 13 deletions

View file

@ -96,8 +96,11 @@ declare_syntax_cat tensorExpr
/-- The syntax for a tensor node. -/
syntax term "|" (ppSpace indexExpr)* : tensorExpr
/-- Equality. -/
syntax:40 tensorExpr "=" tensorExpr:41 : tensorExpr
/-- The syntax for tensor prod two tensor nodes. -/
syntax tensorExpr "⊗" tensorExpr : tensorExpr
syntax:70 tensorExpr "⊗" tensorExpr:71 : tensorExpr
/-- The syntax for tensor addition. -/
syntax tensorExpr "+" tensorExpr : tensorExpr
@ -111,9 +114,6 @@ syntax term "•" tensorExpr : tensorExpr
/-- Negation of a tensor tree. -/
syntax "-" tensorExpr : tensorExpr
/-- Equality. -/
syntax tensorExpr "=" tensorExpr : tensorExpr
namespace TensorNode
/-!
@ -129,7 +129,7 @@ We also want to ensure the number of indices is correct.
-/
/-- The indices of a tensor node. Before contraction, dualisation, and evaluation. -/
/-- The indices of a tensor node. Before contraction, and evaluation. -/
partial def getIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $_:term | $[$args]*) => do
@ -183,6 +183,16 @@ def termNodeSyntax (T : Term) : TermElabM Term := do
let strType := toString type
let n := (String.splitOn strType "CategoryTheory.MonoidalCategoryStruct.tensorObj").length
let const := (String.splitOn strType "Quiver.Hom").length
match ← isDefEq type (← stringToType "CoeSort.coe Lorentz.complexCo") with
| true =>
return Syntax.mkApp (mkIdent ``TensorTree.vecNodeE) #[mkIdent ``Fermion.complexLorentzTensor,
mkIdent ``Fermion.Color.down, T]
| _ =>
match ← isDefEq type (← stringToType "CoeSort.coe Lorentz.complexContr") with
| true =>
return Syntax.mkApp (mkIdent ``TensorTree.vecNodeE) #[mkIdent ``Fermion.complexLorentzTensor,
mkIdent ``Fermion.Color.up, T]
| _ =>
match n, const with
| 1, 1 =>
match type with
@ -235,13 +245,24 @@ def termNodeSyntax (T : Term) : TermElabM Term := do
return Syntax.mkApp (mkIdent ``TensorTree.constThreeNode) #[T]
| _, _ => throwError "Could not create terminal node syntax (termNodeSyntax). "
/-- Adjusts a list `List ` by subtracting from each natrual number the number
of elements before it in the list which are less then itself. This is used
to form a list of pairs which can be used for evaluating indices. -/
def evalAdjustPos (l : List ) : List :=
let l' := List.mapAccumr
(fun x (prev : List ) =>
let e := prev.countP (fun y => y < x)
(x :: prev, x - e)) l.reverse []
l'.2.reverse
/-- The positions in getIndicesNode which get evaluated, and the value they take. -/
partial def getEvalPos (stx : Syntax) : TermElabM (List ( × )) := do
let ind ← getIndices stx
let indEnum := ind.enum
let evals := indEnum.filter (fun x => indexExprIsNum x.2)
let evals2 ← (evals.mapM (fun x => indexToNum x.2))
return List.zip (evals.map (fun x => x.1)) evals2
let pos := evalAdjustPos (evals.map (fun x => x.1))
return List.zip pos evals2
/-- For each element of `l : List ( × )` applies `TensorTree.eval` to the given term. -/
def evalSyntax (l : List ( × )) (T : Term) : Term :=
@ -259,11 +280,11 @@ partial def getContrPos (stx : Syntax) : TermElabM (List ( × )) := do
throwError "To many contractions"
return filt
/-- The list of indices after contraction. -/
/-- The list of indices after contraction or evaluation. -/
def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
let ind ← getIndices stx
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
return ind.filter (fun x => indFilt.count x ≤ 1)
return indFilt.filter (fun x => indFilt.count x ≤ 1)
/-- Takes a list and puts conseutive elements into pairs.
e.g. [0, 1, 2, 3] becomes [(0, 1), (2, 3)]. -/
@ -460,7 +481,8 @@ partial def syntaxFull (stx : Syntax) : TermElabM Term := do
match stx with
| `(tensorExpr| $_:term | $[$args]*) =>
ProdNode.syntaxFull stx
| `(tensorExpr| $_:tensorExpr ⊗ $_:tensorExpr) => ProdNode.syntaxFull stx
| `(tensorExpr| $_:tensorExpr ⊗ $_:tensorExpr) => do
return ← ProdNode.syntaxFull stx
| `(tensorExpr| ($a:tensorExpr)) => do
return (← syntaxFull a)
| `(tensorExpr| -$a:tensorExpr) => do
@ -476,6 +498,7 @@ partial def syntaxFull (stx : Syntax) : TermElabM Term := do
/-- An elaborator for tensor nodes. This is to be generalized. -/
def elaborateTensorNode (stx : Syntax) : TermElabM Expr := do
println! "{(← syntaxFull stx)}"
let tensorExpr ← elabTerm (← syntaxFull stx) none
return tensorExpr