feat: Get evaluation working.
This commit is contained in:
parent
e180d4cca9
commit
8fa8a51367
4 changed files with 56 additions and 13 deletions
|
@ -96,8 +96,11 @@ declare_syntax_cat tensorExpr
|
|||
/-- The syntax for a tensor node. -/
|
||||
syntax term "|" (ppSpace indexExpr)* : tensorExpr
|
||||
|
||||
/-- Equality. -/
|
||||
syntax:40 tensorExpr "=" tensorExpr:41 : tensorExpr
|
||||
|
||||
/-- The syntax for tensor prod two tensor nodes. -/
|
||||
syntax tensorExpr "⊗" tensorExpr : tensorExpr
|
||||
syntax:70 tensorExpr "⊗" tensorExpr:71 : tensorExpr
|
||||
|
||||
/-- The syntax for tensor addition. -/
|
||||
syntax tensorExpr "+" tensorExpr : tensorExpr
|
||||
|
@ -111,9 +114,6 @@ syntax term "•" tensorExpr : tensorExpr
|
|||
/-- Negation of a tensor tree. -/
|
||||
syntax "-" tensorExpr : tensorExpr
|
||||
|
||||
/-- Equality. -/
|
||||
syntax tensorExpr "=" tensorExpr : tensorExpr
|
||||
|
||||
namespace TensorNode
|
||||
|
||||
/-!
|
||||
|
@ -129,7 +129,7 @@ We also want to ensure the number of indices is correct.
|
|||
|
||||
-/
|
||||
|
||||
/-- The indices of a tensor node. Before contraction, dualisation, and evaluation. -/
|
||||
/-- The indices of a tensor node. Before contraction, and evaluation. -/
|
||||
partial def getIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:term | $[$args]*) => do
|
||||
|
@ -183,6 +183,16 @@ def termNodeSyntax (T : Term) : TermElabM Term := do
|
|||
let strType := toString type
|
||||
let n := (String.splitOn strType "CategoryTheory.MonoidalCategoryStruct.tensorObj").length
|
||||
let const := (String.splitOn strType "Quiver.Hom").length
|
||||
match ← isDefEq type (← stringToType "CoeSort.coe Lorentz.complexCo") with
|
||||
| true =>
|
||||
return Syntax.mkApp (mkIdent ``TensorTree.vecNodeE) #[mkIdent ``Fermion.complexLorentzTensor,
|
||||
mkIdent ``Fermion.Color.down, T]
|
||||
| _ =>
|
||||
match ← isDefEq type (← stringToType "CoeSort.coe Lorentz.complexContr") with
|
||||
| true =>
|
||||
return Syntax.mkApp (mkIdent ``TensorTree.vecNodeE) #[mkIdent ``Fermion.complexLorentzTensor,
|
||||
mkIdent ``Fermion.Color.up, T]
|
||||
| _ =>
|
||||
match n, const with
|
||||
| 1, 1 =>
|
||||
match type with
|
||||
|
@ -235,13 +245,24 @@ def termNodeSyntax (T : Term) : TermElabM Term := do
|
|||
return Syntax.mkApp (mkIdent ``TensorTree.constThreeNode) #[T]
|
||||
| _, _ => throwError "Could not create terminal node syntax (termNodeSyntax). "
|
||||
|
||||
/-- Adjusts a list `List ℕ` by subtracting from each natrual number the number
|
||||
of elements before it in the list which are less then itself. This is used
|
||||
to form a list of pairs which can be used for evaluating indices. -/
|
||||
def evalAdjustPos (l : List ℕ) : List ℕ :=
|
||||
let l' := List.mapAccumr
|
||||
(fun x (prev : List ℕ) =>
|
||||
let e := prev.countP (fun y => y < x)
|
||||
(x :: prev, x - e)) l.reverse []
|
||||
l'.2.reverse
|
||||
|
||||
/-- The positions in getIndicesNode which get evaluated, and the value they take. -/
|
||||
partial def getEvalPos (stx : Syntax) : TermElabM (List (ℕ × ℕ)) := do
|
||||
let ind ← getIndices stx
|
||||
let indEnum := ind.enum
|
||||
let evals := indEnum.filter (fun x => indexExprIsNum x.2)
|
||||
let evals2 ← (evals.mapM (fun x => indexToNum x.2))
|
||||
return List.zip (evals.map (fun x => x.1)) evals2
|
||||
let pos := evalAdjustPos (evals.map (fun x => x.1))
|
||||
return List.zip pos evals2
|
||||
|
||||
/-- For each element of `l : List (ℕ × ℕ)` applies `TensorTree.eval` to the given term. -/
|
||||
def evalSyntax (l : List (ℕ × ℕ)) (T : Term) : Term :=
|
||||
|
@ -259,11 +280,11 @@ partial def getContrPos (stx : Syntax) : TermElabM (List (ℕ × ℕ)) := do
|
|||
throwError "To many contractions"
|
||||
return filt
|
||||
|
||||
/-- The list of indices after contraction. -/
|
||||
/-- The list of indices after contraction or evaluation. -/
|
||||
def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
let ind ← getIndices stx
|
||||
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
|
||||
return ind.filter (fun x => indFilt.count x ≤ 1)
|
||||
return indFilt.filter (fun x => indFilt.count x ≤ 1)
|
||||
|
||||
/-- Takes a list and puts conseutive elements into pairs.
|
||||
e.g. [0, 1, 2, 3] becomes [(0, 1), (2, 3)]. -/
|
||||
|
@ -460,7 +481,8 @@ partial def syntaxFull (stx : Syntax) : TermElabM Term := do
|
|||
match stx with
|
||||
| `(tensorExpr| $_:term | $[$args]*) =>
|
||||
ProdNode.syntaxFull stx
|
||||
| `(tensorExpr| $_:tensorExpr ⊗ $_:tensorExpr) => ProdNode.syntaxFull stx
|
||||
| `(tensorExpr| $_:tensorExpr ⊗ $_:tensorExpr) => do
|
||||
return ← ProdNode.syntaxFull stx
|
||||
| `(tensorExpr| ($a:tensorExpr)) => do
|
||||
return (← syntaxFull a)
|
||||
| `(tensorExpr| -$a:tensorExpr) => do
|
||||
|
@ -476,6 +498,7 @@ partial def syntaxFull (stx : Syntax) : TermElabM Term := do
|
|||
|
||||
/-- An elaborator for tensor nodes. This is to be generalized. -/
|
||||
def elaborateTensorNode (stx : Syntax) : TermElabM Expr := do
|
||||
println! "{(← syntaxFull stx)}"
|
||||
let tensorExpr ← elabTerm (← syntaxFull stx) none
|
||||
return tensorExpr
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue