三种序列化方法

在网络上找到的三种序列化方法如下:

  • 在模型定义的时候给序列化的方法
  • 继承改写 flask 里面的 JSONEncoder类以及default方法
  • 使用Marshmallow

为什么要进行序列化(将数据库取出的对象转化为JSON格式)

在Flask中,我们常常会选择采用RESTful设计风格,即在各个资源对应的GET、POST、PUT方法中,返回一个JSON格式的数据(资源)给前端使用。这就要求我们在(如get)方法中return一个dict,flask-restful会自动帮我们返回为一个JSON格式的数据。

而通过flask-SQLAlchemy这一ORM工具所构建的数据库表模型,通过其语句所取出的数据通常是object类型的,这一类型并不能直接在方法中return返回一个JSON格式,因此需要先对从数据库中取出的数据进行序列化,然后再return给前端。


第一种序列化方法:使用dict在模型内部构建一个规则方法,定义资源的模式

该方法是由我们自己手动实现资源的序列化
下面就是一个完整的Model模型的构建,以及序列化实现:

  1. class Test(db.Model):
  2. # 表的字段构建
  3. id = db.Column(db.BIGINT, primary_key=True, autoincrement=True)
  4. station_id = db.Column(db.String(20), nullable=False)
  5. datetime = db.Column(db.DateTime, nullable=False)
  6. m0 = db.Column(db.Float)
  7. # 模型的资源序列化函数(方法)
  8. # 在该函数中所返回的dict的keys,将是我们从test表里所序列化的字段
  9. def test_schema(self):
  10. return {
  11. 'id': self.id,
  12. 'station_id': self.station_id,
  13. 'datetime': self.datetime,
  14. 'm0': self.m0
  15. }

上述代码便完成了准备工作,下面我们将展示如何应用,即在RESTful API中以JSON格式返回数据库表中取得的数据:

  1. class HelloWorld(Resource):
  2. def get(self):
  3. data = Test.query.first() # 取第一条数据
  4. data_serialize = data.test_schema() # 通过我们之前在模型类里定义的序列化函数对取得数据进行序列化,此时 data_serialize 的类型是 dict
  5. return jsonify(data_serialize)

以上就是第一种序列化方法的实现过程
给出完整代码:

  1. from flask import Flask as _Flask
  2. from flask import jsonify
  3. from flask_migrate import Migrate
  4. from flask_sqlalchemy import SQLAlchemy
  5. from flask_restful import Api, Resource
  6. from flask_marshmallow import Marshmallow
  7. app = _Flask(__name__)
  8. api = Api(app)
  9. ma = Marshmallow(app)
  10. # 配置数据库的地址
  11. app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql://XXX:XXX@localhost:3306/test'
  12. app.config['SQLALCHEMY_COMMIT_TEARDOWN'] = True
  13. app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = True
  14. app.config['SQLALCHEMY_COMMIT_ON_TEARDOWN'] = True
  15. db = SQLAlchemy(app)
  16. migrate = Migrate(app, db)
  17. class Test(db.Model):
  18. # 表的字段构建
  19. id = db.Column(db.BIGINT, primary_key=True, autoincrement=True)
  20. station_id = db.Column(db.String(20), nullable=False)
  21. datetime = db.Column(db.DateTime, nullable=False)
  22. m0 = db.Column(db.Float)
  23. # 模型的资源序列化函数(方法)
  24. # 在该函数中所返回的dict的keys,将是我们从test表里所序列化的字段
  25. def test_schema(self):
  26. return {
  27. 'id': self.id,
  28. 'station_id': self.station_id,
  29. 'datetime': self.datetime,
  30. 'm0': self.m0
  31. }
  32. class HelloWorld(Resource):
  33. def get(self):
  34. data = Test.query.first() # 取第一条数据
  35. data_serialize = data.test_schema() # 通过我们之前在模型类里定义的序列化函数对取得数据进行序列化,此时 data_serialize 的类型是 dict
  36. return jsonify(data_serialize)
  37. api.add_resource(HelloWorld, '/')
  38. if __name__ == '__main__':
  39. app.run(debug=True)

最终结果展示:
ip地址:http://localhost:5000/
返回JSON字符串:

  1. {
  2. "datetime": "Sun, 01 Jul 2018 00:00:00 GMT",
  3. "id": 1,
  4. "station_id": "54511"
  5. }

第二种序列化方法:重写JSONEncoder实现自定义序列化

1.为什么要重写JSONencoder?

本方法中我们采用json.dumps()来将取出的数据序列化成json字符串

通过查看json包中的def dumps()源码我们可以得知,我们在调用json.dumps()方法时默认会使用JSONEncoder进行序列化,传入cls参数后可以使用自定义的序列化方法(即我们重写的JSONEncoder)。源码如下(json/init.py):

  1. def dumps(obj, *, skipkeys=False, ensure_ascii=True, check_circular=True,
  2. allow_nan=True, cls=None, indent=None, separators=None,
  3. default=None, sort_keys=False, **kw):
  4. # 以下是关于cls参数的使用介绍
  5. """
  6. To use a custom ``JSONEncoder`` subclass (e.g. one that overrides the
  7. ``.default()`` method to serialize additional types), specify it with
  8. the ``cls`` kwarg; otherwise ``JSONEncoder`` is used.
  9. 翻译:
  10. 如果要使用自定义的“JSONEncoder”子类
  11. (例如,重写 default() 方法来序列化其他类型的数据),
  12. 请使用使用“cls”关键字参数;否则默认使用“JSONEncoder”。
  13. """
  14. # 部分源码:可以看出如果传入了 cls ,就会按照你重写的方法执行
  15. if cls is None:
  16. cls = JSONEncoder
  17. return cls(
  18. skipkeys=skipkeys, ensure_ascii=ensure_ascii,
  19. check_circular=check_circular, allow_nan=allow_nan, indent=indent,
  20. separators=separators, default=default, sort_keys=sort_keys,
  21. **kw).encode(obj)

而原有JSONEncoderdefault()方法是不能够对对象进行序列化的,因此需要我们自己重写default()方法,然后在调用json.dumps()时通过cls参数传进去。


2.如何重写JSONEncoder?

先看完整代码:

  1. from flask import Flask as _Flask
  2. from flask_migrate import Migrate
  3. from flask_sqlalchemy import SQLAlchemy
  4. from flask_restful import Api, Resource
  5. from flask_marshmallow import Marshmallow
  6. from flask.json import JSONEncoder as _JSONEncoder
  7. from datetime import date
  8. import json
  9. app = _Flask(__name__)
  10. api = Api(app)
  11. ma = Marshmallow(app)
  12. # 配置数据库的地址
  13. app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql://xxx:xxxxxx@localhost:3306/test'
  14. app.config['SQLALCHEMY_COMMIT_TEARDOWN'] = True
  15. app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = True
  16. app.config['SQLALCHEMY_COMMIT_ON_TEARDOWN'] = True
  17. db = SQLAlchemy(app)
  18. migrate = Migrate(app, db)
  19. class JSONEncoder(_JSONEncoder):
  20. def default(self, o):
  21. if hasattr(o, 'keys') and hasattr(o, '__getitem__'):
  22. return dict(o)
  23. if isinstance(o, date):
  24. return o.strftime('%Y-%m-%d %H:%M:%S')
  25. return json.JSONEncoder.default(self, o)
  26. class Flask(_Flask):
  27. json_encoder = JSONEncoder
  28. class Test(db.Model):
  29. id = db.Column(db.BIGINT, primary_key=True, autoincrement=True)
  30. name = db.Column(db.String(20), nullable=False)
  31. def keys(self):
  32. return ['id', 'name']
  33. def __getitem__(self, item):
  34. return getattr(self, item)
  35. class HelloWorld(Resource):
  36. def get(self):
  37. test_data = Test.query.all()
  38. data_json = json.loads(json.dumps(test_data, cls=JSONEncoder))
  39. return data_json
  40. api.add_resource(HelloWorld, '/')
  41. if __name__ == '__main__':
  42. app.run(debug=True)

代码分析

JSONEncoder类的default与原先default写法类似,继承原JSONEncoder类,添加一个针对对象的if处理即可:

  1. if hasattr(o, 'keys') and hasattr(o, '__getitem__'):
  2. return dict(o)

判断:如果传入的对象o存在keys__getitem__属性(即我们在模型Test类中定义的两个方法),则表明传入对象o是模型对象,把对象o传给dict()
dict函数的特殊之处在于,当一个对象传入后,dict会去调用keys函数(模型Test类中定义的方法),keys方法的目的是拿到我们自定义的所有字典里的键。
dict会以中括号的形式来拿到对应键的值,如o[‘id’],但是默认不能这样访问,这就是__getitem__方法的作用了。


用法

  1. class HelloWorld(Resource):
  2. def get(self):
  3. test_data = Test.query.all()
  4. data_json = json.loads(json.dumps(test_data, cls=JSONEncoder))
  5. return data_json

在这里,我们对序列化后的数据又进行了一次json.loads,原因在于:
如果不loads,最终返回的JSON字符串是这样的:

  1. "[{\"id\": 1, \"name\": \"qq\"}, {\"id\": 2, \"name\": \"ww\"}, {\"id\": 3, \"name\": \"ee\"}]"

我猜测可能是进行了2次序列化后的结果,因此再loads一次,就可以恢复正常了:
最终返回结果如下

  1. [
  2. {
  3. "id": 1,
  4. "name": "qq"
  5. },
  6. {
  7. "id": 2,
  8. "name": "ww"
  9. },
  10. {
  11. "id": 3,
  12. "name": "ee"
  13. }
  14. ]

注意
由于我们重写了JSONEncoder类的defualt方法,所以需要将自己的JSONEncoder类顶替掉Flask.JSON下的原有类。
相关代码:

  1. from flask import Flask as _Flask
  2. from flask.json import JSONEncoder as _JSONEncoder
  3. class Flask(_Flask):
  4. json_encoder = JSONEncoder