原文链接:https://java-design-patterns.com/patterns/promise/
Promise 模式:一种异步编程模式,它允许我们可以先开始一个任务的执行,并得到一个用于获取该任务执行结果的凭据对象,而不必等待该任务执行完毕就可以执行其他操作,等到我们需要该任务的执行结果时,再调用凭据对象的相关方法来获取,这样可以避免不必要的等待,增加了系统的并发性。
在Promise模式中,客户端代码调用某个异步方法所得到的返回值仅是一个凭据对象(该对象被称为Promise,意为“承诺”),凭借该对象,客户端代码可以获取异步方法相应的真正任务的执行结果。
Promise 模式
Promise 模式的支持类 PromiseSupport:
public class PromiseSupport<V> implements Future<V> {
// 定义运行状态
protected static final int RUNNING = 1;
protected static final int FAILED = 2;
protected static final int COMPLETED = 3;
// 锁对象
protected final Object lock;
// 当前状态
protected volatile int state = RUNNING;
// 返回值
protected V value;
// 异常返回值
protected Exception exception;
public PromiseSupport() {
this.lock = new Object();
}
/**
* 执行成功,将方法返回值回写
* fulfill:履行
*
* @param value 返回值
*/
protected void fulfill(V value) {
this.value = value;
this.state = COMPLETED;
synchronized (lock) {
// 方法执行完成,唤醒其他阻塞线程
// 比如阻塞在get()方法上的线程
lock.notifyAll();
}
}
/**
* 执行失败,异常回写
*
* @param exception 所执行方法抛出的异常
*/
protected void fulfillExceptionally(Exception exception) {
this.exception = exception;
this.state = FAILED;
synchronized (lock) {
// 任务执行过程抛出异常,执行结束
// 唤醒阻塞在get()方法上的线程
lock.notifyAll();
}
}
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
return false;
}
@Override
public boolean isCancelled() {
return false;
}
@Override
public boolean isDone() {
return state > RUNNING;
}
@Override
public V get() throws InterruptedException, ExecutionException {
synchronized (lock) {
// 任务未执行完
while (state == RUNNING) {
// 阻塞调用线程
lock.wait();
}
}
// 任务正常结束,将任务返回值返回
if (state == COMPLETED) {
return value;
}
// 任务异常结束,将异常抛出
throw new ExecutionException(exception);
}
@Override
public V get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
synchronized (lock) {
// 任务未执行完
while (state == RUNNING) {
try {
// 定时阻塞调用线程
lock.wait(unit.toMillis(timeout));
} catch (InterruptedException e) {
// 打印异常日志
System.out.println("Interrupted:" + e.getMessage());
// wait()被中断后会清楚线程的中断标识
// 重新设置当前线程的中断标志
Thread.currentThread().interrupt();
}
}
}
// 任务正常结束,将任务返回值返回
if (state == COMPLETED) {
return value;
}
// 任务异常结束,将异常抛出
throw new ExecutionException(exception);
}
}
Promise 支持:
public class Promise<V> extends PromiseSupport<V> {
// 要履行的动作
private Runnable fulfillmentAction;
// 动作执行过程中异常的处理
private Consumer<? super Throwable> exceptionHandler;
public Promise() {
}
@Override
protected void fulfill(V value) {
super.fulfill(value);
// 拦截器,进行后续处理
postFulfillment();
}
@Override
protected void fulfillExceptionally(Exception exception) {
super.fulfillExceptionally(exception);
// 针对异常的处理
handlerException();
// 执行后续处理
postFulfillment();
}
/**
* 处理任务执行过程中产生的异常
*/
private void handlerException() {
if (exception == null) {
return;
}
exceptionHandler.accept(exception);
}
/**
* 任务执行完毕后需要执行的动作
*/
private void postFulfillment() {
if (null == fulfillmentAction) {
return;
}
fulfillmentAction.run();
}
/**
* 异步任务执行
*
* @param task 待执行的任务
* @param executor 执行器
* @return
*/
public Promise<V> fulfillInAsync(final Callable<V> task, Executor executor) {
executor.execute(() -> {
try {
// 执行任务并将返回值回写
fulfill(task.call());
} catch (Exception e) {
// 执行任务产生异常,将异常回写
fulfillExceptionally(e);
}
});
return this;
}
/**
* 任务执行完后对返回值进行处理
*
* @param action
* @return
*/
public Promise<Void> thenAccept(Consumer<? super V> action) {
Promise<Void> dest = new Promise<>();
fulfillmentAction = new ConsumerAction(this, dest, action);
return dest;
}
/**
* 任务执行过程中异常的处理
*
* @param exceptionHandler
* @return
*/
public Promise<V> onError(Consumer<? super Throwable> exceptionHandler) {
this.exceptionHandler = exceptionHandler;
return this;
}
/**
* 将上一个Promise的处理结果传递给下一个Promise
*
* @param function
* @param <T>
* @return
*/
public <T> Promise<T> thenApply(Function<? super V, T> function) {
Promise<T> dest = new Promise<>();
fulfillmentAction = new TransferAction<T>(this, dest, function);
return dest;
}
private class ConsumerAction implements Runnable {
private final Promise<V> src;
private final Promise<Void> dest;
private final Consumer<? super V> action;
public ConsumerAction(Promise<V> src, Promise<Void> dest, Consumer<? super V> action) {
this.src = src;
this.dest = dest;
this.action = action;
}
@Override
public void run() {
try {
// 异步获取返回值
action.accept(src.get());
// 将空值回写
dest.fulfill(null);
} catch (Throwable t) {
// 异常
dest.fulfillExceptionally((Exception) t.getCause());
}
}
}
private class TransferAction<T> implements Runnable {
private final Promise<V> src;
private final Promise<T> dest;
private final Function<? super V, T> func;
public TransferAction(Promise<V> src, Promise<T> dest, Function<? super V, T> func) {
this.src = src;
this.dest = dest;
this.func = func;
}
@Override
public void run() {
try {
dest.fulfill(func.apply(src.get()));
} catch (Throwable t) {
dest.fulfillExceptionally((Exception) t.getCause());
}
}
}
}
应用
下载一个文件并计算其行数,计算的行数将被消耗并打印到控制台上:
下载工具类:Utility
public class Utility {
// 下载文件并返回文件路径
public static String downloadFile(String urlString) throws IOException {
System.out.println("正在下载url的文件" + urlString);
URL url = new URL(urlString);
File file = File.createTempFile("promise_pattern", null);
try (BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(url.openStream()));
FileWriter writer = new FileWriter(file);) {
String line;
while ((line = bufferedReader.readLine()) != null) {
writer.write(line);
writer.write("\n");
}
System.out.println("下载文件存储位置:" + file.getAbsolutePath());
return file.getAbsolutePath();
}
}
// 返回文件的行数
public static Integer countLines(String fileLocation) {
try(BufferedReader reader = new BufferedReader(new FileReader(fileLocation))) {
return Math.toIntExact(reader.lines().count());
} catch (IOException e) {
e.printStackTrace();
}
return 0;
}
// 计算每个字符出现的频率
// frequency:频率
public static Map<Character, Long> characterFrequency(String fileLocation) {
try (final BufferedReader reader = new BufferedReader(new FileReader(fileLocation))){
return reader.lines()
.flatMapToInt(String::chars)
.mapToObj(x->(char)x)
.collect(Collectors.groupingBy(Function.identity(),Collectors.counting()));
} catch (IOException e) {
e.printStackTrace();
}
return Collections.emptyMap();
}
public static Optional<Character> lowestFrequencyChar(Map<Character, Long> characterFrequency) {
return characterFrequency
.entrySet()
.stream()
.min(Comparator.comparingLong(Map.Entry::getValue))
.map(Map.Entry::getKey);
}
}
下载主类:App
public class App {
private static final String DEFAULT_URL =
"https://github.com/ZhSMM/java-design-patterns/blob/master/abstract-document/pom.xml";
private final ExecutorService executor;
private final CountDownLatch stopLatch;
public App() {
this.executor = Executors.newFixedThreadPool(2);
this.stopLatch = new CountDownLatch(2);
}
public static void main(String[] args) throws InterruptedException {
App app = new App();
try {
app.promiseUsage();
} finally {
app.stop();
}
}
private void promiseUsage() {
calculateLineCount();
calculateLowestFrequency();
}
private void calculateLowestFrequency() {
lowestFrequencyChar().thenAccept(character -> {
System.out.println("出现最少的字符:" + character);
taskCompleted();
});
}
private void calculateLineCount() {
countLines().thenAccept(lines -> {
System.out.println("文件行数:" + lines);
taskCompleted();
});
}
private Promise<Optional<Character>> lowestFrequencyChar() {
return characterFrequency().thenApply(Utility::lowestFrequencyChar);
}
private Promise<Map<Character, Long>> characterFrequency() {
return download(DEFAULT_URL).thenApply(Utility::characterFrequency);
}
private Promise<Integer> countLines() {
return download(DEFAULT_URL).thenApply(Utility::countLines);
}
// 异步下载
private Promise<String> download(String urlString) {
return new Promise<String>()
.fulfillInAsync(() -> Utility.downloadFile(urlString), executor)
.onError(throwable -> {
throwable.printStackTrace();
taskCompleted();
});
}
private void taskCompleted() {
// 相当于计数器建一
stopLatch.countDown();
}
// 关闭
private void stop() throws InterruptedException {
stopLatch.await();
executor.shutdownNow();
}
}