基本单位

  1. import math
  2. import numpy as np
  3. import matplotlib.units as units
  4. import matplotlib.ticker as ticker
  5. from matplotlib.cbook import iterable
  6. class ProxyDelegate(object):
  7. def __init__(self, fn_name, proxy_type):
  8. self.proxy_type = proxy_type
  9. self.fn_name = fn_name
  10. def __get__(self, obj, objtype=None):
  11. return self.proxy_type(self.fn_name, obj)
  12. class TaggedValueMeta(type):
  13. def __init__(self, name, bases, dict):
  14. for fn_name in self._proxies:
  15. try:
  16. dummy = getattr(self, fn_name)
  17. except AttributeError:
  18. setattr(self, fn_name,
  19. ProxyDelegate(fn_name, self._proxies[fn_name]))
  20. class PassThroughProxy(object):
  21. def __init__(self, fn_name, obj):
  22. self.fn_name = fn_name
  23. self.target = obj.proxy_target
  24. def __call__(self, *args):
  25. fn = getattr(self.target, self.fn_name)
  26. ret = fn(*args)
  27. return ret
  28. class ConvertArgsProxy(PassThroughProxy):
  29. def __init__(self, fn_name, obj):
  30. PassThroughProxy.__init__(self, fn_name, obj)
  31. self.unit = obj.unit
  32. def __call__(self, *args):
  33. converted_args = []
  34. for a in args:
  35. try:
  36. converted_args.append(a.convert_to(self.unit))
  37. except AttributeError:
  38. converted_args.append(TaggedValue(a, self.unit))
  39. converted_args = tuple([c.get_value() for c in converted_args])
  40. return PassThroughProxy.__call__(self, *converted_args)
  41. class ConvertReturnProxy(PassThroughProxy):
  42. def __init__(self, fn_name, obj):
  43. PassThroughProxy.__init__(self, fn_name, obj)
  44. self.unit = obj.unit
  45. def __call__(self, *args):
  46. ret = PassThroughProxy.__call__(self, *args)
  47. return (NotImplemented if ret is NotImplemented
  48. else TaggedValue(ret, self.unit))
  49. class ConvertAllProxy(PassThroughProxy):
  50. def __init__(self, fn_name, obj):
  51. PassThroughProxy.__init__(self, fn_name, obj)
  52. self.unit = obj.unit
  53. def __call__(self, *args):
  54. converted_args = []
  55. arg_units = [self.unit]
  56. for a in args:
  57. if hasattr(a, 'get_unit') and not hasattr(a, 'convert_to'):
  58. # if this arg has a unit type but no conversion ability,
  59. # this operation is prohibited
  60. return NotImplemented
  61. if hasattr(a, 'convert_to'):
  62. try:
  63. a = a.convert_to(self.unit)
  64. except:
  65. pass
  66. arg_units.append(a.get_unit())
  67. converted_args.append(a.get_value())
  68. else:
  69. converted_args.append(a)
  70. if hasattr(a, 'get_unit'):
  71. arg_units.append(a.get_unit())
  72. else:
  73. arg_units.append(None)
  74. converted_args = tuple(converted_args)
  75. ret = PassThroughProxy.__call__(self, *converted_args)
  76. if ret is NotImplemented:
  77. return NotImplemented
  78. ret_unit = unit_resolver(self.fn_name, arg_units)
  79. if ret_unit is NotImplemented:
  80. return NotImplemented
  81. return TaggedValue(ret, ret_unit)
  82. class TaggedValue(metaclass=TaggedValueMeta):
  83. _proxies = {'__add__': ConvertAllProxy,
  84. '__sub__': ConvertAllProxy,
  85. '__mul__': ConvertAllProxy,
  86. '__rmul__': ConvertAllProxy,
  87. '__cmp__': ConvertAllProxy,
  88. '__lt__': ConvertAllProxy,
  89. '__gt__': ConvertAllProxy,
  90. '__len__': PassThroughProxy}
  91. def __new__(cls, value, unit):
  92. # generate a new subclass for value
  93. value_class = type(value)
  94. try:
  95. subcls = type('TaggedValue_of_%s' % (value_class.__name__),
  96. tuple([cls, value_class]),
  97. {})
  98. if subcls not in units.registry:
  99. units.registry[subcls] = basicConverter
  100. return object.__new__(subcls)
  101. except TypeError:
  102. if cls not in units.registry:
  103. units.registry[cls] = basicConverter
  104. return object.__new__(cls)
  105. def __init__(self, value, unit):
  106. self.value = value
  107. self.unit = unit
  108. self.proxy_target = self.value
  109. def __getattribute__(self, name):
  110. if name.startswith('__'):
  111. return object.__getattribute__(self, name)
  112. variable = object.__getattribute__(self, 'value')
  113. if hasattr(variable, name) and name not in self.__class__.__dict__:
  114. return getattr(variable, name)
  115. return object.__getattribute__(self, name)
  116. def __array__(self, dtype=object):
  117. return np.asarray(self.value).astype(dtype)
  118. def __array_wrap__(self, array, context):
  119. return TaggedValue(array, self.unit)
  120. def __repr__(self):
  121. return 'TaggedValue({!r}, {!r})'.format(self.value, self.unit)
  122. def __str__(self):
  123. return str(self.value) + ' in ' + str(self.unit)
  124. def __len__(self):
  125. return len(self.value)
  126. def __iter__(self):
  127. # Return a generator expression rather than use `yield`, so that
  128. # TypeError is raised by iter(self) if appropriate when checking for
  129. # iterability.
  130. return (TaggedValue(inner, self.unit) for inner in self.value)
  131. def get_compressed_copy(self, mask):
  132. new_value = np.ma.masked_array(self.value, mask=mask).compressed()
  133. return TaggedValue(new_value, self.unit)
  134. def convert_to(self, unit):
  135. if unit == self.unit or not unit:
  136. return self
  137. new_value = self.unit.convert_value_to(self.value, unit)
  138. return TaggedValue(new_value, unit)
  139. def get_value(self):
  140. return self.value
  141. def get_unit(self):
  142. return self.unit
  143. class BasicUnit(object):
  144. def __init__(self, name, fullname=None):
  145. self.name = name
  146. if fullname is None:
  147. fullname = name
  148. self.fullname = fullname
  149. self.conversions = dict()
  150. def __repr__(self):
  151. return 'BasicUnit(%s)' % self.name
  152. def __str__(self):
  153. return self.fullname
  154. def __call__(self, value):
  155. return TaggedValue(value, self)
  156. def __mul__(self, rhs):
  157. value = rhs
  158. unit = self
  159. if hasattr(rhs, 'get_unit'):
  160. value = rhs.get_value()
  161. unit = rhs.get_unit()
  162. unit = unit_resolver('__mul__', (self, unit))
  163. if unit is NotImplemented:
  164. return NotImplemented
  165. return TaggedValue(value, unit)
  166. def __rmul__(self, lhs):
  167. return self*lhs
  168. def __array_wrap__(self, array, context):
  169. return TaggedValue(array, self)
  170. def __array__(self, t=None, context=None):
  171. ret = np.array([1])
  172. if t is not None:
  173. return ret.astype(t)
  174. else:
  175. return ret
  176. def add_conversion_factor(self, unit, factor):
  177. def convert(x):
  178. return x*factor
  179. self.conversions[unit] = convert
  180. def add_conversion_fn(self, unit, fn):
  181. self.conversions[unit] = fn
  182. def get_conversion_fn(self, unit):
  183. return self.conversions[unit]
  184. def convert_value_to(self, value, unit):
  185. conversion_fn = self.conversions[unit]
  186. ret = conversion_fn(value)
  187. return ret
  188. def get_unit(self):
  189. return self
  190. class UnitResolver(object):
  191. def addition_rule(self, units):
  192. for unit_1, unit_2 in zip(units[:-1], units[1:]):
  193. if unit_1 != unit_2:
  194. return NotImplemented
  195. return units[0]
  196. def multiplication_rule(self, units):
  197. non_null = [u for u in units if u]
  198. if len(non_null) > 1:
  199. return NotImplemented
  200. return non_null[0]
  201. op_dict = {
  202. '__mul__': multiplication_rule,
  203. '__rmul__': multiplication_rule,
  204. '__add__': addition_rule,
  205. '__radd__': addition_rule,
  206. '__sub__': addition_rule,
  207. '__rsub__': addition_rule}
  208. def __call__(self, operation, units):
  209. if operation not in self.op_dict:
  210. return NotImplemented
  211. return self.op_dict[operation](self, units)
  212. unit_resolver = UnitResolver()
  213. cm = BasicUnit('cm', 'centimeters')
  214. inch = BasicUnit('inch', 'inches')
  215. inch.add_conversion_factor(cm, 2.54)
  216. cm.add_conversion_factor(inch, 1/2.54)
  217. radians = BasicUnit('rad', 'radians')
  218. degrees = BasicUnit('deg', 'degrees')
  219. radians.add_conversion_factor(degrees, 180.0/np.pi)
  220. degrees.add_conversion_factor(radians, np.pi/180.0)
  221. secs = BasicUnit('s', 'seconds')
  222. hertz = BasicUnit('Hz', 'Hertz')
  223. minutes = BasicUnit('min', 'minutes')
  224. secs.add_conversion_fn(hertz, lambda x: 1./x)
  225. secs.add_conversion_factor(minutes, 1/60.0)
  226. # radians formatting
  227. def rad_fn(x, pos=None):
  228. if x >= 0:
  229. n = int((x / np.pi) * 2.0 + 0.25)
  230. else:
  231. n = int((x / np.pi) * 2.0 - 0.25)
  232. if n == 0:
  233. return '0'
  234. elif n == 1:
  235. return r'$\pi/2$'
  236. elif n == 2:
  237. return r'$\pi$'
  238. elif n == -1:
  239. return r'$-\pi/2$'
  240. elif n == -2:
  241. return r'$-\pi$'
  242. elif n % 2 == 0:
  243. return r'$%s\pi$' % (n//2,)
  244. else:
  245. return r'$%s\pi/2$' % (n,)
  246. class BasicUnitConverter(units.ConversionInterface):
  247. @staticmethod
  248. def axisinfo(unit, axis):
  249. 'return AxisInfo instance for x and unit'
  250. if unit == radians:
  251. return units.AxisInfo(
  252. majloc=ticker.MultipleLocator(base=np.pi/2),
  253. majfmt=ticker.FuncFormatter(rad_fn),
  254. label=unit.fullname,
  255. )
  256. elif unit == degrees:
  257. return units.AxisInfo(
  258. majloc=ticker.AutoLocator(),
  259. majfmt=ticker.FormatStrFormatter(r'$%i^\circ$'),
  260. label=unit.fullname,
  261. )
  262. elif unit is not None:
  263. if hasattr(unit, 'fullname'):
  264. return units.AxisInfo(label=unit.fullname)
  265. elif hasattr(unit, 'unit'):
  266. return units.AxisInfo(label=unit.unit.fullname)
  267. return None
  268. @staticmethod
  269. def convert(val, unit, axis):
  270. if units.ConversionInterface.is_numlike(val):
  271. return val
  272. if iterable(val):
  273. return [thisval.convert_to(unit).get_value() for thisval in val]
  274. else:
  275. return val.convert_to(unit).get_value()
  276. @staticmethod
  277. def default_units(x, axis):
  278. 'return the default unit for x or None'
  279. if iterable(x):
  280. for thisx in x:
  281. return thisx.unit
  282. return x.unit
  283. def cos(x):
  284. if iterable(x):
  285. return [math.cos(val.convert_to(radians).get_value()) for val in x]
  286. else:
  287. return math.cos(x.convert_to(radians).get_value())
  288. basicConverter = BasicUnitConverter()
  289. units.registry[BasicUnit] = basicConverter
  290. units.registry[TaggedValue] = basicConverter

下载这个示例