refactor: Elab file
This commit is contained in:
parent
6fbace33da
commit
6fe581f31c
2 changed files with 16 additions and 42 deletions
|
@ -263,6 +263,8 @@ def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
|||
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
|
||||
return indFilt.filter (fun x => indFilt.count x ≤ 1)
|
||||
|
||||
end TensorNode
|
||||
|
||||
/-- Takes a list and puts conseutive elements into pairs.
|
||||
e.g. [0, 1, 2, 3] becomes [(0, 1), (2, 3)]. -/
|
||||
def toPairs (l : List ℕ) : List (ℕ × ℕ) :=
|
||||
|
@ -288,23 +290,6 @@ def contrSyntax (l : List (ℕ × ℕ)) (T : Term) : Term :=
|
|||
#[Syntax.mkNumLit (toString x0),
|
||||
Syntax.mkNumLit (toString x1), mkIdent ``rfl, T']) T
|
||||
|
||||
/-- Creates the syntax associated with a tensor node. -/
|
||||
def syntaxFull (stx : Syntax) : TermElabM Term := do
|
||||
match stx with
|
||||
| `(tensorExpr| $T:term | $[$args]*) => do
|
||||
let indices ← getIndices stx
|
||||
let rawIndex ← getNoIndicesExact T
|
||||
if indices.length ≠ rawIndex then
|
||||
throwError "The expected number of indices {rawIndex} does not match the tensor {T}."
|
||||
let tensorNodeSyntax ← termNodeSyntax T
|
||||
let evalSyntax := evalSyntax (← getEvalPos stx) tensorNodeSyntax
|
||||
let contrSyntax := contrSyntax (← getContrPos stx) evalSyntax
|
||||
return contrSyntax
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in elaborateTensorNode: {stx}"
|
||||
|
||||
end TensorNode
|
||||
|
||||
namespace ProdNode
|
||||
|
||||
/-!
|
||||
|
@ -346,28 +331,10 @@ def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
|||
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
|
||||
return ind.filter (fun x => indFilt.count x ≤ 1)
|
||||
|
||||
/-- For each element of `l : List (ℕ × ℕ)` applies `TensorTree.contr` to the given term. -/
|
||||
def contrSyntax (l : List (ℕ × ℕ)) (T : Term) : Term :=
|
||||
(TensorNode.contrListAdjust l).foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr)
|
||||
#[Syntax.mkNumLit (toString x0), Syntax.mkNumLit (toString x1), mkIdent ``rfl, T']) T
|
||||
|
||||
/-- The syntax associated with a product of tensors. -/
|
||||
def prodSyntax (T1 T2 : Term) : Term :=
|
||||
Syntax.mkApp (mkIdent ``TensorTree.prod) #[T1, T2]
|
||||
|
||||
/-- The full term taking tensor syntax into a term for products and single tensor nodes. -/
|
||||
partial def syntaxFull (stx : Syntax) : TermElabM Term := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:term | $[$args]*) => TensorNode.syntaxFull stx
|
||||
| `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do
|
||||
let prodSyntax := prodSyntax (← syntaxFull a) (← syntaxFull b)
|
||||
let contrSyntax := contrSyntax (← getContrPos stx) prodSyntax
|
||||
return contrSyntax
|
||||
| `(tensorExpr| ($a:tensorExpr)) => do
|
||||
return (← syntaxFull a)
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in elaborateTensorNode: {stx}"
|
||||
|
||||
end ProdNode
|
||||
|
||||
/-!
|
||||
|
@ -404,7 +371,6 @@ def getPermutationSyntax (l1 l2 : List (TSyntax `indexExpr)) : TermElabM Term :=
|
|||
let stx := Syntax.mkApp (mkIdent ``finMapToEquiv) #[P1, P2]
|
||||
return stx
|
||||
|
||||
|
||||
namespace negNode
|
||||
|
||||
/-- The syntax associated with a product of tensors. -/
|
||||
|
@ -439,7 +405,6 @@ partial def getIndicesLeft (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)
|
|||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in Add.getIndicesLeft: {stx}"
|
||||
|
||||
|
||||
/-- Gets the indices associated with the RHS of an addition. -/
|
||||
partial def getIndicesRight (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
match stx with
|
||||
|
@ -493,10 +458,19 @@ end Equality
|
|||
/-- Creates the syntax associated with a tensor node. -/
|
||||
partial def syntaxFull (stx : Syntax) : TermElabM Term := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:term | $[$args]*) =>
|
||||
ProdNode.syntaxFull stx
|
||||
| `(tensorExpr| $_:tensorExpr ⊗ $_:tensorExpr) => do
|
||||
return ← ProdNode.syntaxFull stx
|
||||
| `(tensorExpr| $T:term | $[$args]*) =>
|
||||
let indices ← TensorNode.getIndices stx
|
||||
let rawIndex ← TensorNode.getNoIndicesExact T
|
||||
if indices.length ≠ rawIndex then
|
||||
throwError "The expected number of indices {rawIndex} does not match the tensor {T}."
|
||||
let tensorNodeSyntax ← TensorNode.termNodeSyntax T
|
||||
let evalSyntax := TensorNode.evalSyntax (← TensorNode.getEvalPos stx) tensorNodeSyntax
|
||||
let contrSyntax := contrSyntax (← TensorNode.getContrPos stx) evalSyntax
|
||||
return contrSyntax
|
||||
| `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do
|
||||
let prodSyntax := ProdNode.prodSyntax (← syntaxFull a) (← syntaxFull b)
|
||||
let contrSyntax := contrSyntax (← ProdNode.getContrPos stx) prodSyntax
|
||||
return contrSyntax
|
||||
| `(tensorExpr| ($a:tensorExpr)) => do
|
||||
return (← syntaxFull a)
|
||||
| `(tensorExpr| -$a:tensorExpr) => do
|
||||
|
|
|
@ -130,7 +130,7 @@ lemma add_assoc (t1 t2 t3 : TensorTree S c) :
|
|||
can be moved out of the addition. -/
|
||||
lemma add_perm {n : ℕ} {c : Fin n → S.C} {c1 : Fin n → S.C}
|
||||
(σ : (OverColor.mk c) ⟶ (OverColor.mk c1)) (t t1 : TensorTree S c) :
|
||||
(add (perm σ t) (perm σ t1)).tensor = (perm σ (add t t1)).tensor := by
|
||||
(add (perm σ t) (perm σ t1)).tensor = (perm σ (add t t1)).tensor := by
|
||||
simp only [add_tensor, perm_tensor, map_add]
|
||||
|
||||
/-- When the same evaluation acts on both arguments of an addition, the evaluation
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue