refactor: Fix problem with elab and do lint

This commit is contained in:
jstoobysmith 2024-10-24 07:36:54 +00:00
parent 95857993b5
commit 1e8efdb16a
6 changed files with 217 additions and 152 deletions

View file

@ -152,8 +152,11 @@ def getNoIndicesExact (stx : Syntax) : TermElabM := do
| Expr.app _ (Expr.app _ (Expr.app _ c)) =>
let typeC ← inferType c
match typeC with
| Expr.forallE _ (Expr.app _ (Expr.app (Expr.app _ (Expr.lit (Literal.natVal n))) _)) _ _ =>
return n
| Expr.forallE _ (Expr.app _ a) _ _ =>
let a' ← whnf a
match a' with
| Expr.lit (Literal.natVal n) => return n
|_ => throwError "Could not extract number of indices from tensor (getNoIndicesExact). "
| _ => throwError "Could not extract number of indices from tensor (getNoIndicesExact). "
| _ => return 1
| k => return k
@ -248,11 +251,7 @@ def termNodeSyntax (T : Term) : TermElabM Term := do
| _ =>
match type with
| Expr.app _ (Expr.app _ (Expr.app _ c)) =>
let typeC ← inferType c
match typeC with
| Expr.forallE _ (Expr.app _ (Expr.app (Expr.app _ (Expr.lit (Literal.natVal _))) _)) _ _ =>
return Syntax.mkApp (mkIdent ``TensorTree.tensorNode) #[T]
| _ => throwError "Could not create terminal node syntax (termNodeSyntax). "
| _ => return Syntax.mkApp (mkIdent ``TensorTree.vecNode) #[T]
/-- Adjusts a list `List ` by subtracting from each natrual number the number
@ -547,5 +546,14 @@ elab_rules (kind:=tensorExprSyntax) : term
| `(term| {$e:tensorExpr}ᵀ) => do
let tensorTree ← elaborateTensorNode e
return tensorTree
/-!
## Test cases
-/
variable {S : TensorSpecies} {c : Fin (Nat.succ (Nat.succ 0)) → S.C} {t : S.F.obj (OverColor.mk c)}
/-
#check {t | α β}ᵀ
-/
end TensorTree