from object_packaging.field import Field
from object_packaging.my_database import create_pool


# 定义一个元类,控制model对象的创建
class ModelMetaClass(type):
    def __new__(cls, table_name, bases, attrs):
        if table_name == "Model":
            return super(ModelMetaClass, cls).__new__(cls, table_name, bases, attrs)
        mappings = dict()
        for k, v in attrs.items():
            # 保存类属性和列的映射关系到mappings字典中
            if isinstance(v, Field):
                mappings[k] = v  # 这样,这个mappings就存放了  属性名称和字段名称及列的名字
        for k in mappings.keys():
            attrs.pop(k)   # 只有在实例中才可以进行访问。也就是说使用类名.属性就不能访问了
        # 把表名转换为小写,表名就是类名
        # 这个处理略显粗糙,应该继续做一些健壮性的判断,比如ModelMetaClass 变成表名应该是model_meta_class
        attrs['__table__'] = table_name.lower()
        # 保存属性和列的映射关系,创建类时添加一个__mappings__类的属性
        attrs['__mappings__'] = mappings

        return super(ModelMetaClass, cls).__new__(cls, table_name, bases, attrs)



# 编写一个model子类,这个类用于被具体的model对象继承。 它实现了具体的增删改查方法
# 这样以后的每一个model就都有了这些方法
# 这里边我又继承了dict,那么也就是说这个Model也可以使用dict的方法

class Model(dict, metaclass=ModelMetaClass):
    def __init__(self, **kwargs):
        super(Model, self).__init__(**kwargs)

    def __getattr__(self, item):
        try:
            return self[item]
        except KeyError:
            raise AttributeError("这个Model没有这个属性,这个属性是%s" % item)

    def __setattr__(self, key, value):
        self[key] = value



    def insert(self, column_list, param_list):
        print("调用了insert方法")

        fields = []

        for k, v in self.__mappings__.items():
            fields.append(k)

        for key in column_list:
            if key not in fields:
                raise RuntimeError("field not found")

        args = self.__check_params(param_list)

        sql = 'insert into %s (%s) values (%s)' % (self.__table__, ','.join(column_list), ','.join(args))

        res = self.__do_execute(sql)
        print(res)

    def __check_params(self, param_list):

        args = []
        #  insert into   ....   values "va"lue"
        for param in param_list:
            if "\"" in param:
                param = param.replace("\"", "\\\"")

            param = "\"" + param + "\""
            args.append(param)
        return args

    def __do_execute(self,sql):
        conn = create_pool()
        cur = conn.cursor()
        print(sql)

        if "select" in sql:
            cur.execute(sql)
            rs = cur.fetchall()
        else:
            rs = cur.execute(sql)

        conn.commit()
        cur.close()

        return rs
# 下边是作业

    # get  select

    def select(self, column_list, where_list):
        print("调用了select方法")
        args = []
        fields = []

        for k,v in self.__mappings__.items():
            fields.append(k)

        for key in column_list:
            if key not in fields:
                raise RuntimeError("field not found")

        for key in where_list:
            args.append(key)

        sql = 'select %s from %s where %s' % (','.join(column_list), self.__table__, ' and '.join(args))

        res = self.__do_execute(sql)
        return res

    # update
    def update(self, set_column_list, where_list):
        print("调用了update方法")
        args = []
        fields = []


        for k,v in self.__mappings__.items():
            fields.append(k)

        for key in set_column_list:
            if key not in fields:
                raise RuntimeError("field not found")

        for key in where_list:
            args.append(key)

        sql = 'update %s set %s where %s' % (self.__table__, ','.join(set_column_list), ' and '.join(args))

        print(sql)
        res = self.__do_execute(sql)

        return res




    # delete

    def delete(self, where_list):
        print("调用delete方法")
        args = []
        for key in where_list:
            args.append(key)

        sql = 'delete from %s where %s' % (self.__table__, ' and '.join(args))
        res = self.__do_execute(sql)

        return res