Java 类名:com.alibaba.alink.operator.batch.clustering.LdaPredictBatchOp
Python 类名:LdaPredictBatchOp

功能介绍

LDA(Latent Dirichlet allocation)是一种主题模型。LDA是一种非监督机器学习技术,可以用来识别大规模文档集(document collection)或语料库(corpus)中潜藏的主题信息。它采用了词袋(bag of words)的方法,这种方法将每一篇文档视为一个词频向量,从而将文本信息转化为了易于建模的数字信息。但是词袋方法没有考虑词与词之间的顺序,这简化了问题的复杂性,同时也为模型的改进提供了契机。每一篇文档代表了一些主题所构成的一个概率分布,而每一个主题又代表了很多单词所构成的一个概率分布。
它将文档集中每篇文档的主题按照概率分布的形式给出,同时它是一种无监督学习算法,在训练时不需要手工标注的训练集,需要的仅仅是文档集以及指定主题的数量k即可。
LDA功能包含LDA训练和LDA预测(批和流)以及pipeline。

参数说明

| 名称 | 中文名称 | 描述 | 类型 | 是否必须? | 取值范围 | 默认值 | | —- | —- | —- | —- | —- | —- | —- |

| predictionCol | 预测结果列名 | 预测结果列名 | String | ✓ | | |

| selectedCol | 选中的列名 | 计算列对应的列名 | String | ✓ | 所选列类型为 [DENSE_VECTOR, SPARSE_VECTOR, STRING, VECTOR] | |

| modelFilePath | 模型的文件路径 | 模型的文件路径 | String | | | null |

| predictionDetailCol | 预测详细信息列名 | 预测详细信息列名 | String | | | |

| reservedCols | 算法保留列名 | 算法保留列 | String[] | | | null |

| numThreads | 组件多线程线程个数 | 组件多线程线程个数 | Integer | | | 1 |

代码示例

Python 代码

  1. from pyalink.alink import *
  2. import pandas as pd
  3. useLocalEnv(1)
  4. df = pd.DataFrame([
  5. ["a b b c c c c c c e e f f f g h k k k"],
  6. ["a b b b d e e e h h k"],
  7. ["a b b b b c f f f f g g g g g g g g g i j j"],
  8. ["a a b d d d g g g g g i i j j j k k k k k k k k k"],
  9. ["a a a b c d d d d d d d d d e e e g g j k k k"],
  10. ["a a a a b b d d d e e e e f f f f f g h i j j j j"],
  11. ["a a b d d d g g g g g i i j j k k k k k k k k k"],
  12. ["a b c d d d d d d d d d e e f g g j k k k"],
  13. ["a a a a b b b b d d d e e e e f f g h h h"],
  14. ["a a b b b b b b b b c c e e e g g i i j j j j j j j k k"],
  15. ["a b c d d d d d d d d d f f g g j j j k k k"],
  16. ["a a a a b e e e e f f f f f g h h h j"],
  17. ])
  18. inOp = BatchOperator.fromDataframe(df, schemaStr="doc string")
  19. inOp2 = StreamOperator.fromDataframe(df, schemaStr="doc string")
  20. ldaTrain = LdaTrainBatchOp()\
  21. .setSelectedCol("doc")\
  22. .setTopicNum(6)\
  23. .setMethod("online")\
  24. .setSubsamplingRate(1.0)\
  25. .setOptimizeDocConcentration(True)\
  26. .setNumIter(20)
  27. ldaPredict = LdaPredictBatchOp()\
  28. .setPredictionCol("pred")\
  29. .setSelectedCol("doc")
  30. model = ldaTrain.linkFrom(inOp)
  31. ldaPredict.linkFrom(model, inOp)
  32. model.lazyPrint(10)
  33. ldaPredict.print()
  34. ldaPredictS = LdaPredictStreamOp(model)\
  35. .setPredictionCol("pred")\
  36. .setSelectedCol("doc")\
  37. .linkFrom(inOp2)
  38. ldaPredictS.print()
  39. StreamOperator.execute()

Java 代码

  1. import org.apache.flink.types.Row;
  2. import com.alibaba.alink.operator.batch.BatchOperator;
  3. import com.alibaba.alink.operator.batch.clustering.LdaPredictBatchOp;
  4. import com.alibaba.alink.operator.batch.clustering.LdaTrainBatchOp;
  5. import com.alibaba.alink.operator.batch.source.MemSourceBatchOp;
  6. import com.alibaba.alink.operator.stream.StreamOperator;
  7. import com.alibaba.alink.operator.stream.clustering.LdaPredictStreamOp;
  8. import com.alibaba.alink.operator.stream.source.MemSourceStreamOp;
  9. import org.junit.Test;
  10. import java.util.Arrays;
  11. import java.util.List;
  12. public class LdaPredictBatchOpTest {
  13. @Test
  14. public void testLdaPredictBatchOp() throws Exception {
  15. List <Row> df = Arrays.asList(
  16. Row.of("a b b c c c c c c e e f f f g h k k k"),
  17. Row.of("a b b b d e e e h h k"),
  18. Row.of("a b b b b c f f f f g g g g g g g g g i j j"),
  19. Row.of("a a b d d d g g g g g i i j j j k k k k k k k k k"),
  20. Row.of("a a a b c d d d d d d d d d e e e g g j k k k"),
  21. Row.of("a a a a b b d d d e e e e f f f f f g h i j j j j"),
  22. Row.of("a a b d d d g g g g g i i j j k k k k k k k k k"),
  23. Row.of("a b c d d d d d d d d d e e f g g j k k k"),
  24. Row.of("a a a a b b b b d d d e e e e f f g h h h"),
  25. Row.of("a a b b b b b b b b c c e e e g g i i j j j j j j j k k"),
  26. Row.of("a b c d d d d d d d d d f f g g j j j k k k"),
  27. Row.of("a a a a b e e e e f f f f f g h h h j")
  28. );
  29. BatchOperator <?> inOp = new MemSourceBatchOp(df, "doc string");
  30. StreamOperator <?> inOp2 = new MemSourceStreamOp(df, "doc string");
  31. BatchOperator <?> ldaTrain = new LdaTrainBatchOp()
  32. .setSelectedCol("doc")
  33. .setTopicNum(6)
  34. .setMethod("online")
  35. .setSubsamplingRate(1.0)
  36. .setOptimizeDocConcentration(true)
  37. .setNumIter(20);
  38. BatchOperator <?> ldaPredict = new LdaPredictBatchOp()
  39. .setPredictionCol("pred")
  40. .setSelectedCol("doc");
  41. BatchOperator <?> model = ldaTrain.linkFrom(inOp);
  42. ldaPredict.linkFrom(model, inOp);
  43. model.lazyPrint(10);
  44. ldaPredict.print();
  45. StreamOperator <?> ldaPredictS = new LdaPredictStreamOp(model)
  46. .setPredictionCol("pred")
  47. .setSelectedCol("doc")
  48. .linkFrom(inOp2);
  49. ldaPredictS.print();
  50. StreamOperator.execute();
  51. }
  52. }

运行结果

模型结果

| model_id | model_info | | —- | —- |

| 0 | {“logPerplexity”:”3.7090449161397796”,”betaArray”:”[0.16666666666666666,0.16666666666666666,0.16666666666666666,0.16666666666666666,0.16666666666666666,0.16666666666666666]”,”logLikelihood”:”-964.3516781963427”,”method”:””online””,”alphaArray”:”[0.13821318741806757,0.14883947846014303,0.11751772860080838,0.11649338902896737,0.1503735753641805,0.12383960905322638]”,”topicNum”:”6”,”vocabularySize”:”11”} |

| 1048576 | {“m”:6,”n”:11,”data”:[6125.275647735944,5541.830400832857,5277.404107556518,5575.307666756267,5738.822977932333,5664.141524765102,5183.8663148472615,6286.886714218059,5159.4834022615505,5965.45851687814,5785.616901302167,5558.164928383525,5290.881194601821,5849.766053667748,5595.238710003511,5709.172846472106,5367.427910628795,6967.997740551021,5688.8764262580735,4955.8174077887725,4940.593716098454,5435.785995518678,6359.043301395186,4992.933732368455,5164.467086144761,6624.6072909374125,6911.005911971013,6239.327690548231,5908.580210537792,6090.679944041717,4491.439930702308,5785.921888708801,4648.954813378507,5714.129075228494,6200.167117921488,5223.186458407328,5560.911614536643,5141.113565996373,6043.809469077941,7092.299303765094,6408.739229185271,5851.449695701356,4518.178684615466,5946.483529384942,5633.526524470202,5538.4345859137275,5983.901197676244,5587.210556929512,6050.024468817716,4965.114090486532,4634.277477990217,5692.989466800378,5462.485467579785,4841.301836486494,5117.962076960599,4980.381226902301,5186.706443620538,6608.121037167229,5926.302505211329,6106.240714316094,5474.117007346719,4977.005342253029,5871.2842682743185,4842.798396244806,4810.0086663355705,5468.469136036559]} |

| 2097152 | {“f0”:”d”,”f1”:0.36772478012531734,”f2”:0} |

| 3145728 | {“f0”:”k”,”f1”:0.36772478012531734,”f2”:1} |

| 4194304 | {“f0”:”f”,”f1”:0.4855078157817008,”f2”:7} |

| 5242880 | {“f0”:”c”,”f1”:0.6190392084062235,”f2”:8} |

| 6291456 | {“f0”:”h”,”f1”:0.7731898882334817,”f2”:9} |

| 7340032 | {“f0”:”i”,”f1”:0.7731898882334817,”f2”:10} |

| 8388608 | {“f0”:”g”,”f1”:0.08004270767353636,”f2”:2} |

| 9437184 | {“f0”:”b”,”f1”:0.0,”f2”:3} |

| 10485760 | {“f0”:”a”,”f1”:0.0,”f2”:4} |

| 11534336 | {“f0”:”e”,”f1”:0.36772478012531734,”f2”:5} |

| 12582912 | {“f0”:”j”,”f1”:0.26236426446749106,”f2”:6} |

预测结果

| id | libsvm | pred | | —- | —- | —- |

| 0 | a b b c c c c c c e e f f f g h k k k | 0 |

| 1 | a b b b d e e e h h k | 4 |

| 2 | a b b b b c f f f f g g g g g g g g g i j j | 5 |

| 3 | a a b d d d g g g g g i i j j j k k k k k k k k k | 1 |

| 4 | a a a b c d d d d d d d d d e e e g g j k k k | 1 |

| 5 | a a a a b b d d d e e e e f f f f f g h i j j j j | 2 |

| 6 | a a b d d d g g g g g i i j j k k k k k k k k k | 1 |

| 7 | a b c d d d d d d d d d e e f g g j k k k | 0 |

| 8 | a a a a b b b b d d d e e e e f f g h h h | 4 |

| 9 | a a b b b b b b b b c c e e e g g i i j j j j j j j k k | 4 |

| 10 | a b c d d d d d d d d d f f g g j j j k k k | 0 |

| 11 | a a a a b e e e e f f f f f g h h h j | 1 |