multiprocessing.Pool.apply_async用法

apply_async(func[, args[, kwds[, callback]]])

A variant of the apply() method which returns a result object.

If callback is specified then it should be a callable which accepts a single argument. When the result becomes ready callback is applied to it (unless the call failed). callback should complete immediately since otherwise the thread which handles the results will get blocked.

说明:

  • callback函数应该接受一个参数(func返回的结果),如需多个参数可通过functools.partial构造偏函数
  • callback函数应该可以立即完成,否则会阻塞其他线程
  1. #!/usr/bin/env python
  2. #-*- encoding: utf8 -*-
  3. import os
  4. import sys
  5. import json
  6. import time
  7. import socket
  8. import logging
  9. import commands
  10. from functools import partial
  11. from multiprocessing import Pool
  12. reload(sys)
  13. sys.setdefaultencoding('utf-8')
  14. logging.basicConfig(
  15. format=
  16. '[%(asctime)s %(funcName)s %(levelname)s %(processName)s] %(message)s',
  17. datefmt='%Y%m%d %H:%M:%S',
  18. level=logging.DEBUG)
  19. logger = logging.getLogger(__name__)
  20. def get_sample_context(infile):
  21. logger.info('infile: {}'.format(infile))
  22. if infile.endswith('.jl'):
  23. with open(infile) as f:
  24. for line in f:
  25. context = json.loads(line.strip())
  26. yield context
  27. elif infile.endswith('.json'):
  28. with open(infile) as f:
  29. data = json.load(f)
  30. for context in data:
  31. yield context
  32. def check_path(path, context, key=None):
  33. cmd = 'ls -Ld {}'.format(path)
  34. logger.debug('checking path: {}'.format(path))
  35. res = commands.getoutput(cmd).strip()
  36. if 'ls:' in res or not res:
  37. logger.debug('path not exists: {}'.format(path))
  38. # 原始数据本地不存在时检查云上路径
  39. if key == 'raw_path':
  40. logger.debug('check rawdata on oss ...')
  41. logger.info('{projectid} {novoid} {lane}'.format(**context))
  42. return '.'
  43. logger.info('\033[32mfind path: {}\033[0m'.format(res))
  44. return ','.join(res.split('\n'))
  45. def check_context(context, title, num):
  46. logger.debug('\033[36mdealing with line: {}\033[0m'.format(num))
  47. # 先检查项目路径
  48. if check_path(context['projpath'], context) == '.':
  49. return None
  50. data_path = {}
  51. data_path[
  52. 'raw_path'] = '{projpath}/RawData/{samplename}/{samplename}*fq.gz'.format(
  53. **context)
  54. data_path[
  55. 'clean_path'] = '{projpath}/QC/{samplename}/{samplename}*.clean.fq.gz'.format(
  56. **context)
  57. data_path[
  58. 'bam_path'] = '{projpath}/Mapping/{samplename}.{samplename}/{samplename}*final.bam'.format(
  59. **context)
  60. data_path[
  61. 'gvcf_path'] = '{projpath}/Mutation/{samplename}.*/{samplename}*flt.vcf.gz'.format(
  62. **context)
  63. for key in ('raw_path', 'clean_path', 'bam_path', 'gvcf_path'):
  64. path = data_path[key]
  65. data_path[key] = check_path(path, context, key=key)
  66. if set(data_path.values()) != set('.'):
  67. line = '\t'.join('{%s}' % t for t in title)
  68. line = line.format(**dict(context, **data_path))
  69. return line
  70. return None
  71. def write_output(out, data):
  72. if data:
  73. logger.info('write 1 line')
  74. out.write(data + '\n')
  75. def main():
  76. start_time = time.time()
  77. proc = Pool(args['jobs'])
  78. if args['jobs'] > 1:
  79. logger.info('run {jobs} jobs in parallel'.format(**args))
  80. hostname = socket.gethostname()
  81. title = '''
  82. familyid samplename sex diseasename disnorm seqsty
  83. raw_path clean_path bam_path gvcf_path
  84. '''.split()
  85. sample_context = get_sample_context(args['infile'])
  86. with open(args['outfile'], 'w') as out:
  87. out.write('\t'.join(title) + '\n')
  88. for num, context in enumerate(sample_context):
  89. projpath = context.get('projpath')
  90. diseasename = context.get('diseasename', '').strip()
  91. familyid = context.get('familyid')
  92. seqsty = context.get('seqsty')
  93. context['diseasename'] = diseasename
  94. if not all([projpath, diseasename, familyid, seqsty]):
  95. continue
  96. if ('NJPROJ' in projpath) and ('nj' not in hostname):
  97. continue
  98. elif ('NJPROJ' not in projpath) and ('nj' in hostname):
  99. continue
  100. if args['jobs'] > 1:
  101. proc.apply_async(check_context,
  102. args=(context, title, num),
  103. callback=partial(write_output, out))
  104. else:
  105. data = check_context(context, title, num)
  106. write_output(out, data)
  107. if args['jobs'] > 1:
  108. proc.close()
  109. proc.join()
  110. logger.info('time used: {:.1f}s'.format(time.time() - start_time))
  111. if __name__ == '__main__':
  112. import argparse
  113. parser = argparse.ArgumentParser(
  114. prog='sample_stat',
  115. description=__doc__,
  116. formatter_class=argparse.RawTextHelpFormatter)
  117. parser.add_argument('infile',
  118. help='the samplelist.json or samplelist.jl file',
  119. nargs='?')
  120. parser.add_argument('-j',
  121. '--jobs',
  122. help='run n jobs in parallel[%(default)s]',
  123. type=int,
  124. default=4)
  125. parser.add_argument('-o',
  126. '--outfile',
  127. help='the output filename[%(default)s]',
  128. default='out.xls')
  129. args = vars(parser.parse_args())
  130. main()