一,实现思路

1,配置阶段

配置web.xml DispatcherServlet
设定init-param contextConfigLocation=classpath:application.xml
设定url-pattern /*
配置Annotation @Controller
@Service
@Autowrited
@RequestMapping

2,初始化阶段

调用init方法 加载配置文件
IOC容器初始化 MAP
扫描相关的类 scan-package=””
创建实例化并保存至容器 通过反射机制将类实例化放入IOC容器
进行DI操作 扫描IOC容器的实例,给没有赋值的属性自动填充
初始化HandlerMapping 讲一个URL和一个Method进行一对一的映射

3,运行阶段

调用doGet/doPost web容器调用doget、dopost,获取req和resp对象
匹配HandlerMapping 从req对象获取输入的URL,找到其对应的method
反射调用method.invoker() 利用反射调用方法并返回结果
response.getWrite().write() 将返回结果输出到浏览器

二,自定义配置

1,配置 application.properties 文件

为了解析方便,用 application.properties 来代替 application.xml 文件,具体配置内容如下:

  1. scanPackage=com.yhd.spring01

2,配置web.xml文件

<?xml version="1.0" encoding="UTF-8"?>
<web-app xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xmlns="http://java.sun.com/xml/ns/j2ee" xmlns:javaee="http://java.sun.com/xml/ns/javaee"
         xmlns:web="http://java.sun.com/xml/ns/javaee/web-app_2_5.xsd"
         xsi:schemaLocation="http://java.sun.com/xml/ns/j2ee
http://java.sun.com/xml/ns/j2ee/web-app_2_4.xsd"
         version="2.4">
    <display-name>YHD Web Application</display-name>
    <servlet>
        <servlet-name>yhdmvc</servlet-name>
        <servlet-class>com.yhd.spring01.servlet.HdDispatcherServlet</servlet-class>
        <init-param>
            <param-name>contextConfigLocation</param-name>
            <param-value>application.properties</param-value>
        </init-param>
        <load-on-startup>1</load-on-startup>
    </servlet>
    <servlet-mapping>
        <servlet-name>yhdmvc</servlet-name>
        <url-pattern>/*</url-pattern>
    </servlet-mapping>
</web-app>

3,自定义注解

@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface HdAutowired {
    String value() default "";
}
import java.lang.annotation.*;
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface HdController {
    String value() default "";
}
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface HdRequestMapping {
    String value() default "";
}
@Target({ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface HdRequestParam {
    String value() default "";
}
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface HdService {
    String value() default "";
}

4,编写模拟业务

@HdService
public class DemoService implements IDemoService {
    @Override
    public String get(String name) {
        return "My name is " + name;
    }
}
@HdController
@HdRequestMapping("/demo")
public class DemoController {

    @HdAutowired
    private IDemoService demoService;

    @HdRequestMapping("/query")
    public void query(HttpServletRequest req, HttpServletResponse resp,
                      @HdRequestParam("name") String name){
        String result = demoService.get(name);
        try {
            resp.getWriter().write(result);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
    @HdRequestMapping("/add")
    public void add(HttpServletRequest req, HttpServletResponse resp,
                    @HdRequestParam("a") Integer a, @HdRequestParam("b") Integer b){
        try {
            resp.getWriter().write(a + "+" + b + "=" + (a + b));
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
    @HdRequestMapping("/remove")
    public void remove(HttpServletRequest req,HttpServletResponse resp,
                       @HdRequestParam("id") Integer id){
    }
}

三,容器初始化

1.0版本

流程分析

1.首先在doGet方法里面调用doDispatcher方法,根据请求路径判断路径是否存在,如果不存在就返回404存在就从容器中拿到路径对应的方法,通过动态代理执行对应的方法

2.在类加载阶段,用流来加载配置文件,从配置文件读取配置的包扫描路径根据包扫描路径进行迭代遍历,利用反射创建所有类上标有controller注解的类加入到容器,并下钻到类中,将类中每一个方法的绝对访问路径和方法加入到容器,迭代遍历创建所有标有service注解的类,如果该类实现了接口,将该接口的全限定类型名和类实例对象也放入容器,达到根据接口注入的效果。

3.属性赋值,遍历容器中所有类,如果类中标有@autowried注解,将属性对应的值设置进去。

重要方法

1.clazz.isAnnotationPresent(HdController.class)
判断clazz上有没有HdController注解

2.field.set(mappings.get(clazz.getName()), mappings.get(beanName));
属性赋值:args1:给哪个属性设值,args2:设置的什么值

3.method.invoke(mappings.get(method.getDeclaringClass().getName()), new Object[]{req, resp, params.get(“name”)[0]});
通过动态代理执行方法,方法所在类名,方法参数

代码

/**
 * @author yhd
 * @createtime 2021/1/31 15:49
 * @description 模拟IOC容器的创建
 */
public class HdDispatcherServlet extends HttpServlet {
    //映射关系  访问路径-方法名   全限定类名-实例对象
    private Map<String, Object> mappings = new HashMap<>();


    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        this.doPost(req, resp);
    }

    @SneakyThrows
    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        doDispatch(req, resp);
    }

    private void doDispatch(HttpServletRequest req, HttpServletResponse resp) throws IOException, InvocationTargetException, IllegalAccessException {
        //组装路径
        String url = req.getRequestURI();
        String contextPath = req.getContextPath();
        url = url.replace(contextPath, "").replaceAll("/+", "/");
        //判断路径是否存在
        if (!this.mappings.containsKey(url)) {
            resp.getWriter().write("404 NotFound!");
            return;
        }
        //获取路径对应的方法参数,通过动态代理进行增强
        Method method = (Method) this.mappings.get(url);
        Map<String, String[]> params = req.getParameterMap();
        method.invoke(mappings.get(method.getDeclaringClass().getName()), new Object[]{req, resp, params.get("name")[0]});
    }

    @Override
    public void init(ServletConfig config) throws ServletException {
        InputStream is = null;
        try {
            //加载配置文件
            Properties configContext = new Properties();
            is = this.getClass().getClassLoader().getResourceAsStream(config.getInitParameter("contextConfigLocation"));
            configContext.load(is);
            //获取扫描路径
            String scanPackage = configContext.getProperty("scanPackage");
            doScanner(scanPackage);
            for (String className : mappings.keySet()) {
                if (!className.contains(".")) {
                    continue;
                }
                Class<?> clazz = Class.forName(className);
                //当前这个类上有没有controller注解
                if (clazz.isAnnotationPresent(HdController.class)) {
                    mappings.put(className, clazz.newInstance());
                    String baseUrl = "";
                    //判断有没有一级访问路径
                    if (clazz.isAnnotationPresent(HdRequestMapping.class)) {
                        HdRequestMapping requestMapping = clazz.getAnnotation(HdRequestMapping.class);
                        baseUrl = requestMapping.value();
                    }
                    Method[] methods = clazz.getMethods();
                    for (Method method : methods) {
                        if (!method.isAnnotationPresent(HdRequestMapping.class)) {
                            continue;
                        }
                        HdRequestMapping requestMapping = method.getAnnotation(HdRequestMapping.class);
                        //拼装路径
                        String url = (baseUrl + "/" + requestMapping.value()).replaceAll("/+", "/");
                        //map放的是:controller里面一个方法的访问绝对路径,这个对应的方法
                        mappings.put(url, method);
                        System.out.println("Mapped " + url + "," + method);
                    }
                } else if (clazz.isAnnotationPresent(HdService.class)) {
                    HdService service = clazz.getAnnotation(HdService.class);
                    String beanName = service.value();
                    if ("".equals(beanName)) {
                        beanName = clazz.getName();
                    }
                    Object instance = clazz.newInstance();
                    //map里面放的是类名和实例对象
                    mappings.put(beanName, instance);
                    //将这个类实现的接口和实例对象放进去
                    for (Class<?> i : clazz.getInterfaces()) {
                        mappings.put(i.getName(), instance);
                    }
                } else {
                    continue;
                }
            }
            //属性注入
            for (Object object : mappings.values()) {
                if (object == null) {
                    continue;
                }
                Class clazz = object.getClass();
                if (clazz.isAnnotationPresent(HdController.class)) {
                    Field[] fields = clazz.getDeclaredFields();
                    for (Field field : fields) {
                        if (!field.isAnnotationPresent(HdAutowired.class)) {
                            continue;
                        }
                        HdAutowired autowired = field.getAnnotation(HdAutowired.class);
                        String beanName = autowired.value();
                        if ("".equals(beanName)) {
                            beanName = field.getType().getName();
                        }
                        field.setAccessible(true);
                        try {
                            field.set(mappings.get(clazz.getName()), mappings.get(beanName));
                        } catch (IllegalAccessException e) {
                            e.printStackTrace();
                        }
                    }
                }
            }
            System.out.print("Diy MVC Framework is init");
        } catch (Exception e) {

        }
    }

    private void doScanner(String scanPackage) {
        URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
        File classDir = new File(url.getFile());
        Arrays.stream(classDir.listFiles()).forEach(file -> {
            if (file.isDirectory()) {
                doScanner(scanPackage + "." + file.getName());
            } else {
                if (!file.getName().endsWith(".class")) {
                    return;
                }
                String clazzName = (scanPackage + "." + file.getName().replace(".class", ""));
                mappings.put(clazzName, null);
            }
        });
    }
}

2.0版本

分析

1.0版本的所有代码都写在了一个方法里面,代码耦合度 十分高,不符合开发规范

思路

采用设计模式(工厂模式、单例模式、委派模式、策略模式),改造业务逻辑。

代码

/**
 * @author yhd
 * @createtime 2021/2/1 11:29
 */
public class HdDispatcherServlet2 extends HttpServlet {

    private Map<String, Object> ioc = new ConcurrentHashMap<>();

    private Map<String, Method> handlerMappings = new ConcurrentHashMap<>();

    private List<String> classNames = new CopyOnWriteArrayList<>();

    private Properties configContext = new Properties();

    private static final String CONFIG_LOCATION = "contextConfigLocation";

    @Override
    public void init(ServletConfig config) throws ServletException {
        //1.加载配置文件
        loadConfig(config.getInitParameter(CONFIG_LOCATION));
        //2.扫描所有的组件
        doScanPackages(configContext.getProperty("scanPackage"));
        //3.将组件加入到容器
        refersh();
        //4.属性设值
        population();
        //5.建立方法与路径的映射
        routingAndMapping();
    }

    /**
     * 建立方法与路径的映射
     */
    private void routingAndMapping() {
        classNames.forEach(className -> {
            Object instance = ioc.get(className);
            if (instance.getClass().isAnnotationPresent(HdController.class)) {
                String baseUrl = "";
                if (instance.getClass().isAnnotationPresent(HdRequestMapping.class)) {
                    baseUrl += instance.getClass().getAnnotation(HdRequestMapping.class).value().trim();
                }

                String finalBaseUrl = baseUrl;
                Arrays.asList(instance.getClass().getDeclaredMethods()).forEach(method -> {
                    if (method.isAnnotationPresent(HdRequestMapping.class)) {
                        String methodUrl = finalBaseUrl;
                        methodUrl += method.getAnnotation(HdRequestMapping.class).value().trim();
                        handlerMappings.put(methodUrl, method);
                    }
                });
            }
        });
    }

    /**
     * 属性设值
     */
    private void population() {
        Set<String> keySet = ioc.keySet();

        keySet.forEach(key -> {
            Field[] fields = ioc.get(key).getClass().getFields();
            Arrays.asList(fields).forEach(field -> {
                if (field.isAnnotationPresent(HdAutowired.class)) {
                    HdAutowired autowired = field.getAnnotation(HdAutowired.class);
                    String name = autowired.value().trim();
                    if ("".equals(autowired.value().trim())) {
                        name = field.getType().getName();
                    }
                    try {
                        field.setAccessible(true);
                        field.set(name, ioc.get(name));
                    } catch (IllegalAccessException e) {

                    }
                }
            });
        });

    }

    /**
     * 容器刷新
     * 组件加入到容器中
     */
    @SneakyThrows
    private void refersh() {
        if (classNames == null || classNames.isEmpty()) {
            throw new RuntimeException("组件扫描出现异常!");
        }
        for (String className : classNames) {
            Class<?> clazz = Class.forName(className);
            if (clazz.isAnnotationPresent(HdController.class)) {
                //TODO 类名处理
                ioc.put(clazz.getSimpleName(), clazz.newInstance());
            } else if (clazz.isAnnotationPresent(HdService.class)) {
                Object instance = clazz.newInstance();
                ioc.put(clazz.getSimpleName(), instance);
                Class<?>[] interfaces = clazz.getInterfaces();
                for (Class<?> inter : interfaces) {
                    ioc.put(inter.getSimpleName(), clazz);
                }
            } else {
                continue;
            }
        }

    }

    /**
     * 组件扫描
     *
     * @param scanPackage
     */
    private void doScanPackages(String scanPackage) {
        URL url = getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));

        File files = new File(url.getFile());
        for (File file : files.listFiles()) {
            if (file.isDirectory()) {
                doScanPackages(scanPackage + "." + file.getName());
            } else {
                if (!file.getName().endsWith(".class")) {
                    continue;
                }
                String className = scanPackage + "." + file.getName().replace(".class", "");
                classNames.add(className);
            }

        }
    }

    /**
     * 加载配置文件
     *
     * @param initParameter
     */
    @SneakyThrows
    private void loadConfig(String initParameter) {
        InputStream is = getClass().getClassLoader().getResourceAsStream(initParameter);
        configContext.load(is);
    }

    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        doPost(req, resp);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        try {
            doDispatcher(req, resp);
        } catch (Exception e) {
            throw new RuntimeException(" 500 server error!");
        }
    }

    @SneakyThrows
    private void doDispatcher(HttpServletRequest req, HttpServletResponse resp) {

        String realPath = req.getRequestURI().replace(req.getContextPath(), "");

        Map<String, String[]> parameterMap = req.getParameterMap();

        if (!handlerMappings.containsKey(realPath)) {
            throw new RuntimeException("404 Not Found!");
        }

        Method method = handlerMappings.get(realPath);

        Class<?>[] parameterTypes = method.getParameterTypes();

        Object[] paramValues = new Object[parameterTypes.length];

        for (int i = 0; i < parameterTypes.length - 1; i++) {

            Class param = parameterTypes[i];

            if (param == HttpServletRequest.class) {
                paramValues[i] = req;
            }

            if (param == HttpServletResponse.class) {
                paramValues[i] = resp;
            }

            if (param == String.class) {
                HdRequestParam requestParam = parameterTypes[i].getAnnotation(HdRequestParam.class);
                String value = requestParam.value();
                String[] realParam = parameterMap.get(value);
                paramValues[i] = Arrays.toString(realParam)
                        .replaceAll("\\[|\\]", "")
                        .replaceAll("\\s", ",");
            }
        }
        method.invoke(method.getDeclaringClass().getSimpleName(), paramValues);
    }

    private Object convertParamType() {
        return null;
    }
}

3.0版本

分析

HandlerMapping还不能像SpringMVC一样支持正则,url参数还不支持强制类型转换,反射调用之前还需要重新获取bean的name。

改造 HandlerMapping,在真实的 Spring 源码中,HandlerMapping 其实是一个 List 而非 Map。List 中的元素是一个自定义的类型。

思路

使用内部类维护requestMapping和url之间的关系。

代码

public class HdDispatcherServlet3 extends HttpServlet {

    private Map<String, Object> ioc = new ConcurrentHashMap<>();

    private Map<String, Method> handlerMappings = new ConcurrentHashMap<>();

    private List<String> classNames = new CopyOnWriteArrayList<>();

    private Properties configContext = new Properties();

    private static final String CONFIG_LOCATION = "contextConfigLocation";

    private List<Handler> handlerMapping = new ArrayList<>();

    /**
     *
     */
    @Data
    private class Handler {
        //保存方法对应的实例
        private Object controller;
        //保存映射的方法
        private Method method;
        //正则匹配
        private Pattern pattern;
        //参数顺序
        private Map<String, Integer> paramIndexMapping = new ConcurrentHashMap<>();

        public Handler(Pattern pattern, Object controller, Method method) {
            this.controller = controller;
            this.method = method;
            this.pattern = pattern;
            paramIndexMapping = new HashMap<String, Integer>();
            putParamIndexMapping(method);
        }

        private void putParamIndexMapping(Method method) {
            //提取方法中加了注解的参数
            Annotation[][] pa = method.getParameterAnnotations();

            for (int i = 0; i < pa.length; i++) {
                for (Annotation a : pa[i]) {
                    if (a instanceof HdRequestParam) {
                        String paramName = ((HdRequestParam) a).value();
                        if (!"".equals(paramName.trim())) {
                            paramIndexMapping.put(paramName, i);
                        }
                    }
                }
            }

            //提取方法中的req和resp
            Class<?>[] parameterTypes = method.getParameterTypes();

            for (int i = 0; i < parameterTypes.length; i++) {
                Class<?> type = parameterTypes[i];

                if (type == HttpServletRequest.class ||
                        type == HttpServletResponse.class) {
                    paramIndexMapping.put(type.getName(), i);
                }
            }
        }

    }

    @Override
    public void init(ServletConfig config) throws ServletException {
        //1.加载配置文件
        loadConfig(config.getInitParameter(CONFIG_LOCATION));
        //2.扫描所有的组件
        doScanPackages(configContext.getProperty("scanPackage"));
        //3.将组件加入到容器
        refersh();
        //4.属性设值
        population();
        //5.建立方法与路径的映射
        routingAndMapping();
    }

    /**
     * 建立方法与路径的映射
     */
    private void routingAndMapping() {
        if (ioc.isEmpty()) {
            return;
        }
        for (Map.Entry<String, Object> entry : ioc.entrySet()) {
            Class<?> clazz = entry.getValue().getClass();
            if (!clazz.isAnnotationPresent(HdController.class)) {
                continue;
            }
            String url = "";
            if (clazz.isAnnotationPresent(HdRequestMapping.class)) {
                HdRequestMapping requestMapping = clazz.getAnnotation(HdRequestMapping.class);
                url = requestMapping.value();
            }
            for (Method method : clazz.getMethods()) {
                if (!method.isAnnotationPresent(HdRequestMapping.class)) {
                    continue;
                }
                HdRequestMapping requestMapping = method.getAnnotation(HdRequestMapping.class);
                String regex = ("/" + url + requestMapping.value()).replaceAll("/+", "/");
                Pattern pattern = Pattern.compile(regex);
                handlerMapping.add(new Handler(pattern, entry.getValue(), method));
            }
        }
    }

    /**
     * 属性设值
     */
    private void population() {
        Set<String> keySet = ioc.keySet();

        keySet.forEach(key -> {
            Field[] fields = ioc.get(key).getClass().getFields();
            Arrays.asList(fields).forEach(field -> {
                if (field.isAnnotationPresent(HdAutowired.class)) {
                    HdAutowired autowired = field.getAnnotation(HdAutowired.class);
                    String name = autowired.value().trim();
                    if ("".equals(autowired.value().trim())) {
                        name = field.getType().getName();
                    }
                    try {
                        field.setAccessible(true);
                        field.set(name, ioc.get(name));
                    } catch (IllegalAccessException e) {

                    }
                }
            });
        });

    }

    /**
     * 容器刷新
     * 组件加入到容器中
     */
    @SneakyThrows
    private void refersh() {
        if (classNames == null || classNames.isEmpty()) {
            throw new RuntimeException("组件扫描出现异常!");
        }
        for (String className : classNames) {
            Class<?> clazz = Class.forName(className);
            if (clazz.isAnnotationPresent(HdController.class)) {
                //TODO 类名处理
                ioc.put(clazz.getSimpleName(), clazz.newInstance());
            } else if (clazz.isAnnotationPresent(HdService.class)) {
                Object instance = clazz.newInstance();
                ioc.put(clazz.getSimpleName(), instance);
                Class<?>[] interfaces = clazz.getInterfaces();
                for (Class<?> inter : interfaces) {
                    ioc.put(inter.getSimpleName(), clazz);
                }
            } else {
                continue;
            }
        }

    }

    /**
     * 组件扫描
     *
     * @param scanPackage
     */
    private void doScanPackages(String scanPackage) {
        URL url = getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));

        File files = new File(url.getFile());
        for (File file : files.listFiles()) {
            if (file.isDirectory()) {
                doScanPackages(scanPackage + "." + file.getName());
            } else {
                if (!file.getName().endsWith(".class")) {
                    continue;
                }
                String className = scanPackage + "." + file.getName().replace(".class", "");
                classNames.add(className);
            }

        }
    }

    /**
     * 加载配置文件
     *
     * @param initParameter
     */
    @SneakyThrows
    private void loadConfig(String initParameter) {
        InputStream is = getClass().getClassLoader().getResourceAsStream(initParameter);
        configContext.load(is);
    }

    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        doPost(req, resp);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        try {
            doDispatcher(req, resp);
        } catch (Exception e) {
            throw new RuntimeException(" 500 server error!");
        }
    }

    @SneakyThrows
    private void doDispatcher(HttpServletRequest req, HttpServletResponse resp) {
        Handler handler = getHandler(req);

        if (handler == null) {
            throw new RuntimeException("404 Not Found!");
        }

        Class<?>[] parameterTypes = handler.getMethod().getParameterTypes();

        Object[] paramValues = new Object[parameterTypes.length];

        Map<String, String[]> params = req.getParameterMap();

        for (Map.Entry<String, String[]> param : params.entrySet()) {
            String value = Arrays.toString(param.getValue()).replaceAll("\\[|\\]", "")
                    .replaceAll("\\s", ",");
            if (!handler.getParamIndexMapping().containsKey(param.getKey())) {
                continue;
            }
            Integer index = handler.getParamIndexMapping().get(param.getKey());
            paramValues[index] = this.convert(parameterTypes[index], value);
        }

        if (handler.paramIndexMapping.containsKey(HttpServletRequest.class.getName())) {
            int reqIndex = handler.paramIndexMapping.get(HttpServletRequest.class.getName());
            paramValues[reqIndex] = req;
        }
        if (handler.paramIndexMapping.containsKey(HttpServletResponse.class.getName())) {
            int respIndex = handler.paramIndexMapping.get(HttpServletResponse.class.getName());
            paramValues[respIndex] = resp;
        }

        Object returnValue = handler.getMethod().invoke(handler.getController(), paramValues);
        if (returnValue == null || returnValue instanceof Void) {
            return;
        }
        resp.getWriter().write(returnValue.toString());
    }

    private Object convert(Class<?> parameterType, String value) {
        if (Integer.class == parameterType) {
            return Integer.parseInt(value);
        }
        return value;
    }

    private Handler getHandler(HttpServletRequest req) {
        if (handlerMapping.isEmpty()) {
            return null;
        }
        String url = req.getRequestURI();
        String contextPath = req.getContextPath();
        url = url.replace(contextPath, "")
                .replaceAll("/+", "/");
        for (Handler handler : handlerMapping) {
            try {
                Matcher matcher = handler.pattern.matcher(url);
                //如果没有匹配上继续下一个匹配
                if (!matcher.matches()) {
                    continue;
                }
                return handler;
            } catch (Exception e) {
                throw e;
            }
        }
        return null;
    }
}