refactor: Elab file

This commit is contained in:
jstoobysmith 2024-10-22 12:10:55 +00:00
parent 6fbace33da
commit 6fe581f31c
2 changed files with 16 additions and 42 deletions

View file

@ -263,6 +263,8 @@ def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
return indFilt.filter (fun x => indFilt.count x ≤ 1)
end TensorNode
/-- Takes a list and puts conseutive elements into pairs.
e.g. [0, 1, 2, 3] becomes [(0, 1), (2, 3)]. -/
def toPairs (l : List ) : List ( × ) :=
@ -288,23 +290,6 @@ def contrSyntax (l : List ( × )) (T : Term) : Term :=
#[Syntax.mkNumLit (toString x0),
Syntax.mkNumLit (toString x1), mkIdent ``rfl, T']) T
/-- Creates the syntax associated with a tensor node. -/
def syntaxFull (stx : Syntax) : TermElabM Term := do
match stx with
| `(tensorExpr| $T:term | $[$args]*) => do
let indices ← getIndices stx
let rawIndex ← getNoIndicesExact T
if indices.length ≠ rawIndex then
throwError "The expected number of indices {rawIndex} does not match the tensor {T}."
let tensorNodeSyntax ← termNodeSyntax T
let evalSyntax := evalSyntax (← getEvalPos stx) tensorNodeSyntax
let contrSyntax := contrSyntax (← getContrPos stx) evalSyntax
return contrSyntax
| _ =>
throwError "Unsupported tensor expression syntax in elaborateTensorNode: {stx}"
end TensorNode
namespace ProdNode
/-!
@ -346,28 +331,10 @@ def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
return ind.filter (fun x => indFilt.count x ≤ 1)
/-- For each element of `l : List ( × )` applies `TensorTree.contr` to the given term. -/
def contrSyntax (l : List ( × )) (T : Term) : Term :=
(TensorNode.contrListAdjust l).foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr)
#[Syntax.mkNumLit (toString x0), Syntax.mkNumLit (toString x1), mkIdent ``rfl, T']) T
/-- The syntax associated with a product of tensors. -/
def prodSyntax (T1 T2 : Term) : Term :=
Syntax.mkApp (mkIdent ``TensorTree.prod) #[T1, T2]
/-- The full term taking tensor syntax into a term for products and single tensor nodes. -/
partial def syntaxFull (stx : Syntax) : TermElabM Term := do
match stx with
| `(tensorExpr| $_:term | $[$args]*) => TensorNode.syntaxFull stx
| `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do
let prodSyntax := prodSyntax (← syntaxFull a) (← syntaxFull b)
let contrSyntax := contrSyntax (← getContrPos stx) prodSyntax
return contrSyntax
| `(tensorExpr| ($a:tensorExpr)) => do
return (← syntaxFull a)
| _ =>
throwError "Unsupported tensor expression syntax in elaborateTensorNode: {stx}"
end ProdNode
/-!
@ -404,7 +371,6 @@ def getPermutationSyntax (l1 l2 : List (TSyntax `indexExpr)) : TermElabM Term :=
let stx := Syntax.mkApp (mkIdent ``finMapToEquiv) #[P1, P2]
return stx
namespace negNode
/-- The syntax associated with a product of tensors. -/
@ -439,7 +405,6 @@ partial def getIndicesLeft (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)
| _ =>
throwError "Unsupported tensor expression syntax in Add.getIndicesLeft: {stx}"
/-- Gets the indices associated with the RHS of an addition. -/
partial def getIndicesRight (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
@ -493,10 +458,19 @@ end Equality
/-- Creates the syntax associated with a tensor node. -/
partial def syntaxFull (stx : Syntax) : TermElabM Term := do
match stx with
| `(tensorExpr| $_:term | $[$args]*) =>
ProdNode.syntaxFull stx
| `(tensorExpr| $_:tensorExpr ⊗ $_:tensorExpr) => do
return ← ProdNode.syntaxFull stx
| `(tensorExpr| $T:term | $[$args]*) =>
let indices ← TensorNode.getIndices stx
let rawIndex ← TensorNode.getNoIndicesExact T
if indices.length ≠ rawIndex then
throwError "The expected number of indices {rawIndex} does not match the tensor {T}."
let tensorNodeSyntax ← TensorNode.termNodeSyntax T
let evalSyntax := TensorNode.evalSyntax (← TensorNode.getEvalPos stx) tensorNodeSyntax
let contrSyntax := contrSyntax (← TensorNode.getContrPos stx) evalSyntax
return contrSyntax
| `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do
let prodSyntax := ProdNode.prodSyntax (← syntaxFull a) (← syntaxFull b)
let contrSyntax := contrSyntax (← ProdNode.getContrPos stx) prodSyntax
return contrSyntax
| `(tensorExpr| ($a:tensorExpr)) => do
return (← syntaxFull a)
| `(tensorExpr| -$a:tensorExpr) => do