lpaggen/Python-Dimension-Static-Type-Checker

GitHub: lpaggen/Python-Dimension-Static-Type-Checker

基于 Z3 SMT 求解器的 PyTorch 张量维度静态检查器,在编译期验证张量形状与线性代数操作的维度正确性。

Stars: 0 | Forks: 0

# 警告:这是一个正在进行中的项目,README 的内容超前于代码的实际功能 # Torch Shape Checker (由 Z3 驱动) ![Python](https://img.shields.io/badge/python-3.14+-blue.svg) ![Z3](https://img.shields.io/badge/Z3-SMT%20Solver-green.svg) ![状态](https://img.shields.io/badge/status-experimental-orange.svg) 本项目实现了 Microsoft 的 Z3 SMT 求解器 (https://www.microsoft.com/en-us/research/project/z3-3/)。 你是否在模型训练了几个小时后遇到过简单的 runtime 错误?如果你的错误与线性代数有关,或者是由于对 Pytorch tensor 维度的疏忽造成的,那么这个工具就是为你准备的! Torch Shape Checker 是一个用 Python 编写的静态类型检查器,它能在编译时而非运行时检查 Pytorch tensor 的维度有效性。它通过利用 Python 3.14 的类型注解来实现这一点,所有功能都在原生 Python 3.14 上运行。因为是在原生 Python 上运行,所以在你的程序中应用此工具几乎是毫不费力的,请看下文! ## 为什么使用 Z3 Z3 是 Microsoft 的开源 SMT 求解器。SMT 求解器接收一组约束条件,并输出它们是否全部可行,令人惊讶的是,这可以直接应用于像 Python 这样的编程语言。借助 Z3,该工具让你能够在进入模型训练数小时之前,就检测出 Pytorch tensor 上的维度不匹配问题。 ## 常规 Pytorch 程序示例 ``` n = 13 m = 3 k = 3 A = torch.tensor([[1, 2, 3]]) B = torch.tensor([[1, 2, 3]]) C = torch.matmul(A, B) ``` ## 带有注解的相同程序 ``` n: int = 1 m: int = 3 k: int = 3 A: torch.Tensor[n, m] = torch.tensor([[1, 2, 3]]) # the tool verifies (n, m) matches actual shape B: torch.Tensor[m, k] = torch.tensor([[1], [2], [3]]]) C: torch.Tensor[n, k] = torch.matmul(A, B) # tool verifies A and B can be multiplied and (n, m) matches shape(A dot B) out -> VALID ``` 该工具使用 Z3 SMT 求解器收集整数类型和 tensor 类型提示,并在编译时强制执行适用于 tensor 声明和线性代数操作的规则。这意味着你不需要运行代码就能发现细微的错误,该工具会检测你的错误并报告给你。查看以下示例和输出: ``` n: int = 1 m: int = 3 k: int = 1 A: torch.Tensor[n, m] = torch.tensor([[1, 2, 3, 4]]) # A's type annotation and its actual shape differ B: torch.Tensor[m, k] = torch.randn(3, 1) C: torch.Tensor[n, k] = torch.matmul(A, B) out -> DeclarationError: tensor A was declared with shape(rows=1, cols=4), but expected shape(rows=1, cols=3) ``` # 工具架构 1. 使用 Python 的 _ast_ 模块将 Python 源代码转换为 AST 2. 自定义访问器遍历 AST,并将整数维度和包含 tensor 的节点转换为表示类型和形状的 IR 3. Z3 wrapper 对 IR 完成一次遍历,并根据类型和线性代数规则应用约束 4. 程序通知用户是否犯了任何维度错误 ``` flowchart LR A["Python Source Code"] --> B["Python AST"] B --> C["Semantic IR"] C --> D["Z3 Constraint Layer"] D --> E{"SAT ?"} E -->|SAT| F["Runtime / Valid State"] E -->|UNSAT| G["Invalid Path"] %% Styling style A fill:#2d3436,color:#fff,stroke:#636e72 style B fill:#6c5ce7,color:#fff,stroke:#4834d4 style C fill:#00b894,color:#fff,stroke:#019875 style D fill:#0984e3,color:#fff,stroke:#0767b1 style E fill:#b2bec3,color:#2d3436,stroke:#636e72 style F fill:#00cec9,color:#fff,stroke:#00a8a8 style G fill:#d63031,color:#fff,stroke:#a61e1e ``` # 如何使用 ## 安装依赖项 你只需要 Python 3.14+ 和 Z3 即可运行该工具,只有在执行你的代码时才需要 Pytorch ``` python3.14 -m pip install z3-solver torch pip install z3-solver ``` ## 运行工具 ``` torchdimchecker **your_file** --verbose ``` # 目前支持的功能 ``` torch.matmul torch.tensor torch.randn ```
标签:Python, PyTorch, SOC Prime, Z3求解器, 凭据扫描, 开发工具, 无后门, 维度校验, 静态类型检查