多线程批量执行任务
package com.gitee.kooder.utils;
import java.util.List;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveAction;
import java.util.function.Consumer;
/**
* Batch task action
* @author Winter Lau<javayou@gmail.com>
*/
public final class BatchTaskRunner extends RecursiveAction {
protected int threshold = 5;
protected List taskList;
Consumer<List> action;
/**
* @param taskList 任务列表
* @param threshold 每个线程处理的任务数
*/
private BatchTaskRunner(List taskList, int threshold, Consumer action) {
this.taskList = taskList;
this.threshold = threshold;
this.action = action;
}
/**
* 多线程批量执行任务
* @param taskList
* @param threshold
* @param action
*/
public static <T> void execute(List<T> taskList, int threshold, Consumer<List<T>> action) {
new BatchTaskRunner(taskList, threshold, action).invoke();
}
@Override
protected void compute() {
if (taskList.size() <= threshold) {
this.action.accept(taskList);
}
else {
this.splitFromMiddle(taskList);
}
}
/**
* 任务中分
* @param list
*/
private void splitFromMiddle(List list) {
int middle = (int)Math.ceil(list.size() / 2.0);
List leftList = list.subList(0, middle);
List RightList = list.subList(middle, list.size());
BatchTaskRunner left = newInstance(leftList);
BatchTaskRunner right = newInstance(RightList);
ForkJoinTask.invokeAll(left, right);
}
private BatchTaskRunner newInstance(List taskList) {
return new BatchTaskRunner(taskList, threshold, action);
}
}
/**
* 读取指定目录下所有 json 文件并写入索引
* @param type
* @param action
* @param path
* @param thread_count
* @return file count
*/
private static int importJsonInPath(String type, String action, Path path, int thread_count) throws IOException {
final AtomicInteger fc = new AtomicInteger(0);
thread_count = Math.min(MAX_THREAD_COUNT, Math.max(thread_count, 1));
try (
IndexWriter writer = StorageFactory.getIndexWriter(type);
TaxonomyWriter taxonomyWriter = StorageFactory.getTaxonomyWriter(type);
Stream<Path> pathStream = Files.list(path);
) {
List<Path> allFiles = pathStream.filter(p -> p.toString().endsWith(".json") && !Files.isDirectory(p)).collect(Collectors.toList());
int threshold = Math.max(allFiles.size()/thread_count, 1);
BatchTaskRunner.execute(allFiles, threshold, files -> {
files.forEach( jsonFile -> {
importJsonFile(type, action, jsonFile, writer, taxonomyWriter);
fc.addAndGet(1);
});
});
}
return fc.get();
}
redis队列
package com.gitee.kooder.queue;
import java.util.Collection;
import java.util.List;
/**
* 定义了获取索引任务的队列接口
* @author Winter Lau<javayou@gmail.com>
*/
public interface Queue extends AutoCloseable{
/**
* 队列的唯一名称
* @return
*/
String type();
/**
* 添加任务到队列
* @param tasks
*/
void push(Collection<QueueTask> tasks) ;
/**
* 从队列获取任务
* @return
*/
List<QueueTask> pop(int count) ;
}
package com.gitee.kooder.queue;
import com.gitee.kooder.core.KooderConfig;
import org.apache.commons.lang3.StringUtils;
import java.util.Properties;
/**
* 队列工厂
* @author Winter Lau<javayou@gmail.com>
*/
public class QueueFactory {
static QueueProvider provider;
static {
Properties props = KooderConfig.getQueueProperties();
String type = StringUtils.trim(props.getProperty("provider"));
if("redis".equalsIgnoreCase(type))
provider = new RedisQueueProvider(props);
else if("embed".equalsIgnoreCase(type))
provider = new EmbedQueueProvider(props);
}
public final static QueueProvider getProvider() {
return provider;
}
}
package com.gitee.kooder.queue;
import com.gitee.kooder.core.Constants;
import java.util.Arrays;
import java.util.List;
/**
* 定义了获取索引任务的队列接口
* @author Winter Lau<javayou@gmail.com>
*/
public interface QueueProvider extends AutoCloseable {
List<String> TYPES = Arrays.asList(Constants.TYPE_REPOSITORY, Constants.TYPE_ISSUE, Constants.TYPE_CODE);
/**
* Provider 唯一标识
* @return
*/
String name();
/**
* 获取支持的所有任务类型
* @return
*/
default List<String> getAllTypes() {
return TYPES;
}
/**
* 获取某个任务类型的队列
* @param type
* @return
*/
Queue queue(String type);
}
package com.gitee.kooder.queue;
import io.lettuce.core.RedisClient;
import io.lettuce.core.RedisURI;
import io.lettuce.core.api.StatefulRedisConnection;
import io.lettuce.core.api.sync.RedisCommands;
import org.apache.commons.lang3.math.NumberUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Properties;
/**
* 使用 Redis 队列
* @author Winter Lau<javayou@gmail.com>
*/
public class RedisQueueProvider implements QueueProvider {
private final static Logger log = LoggerFactory.getLogger(RedisQueueProvider.class);
private String host;
private int port;
private int database;
private String baseKey;
private String username;
private String password;
private RedisClient client;
/**
* Connect to redis
* @param props
*/
public RedisQueueProvider(Properties props) {
this.host = props.getProperty("redis.host", "127.0.0.1");
this.port = NumberUtils.toInt(props.getProperty("redis.port"), 6379);
this.database = NumberUtils.toInt(props.getProperty("redis.database"), 1);
this.baseKey = props.getProperty("redis.key", "gsearch-queue");
this.username = props.getProperty("username");
this.password = props.getProperty("password");
RedisURI uri = RedisURI.create(host,port);
uri.setDatabase(this.database);
if(password != null)
uri.setPassword(password.toCharArray());
if(username != null)
uri.setUsername(username);
this.client = RedisClient.create(uri);
log.info("Connected to {} at {}}:{}}\n", getRedisVersion(), this.host, this.port);
}
private String getRedisVersion() {
try (StatefulRedisConnection<String, String> connection = client.connect()) {
RedisCommands<String, String> cmd = connection.sync();
return cmd.info("redis_version");
}
}
@Override
public String name() {
return "redis";
}
@Override
public Queue queue(String type) {
return new Queue() {
private String key = type + '@' + baseKey;
@Override
public String type() {
return type;
}
@Override
public void push(Collection<QueueTask> tasks) {
try (StatefulRedisConnection<String, String> connection = client.connect()) {
RedisCommands<String, String> cmd = connection.sync();
cmd.rpush(key, tasks.stream().map(t -> t.json()).toArray(String[]::new));
}
}
@Override
public List<QueueTask> pop(int count) {
String json = null;
List<QueueTask> tasks = new ArrayList<>();
try (StatefulRedisConnection<String, String> connection = client.connect()) {
RedisCommands<String, String> cmd = connection.sync();
do{
json = cmd.lpop(key);
if(json == null)
break;
QueueTask task = QueueTask.parse(json);
if(task != null)
tasks.add(task);
}while(tasks.size() < count);
}
return tasks;
}
@Override
public void close() {}
};
}
@Override
public void close() {
client.shutdown();
}
}
package com.gitee.kooder.queue;
import com.gitee.kooder.core.KooderConfig;
import org.apache.commons.lang3.math.NumberUtils;
import org.infobip.lib.popout.FileQueue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
/**
* 实现 Gitee Search 内嵌式的队列,不依赖第三方服务,通过 HTTP 方式提供对象获取
* @author Winter Lau<javayou@gmail.com>
*/
public class EmbedQueueProvider implements QueueProvider {
private final static Logger log = LoggerFactory.getLogger(EmbedQueueProvider.class);
private Map<String, FileQueue<QueueTask>> fileQueues = new ConcurrentHashMap<>();
public EmbedQueueProvider(Properties props) {
int batch_size = NumberUtils.toInt(props.getProperty("embed.batch_size", "10000"), 10000);
Path path = checkoutPath(KooderConfig.getPath(props.getProperty("embed.path")));
for(String type : getAllTypes()) {
Path typePath = checkoutPath(path.resolve(type));
fileQueues.put(type, FileQueue.<QueueTask>batched().name(type)
.folder(typePath)
.restoreFromDisk(true)
.batchSize(batch_size)
.build());
}
}
private static Path checkoutPath(Path path) {
if(!Files.exists(path) || !Files.isDirectory(path)) {
log.warn("Path '{}' for queue storage not exists, created it!", path);
try {
Files.createDirectories(path);
} catch(IOException e) {
log.error("Failed to create directory '{}'", path, e);
}
}
return path;
}
/**
* 队列的唯一名称
*
* @return
*/
@Override
public String name() {
return "embed";
}
/**
* 获取某个任务类型的队列
*
* @param type
* @return
*/
@Override
public Queue queue(String type) {
return new Queue() {
@Override
public String type() {
return type;
}
@Override
public void push(Collection<QueueTask> tasks) {
fileQueues.get(type).addAll(tasks);
}
@Override
public List<QueueTask> pop(int count) {
List<QueueTask> tasks = new ArrayList<>();
QueueTask task;
while(tasks.size() < count && (task = fileQueues.get(type).poll()) != null)
tasks.add(task);
return tasks;
}
@Override
public void close() {
fileQueues.get(type).close();
}
};
}
@Override
public void close() {
fileQueues.values().forEach(q -> q.close());
}
}
package com.gitee.kooder.queue;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.type.TypeReference;
import com.gitee.kooder.models.CodeRepository;
import com.gitee.kooder.core.Constants;
import com.gitee.kooder.index.IndexManager;
import com.gitee.kooder.models.Issue;
import com.gitee.kooder.models.Repository;
import com.gitee.kooder.models.Searchable;
import com.gitee.kooder.utils.JsonUtils;
import org.apache.commons.beanutils.BeanUtils;
import org.apache.lucene.facet.taxonomy.TaxonomyWriter;
import org.apache.lucene.index.IndexWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* 队列中的任务
* @author Winter Lau<javayou@gmail.com>
*/
public class QueueTask implements Serializable {
private transient final static Logger log = LoggerFactory.getLogger(QueueTask.class);
public transient final static List<String> types = Arrays.asList(
Constants.TYPE_CODE,
Constants.TYPE_REPOSITORY,
Constants.TYPE_ISSUE,
Constants.TYPE_PR,
Constants.TYPE_COMMIT,
Constants.TYPE_WIKI,
Constants.TYPE_USER
);
public transient final static String ACTION_ADD = "add"; //添加
public transient final static String ACTION_UPDATE = "update"; //修改
public transient final static String ACTION_DELETE = "delete"; //删除
private String type; //对象类型
private String action; //动作(添加、删除、修改)
private List<Searchable> objects = new ArrayList<>(); //objects list
public QueueTask(){}
public static void push(String type, String action, Searchable...obj){
QueueTask task = new QueueTask();
task.type = type;
task.action = action;
task.objects.addAll(Arrays.asList(obj));
QueueFactory.getProvider().queue(type).push(Arrays.asList(task));
}
public static void add(String type, Searchable...obj) {
push(type, ACTION_ADD, obj);
}
public static void update(String type, Searchable...obj) {
push(type, ACTION_UPDATE, obj);
}
public static void delete(String type, Searchable...obj) {
push(type, ACTION_DELETE, obj);
}
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public final static boolean isAvailType(String p_type) {
return (p_type!=null) && types.contains(p_type.toLowerCase());
}
public final static boolean isAvailAction(String p_action) {
return ACTION_ADD.equalsIgnoreCase(p_action) || ACTION_DELETE.equalsIgnoreCase(p_action) || ACTION_UPDATE.equalsIgnoreCase(p_action);
}
public boolean isCodeTask() {
return Constants.TYPE_CODE.equals(type);
}
public String getAction() {
return action;
}
public void setAction(String action) {
this.action = action;
}
public List<Searchable> getObjects() {
return objects;
}
@JsonProperty("objects")
public void readObjects(Map<String,Object>[] values) throws Exception {
for(Map<String, Object> value : values) {
Searchable obj = null;
switch(type){
case Constants.TYPE_CODE:
obj = new CodeRepository();
break;
case Constants.TYPE_REPOSITORY:
obj = new Repository();
break;
case Constants.TYPE_ISSUE:
obj = new Issue();
}
BeanUtils.populate(obj, value);
objects.add(obj);
}
}
public void addObject(Searchable obj) {
objects.add(obj);
}
@JsonIgnore
public void setJsonObjects(String json) {
TypeReference typeRefer;
switch(type) {
case Constants.TYPE_CODE:
typeRefer = new TypeReference<List<CodeRepository>>(){};
break;
case Constants.TYPE_REPOSITORY:
typeRefer = new TypeReference<List<Repository>>() {};
break;
case Constants.TYPE_ISSUE:
typeRefer = new TypeReference<List<Issue>>() {};
break;
default:
throw new IllegalArgumentException("Illegal task type: " + type);
}
this.objects = (List<Searchable>)JsonUtils.readValue(json, typeRefer);
}
/**
* 写入索引库
* @return
* @exception
*/
public int write() throws IOException {
return IndexManager.write(this);
}
/**
* 用于多线程环境下共享 IndexWriter 写入
* @param i_writer
* @param t_writer
* @return
* @throws IOException
*/
public int write(IndexWriter i_writer, TaxonomyWriter t_writer) throws IOException {
return IndexManager.write(this, i_writer, t_writer);
}
/**
* 生成 json
* @return
*/
public String json() {
return JsonUtils.toJson(this);
}
public static QueueTask parse(String json) {
return JsonUtils.readValue(json, QueueTask.class);
}
@Override
public String toString() {
return "QueueTask{" +
"type='" + type + '\'' +
", action='" + action + '\'' +
", objects=" + objects +
'}';
}
public static void main(String[] args) {
String json = "{\"type\":\"code\",\"action\":\"add\",\"objects\":[{\"id\":379,\"doc_id\":0,\"doc_score\":0.0,\"enterprise\":10,\"scm\":\"git\",\"vender\":\"gitea\",\"name\":\"xxxxx\",\"url\":\"http://git.xxxxxx.com:3000/xxxx/xxxxx\",\"timestamp\":0,\"document\":{\"fields\":[{\"char_sequence_value\":\"379\"},{\"char_sequence_value\":\"gitea\"},{\"char_sequence_value\":\"10\"},{\"char_sequence_value\":\"http://git.xxxxx.com:3000/xxxx/xxxxx\"},{\"char_sequence_value\":\"xxxxx\"},{\"char_sequence_value\":\"git\"},{\"char_sequence_value\":\"1620462113883\"}]},\"relative_path\":\"000/000/000/xxxxx_379\",\"id_as_string\":\"379\"}],\"code_task\":true}";
QueueTask task = parse(json);
System.out.println(task);
}
}
public static void push(String type, String action, Searchable...obj){
QueueTask task = new QueueTask();
task.type = type;
task.action = action;
task.objects.addAll(Arrays.asList(obj));
QueueFactory.getProvider().queue(type).push(Arrays.asList(task));
}
List<QueueTask> tasks = provider.queue(type).pop(batch_fetch_count);