MST

星途 面试题库

面试题:Python元类在避免复杂实参错误场景中的设计

假设你正在开发一个数学计算库,其中有多个类的方法对实参的要求非常复杂,例如要求实参必须是特定维度的矩阵(可以用嵌套列表模拟),并且矩阵元素要满足一定的数值范围。请设计一个元类,通过元类来控制这些类的方法实参的正确性,确保在方法调用前就能检测并处理实参错误,给出元类设计思路及关键代码实现。
30.7万 热度难度
编程语言Python

知识考点

AI 面试

面试题答案

一键面试

元类设计思路

  1. 定义元类:创建一个元类,该元类将用于包装目标类的方法。
  2. 方法包装:在元类中,对于目标类的每个方法,创建一个新的包装函数。这个包装函数在调用原始方法之前,先对实参进行检查。
  3. 参数检查逻辑:在包装函数中,编写逻辑来检查实参是否为特定维度的矩阵(嵌套列表),并且矩阵元素是否在指定的数值范围内。
  4. 错误处理:如果实参不符合要求,抛出适当的异常或进行相应的错误处理。

关键代码实现

class MatrixMeta(type):
    def __new__(cls, name, bases, attrs):
        for attr_name, attr_value in attrs.items():
            if callable(attr_value):
                def wrapper(func):
                    def inner(self, *args, **kwargs):
                        # 这里以第一个位置参数为矩阵参数为例
                        matrix = args[0] if args else None
                        if matrix is None or not isinstance(matrix, list) or any(not isinstance(row, list) for row in matrix):
                            raise ValueError("实参必须是嵌套列表表示的矩阵")
                        # 检查矩阵维度
                        dim = len(matrix[0])
                        if any(len(row) != dim for row in matrix[1:]):
                            raise ValueError("矩阵必须是二维且每行长度一致")
                        # 检查矩阵元素数值范围,假设范围是0到100
                        for row in matrix:
                            for value in row:
                                if not (0 <= value <= 100):
                                    raise ValueError("矩阵元素必须在0到100之间")
                        return func(self, *args, **kwargs)
                    return inner
                attrs[attr_name] = wrapper(attr_value)
        return super().__new__(cls, name, bases, attrs)


class MathCalculationLibrary(metaclass=MatrixMeta):
    def some_method(self, matrix):
        # 实际的方法逻辑
        print(f"处理矩阵: {matrix}")


你可以使用以下方式测试:

try:
    lib = MathCalculationLibrary()
    valid_matrix = [[1, 2], [3, 4]]
    lib.some_method(valid_matrix)
    invalid_matrix = [[1, 2], [3, 4, 5]]  # 维度错误
    lib.some_method(invalid_matrix)
except ValueError as e:
    print(f"错误: {e}")