一,实现思路
1,配置阶段
配置web.xml | DispatcherServlet |
---|---|
设定init-param | contextConfigLocation=classpath:application.xml |
设定url-pattern | /* |
配置Annotation | @Controller @Service @Autowrited @RequestMapping |
2,初始化阶段
调用init方法 | 加载配置文件 |
---|---|
IOC容器初始化 | MAP |
扫描相关的类 | scan-package=”” |
创建实例化并保存至容器 | 通过反射机制将类实例化放入IOC容器 |
进行DI操作 | 扫描IOC容器的实例,给没有赋值的属性自动填充 |
初始化HandlerMapping | 讲一个URL和一个Method进行一对一的映射 |
3,运行阶段
调用doGet/doPost | web容器调用doget、dopost,获取req和resp对象 |
---|---|
匹配HandlerMapping | 从req对象获取输入的URL,找到其对应的method |
反射调用method.invoker() | 利用反射调用方法并返回结果 |
response.getWrite().write() | 将返回结果输出到浏览器 |
二,自定义配置
1,配置 application.properties 文件
为了解析方便,用 application.properties 来代替 application.xml 文件,具体配置内容如下:
scanPackage=com.yhd.spring01
2,配置web.xml文件
<?xml version="1.0" encoding="UTF-8"?>
<web-app xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://java.sun.com/xml/ns/j2ee" xmlns:javaee="http://java.sun.com/xml/ns/javaee"
xmlns:web="http://java.sun.com/xml/ns/javaee/web-app_2_5.xsd"
xsi:schemaLocation="http://java.sun.com/xml/ns/j2ee
http://java.sun.com/xml/ns/j2ee/web-app_2_4.xsd"
version="2.4">
<display-name>YHD Web Application</display-name>
<servlet>
<servlet-name>yhdmvc</servlet-name>
<servlet-class>com.yhd.spring01.servlet.HdDispatcherServlet</servlet-class>
<init-param>
<param-name>contextConfigLocation</param-name>
<param-value>application.properties</param-value>
</init-param>
<load-on-startup>1</load-on-startup>
</servlet>
<servlet-mapping>
<servlet-name>yhdmvc</servlet-name>
<url-pattern>/*</url-pattern>
</servlet-mapping>
</web-app>
3,自定义注解
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface HdAutowired {
String value() default "";
}
import java.lang.annotation.*;
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface HdController {
String value() default "";
}
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface HdRequestMapping {
String value() default "";
}
@Target({ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface HdRequestParam {
String value() default "";
}
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface HdService {
String value() default "";
}
4,编写模拟业务
@HdService
public class DemoService implements IDemoService {
@Override
public String get(String name) {
return "My name is " + name;
}
}
@HdController
@HdRequestMapping("/demo")
public class DemoController {
@HdAutowired
private IDemoService demoService;
@HdRequestMapping("/query")
public void query(HttpServletRequest req, HttpServletResponse resp,
@HdRequestParam("name") String name){
String result = demoService.get(name);
try {
resp.getWriter().write(result);
} catch (IOException e) {
e.printStackTrace();
}
}
@HdRequestMapping("/add")
public void add(HttpServletRequest req, HttpServletResponse resp,
@HdRequestParam("a") Integer a, @HdRequestParam("b") Integer b){
try {
resp.getWriter().write(a + "+" + b + "=" + (a + b));
} catch (IOException e) {
e.printStackTrace();
}
}
@HdRequestMapping("/remove")
public void remove(HttpServletRequest req,HttpServletResponse resp,
@HdRequestParam("id") Integer id){
}
}
三,容器初始化
1.0版本
流程分析
1.首先在doGet方法里面调用doDispatcher方法,根据请求路径判断路径是否存在,如果不存在就返回404存在就从容器中拿到路径对应的方法,通过动态代理执行对应的方法
2.在类加载阶段,用流来加载配置文件,从配置文件读取配置的包扫描路径根据包扫描路径进行迭代遍历,利用反射创建所有类上标有controller注解的类加入到容器,并下钻到类中,将类中每一个方法的绝对访问路径和方法加入到容器,迭代遍历创建所有标有service注解的类,如果该类实现了接口,将该接口的全限定类型名和类实例对象也放入容器,达到根据接口注入的效果。
3.属性赋值,遍历容器中所有类,如果类中标有@autowried注解,将属性对应的值设置进去。
重要方法
1.clazz.isAnnotationPresent(HdController.class)
判断clazz上有没有HdController注解
2.field.set(mappings.get(clazz.getName()), mappings.get(beanName));
属性赋值:args1:给哪个属性设值,args2:设置的什么值
3.method.invoke(mappings.get(method.getDeclaringClass().getName()), new Object[]{req, resp, params.get(“name”)[0]});
通过动态代理执行方法,方法所在类名,方法参数
代码
/**
* @author yhd
* @createtime 2021/1/31 15:49
* @description 模拟IOC容器的创建
*/
public class HdDispatcherServlet extends HttpServlet {
//映射关系 访问路径-方法名 全限定类名-实例对象
private Map<String, Object> mappings = new HashMap<>();
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
this.doPost(req, resp);
}
@SneakyThrows
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
doDispatch(req, resp);
}
private void doDispatch(HttpServletRequest req, HttpServletResponse resp) throws IOException, InvocationTargetException, IllegalAccessException {
//组装路径
String url = req.getRequestURI();
String contextPath = req.getContextPath();
url = url.replace(contextPath, "").replaceAll("/+", "/");
//判断路径是否存在
if (!this.mappings.containsKey(url)) {
resp.getWriter().write("404 NotFound!");
return;
}
//获取路径对应的方法参数,通过动态代理进行增强
Method method = (Method) this.mappings.get(url);
Map<String, String[]> params = req.getParameterMap();
method.invoke(mappings.get(method.getDeclaringClass().getName()), new Object[]{req, resp, params.get("name")[0]});
}
@Override
public void init(ServletConfig config) throws ServletException {
InputStream is = null;
try {
//加载配置文件
Properties configContext = new Properties();
is = this.getClass().getClassLoader().getResourceAsStream(config.getInitParameter("contextConfigLocation"));
configContext.load(is);
//获取扫描路径
String scanPackage = configContext.getProperty("scanPackage");
doScanner(scanPackage);
for (String className : mappings.keySet()) {
if (!className.contains(".")) {
continue;
}
Class<?> clazz = Class.forName(className);
//当前这个类上有没有controller注解
if (clazz.isAnnotationPresent(HdController.class)) {
mappings.put(className, clazz.newInstance());
String baseUrl = "";
//判断有没有一级访问路径
if (clazz.isAnnotationPresent(HdRequestMapping.class)) {
HdRequestMapping requestMapping = clazz.getAnnotation(HdRequestMapping.class);
baseUrl = requestMapping.value();
}
Method[] methods = clazz.getMethods();
for (Method method : methods) {
if (!method.isAnnotationPresent(HdRequestMapping.class)) {
continue;
}
HdRequestMapping requestMapping = method.getAnnotation(HdRequestMapping.class);
//拼装路径
String url = (baseUrl + "/" + requestMapping.value()).replaceAll("/+", "/");
//map放的是:controller里面一个方法的访问绝对路径,这个对应的方法
mappings.put(url, method);
System.out.println("Mapped " + url + "," + method);
}
} else if (clazz.isAnnotationPresent(HdService.class)) {
HdService service = clazz.getAnnotation(HdService.class);
String beanName = service.value();
if ("".equals(beanName)) {
beanName = clazz.getName();
}
Object instance = clazz.newInstance();
//map里面放的是类名和实例对象
mappings.put(beanName, instance);
//将这个类实现的接口和实例对象放进去
for (Class<?> i : clazz.getInterfaces()) {
mappings.put(i.getName(), instance);
}
} else {
continue;
}
}
//属性注入
for (Object object : mappings.values()) {
if (object == null) {
continue;
}
Class clazz = object.getClass();
if (clazz.isAnnotationPresent(HdController.class)) {
Field[] fields = clazz.getDeclaredFields();
for (Field field : fields) {
if (!field.isAnnotationPresent(HdAutowired.class)) {
continue;
}
HdAutowired autowired = field.getAnnotation(HdAutowired.class);
String beanName = autowired.value();
if ("".equals(beanName)) {
beanName = field.getType().getName();
}
field.setAccessible(true);
try {
field.set(mappings.get(clazz.getName()), mappings.get(beanName));
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
}
System.out.print("Diy MVC Framework is init");
} catch (Exception e) {
}
}
private void doScanner(String scanPackage) {
URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
File classDir = new File(url.getFile());
Arrays.stream(classDir.listFiles()).forEach(file -> {
if (file.isDirectory()) {
doScanner(scanPackage + "." + file.getName());
} else {
if (!file.getName().endsWith(".class")) {
return;
}
String clazzName = (scanPackage + "." + file.getName().replace(".class", ""));
mappings.put(clazzName, null);
}
});
}
}
2.0版本
分析
1.0版本的所有代码都写在了一个方法里面,代码耦合度 十分高,不符合开发规范
思路
采用设计模式(工厂模式、单例模式、委派模式、策略模式),改造业务逻辑。
代码
/**
* @author yhd
* @createtime 2021/2/1 11:29
*/
public class HdDispatcherServlet2 extends HttpServlet {
private Map<String, Object> ioc = new ConcurrentHashMap<>();
private Map<String, Method> handlerMappings = new ConcurrentHashMap<>();
private List<String> classNames = new CopyOnWriteArrayList<>();
private Properties configContext = new Properties();
private static final String CONFIG_LOCATION = "contextConfigLocation";
@Override
public void init(ServletConfig config) throws ServletException {
//1.加载配置文件
loadConfig(config.getInitParameter(CONFIG_LOCATION));
//2.扫描所有的组件
doScanPackages(configContext.getProperty("scanPackage"));
//3.将组件加入到容器
refersh();
//4.属性设值
population();
//5.建立方法与路径的映射
routingAndMapping();
}
/**
* 建立方法与路径的映射
*/
private void routingAndMapping() {
classNames.forEach(className -> {
Object instance = ioc.get(className);
if (instance.getClass().isAnnotationPresent(HdController.class)) {
String baseUrl = "";
if (instance.getClass().isAnnotationPresent(HdRequestMapping.class)) {
baseUrl += instance.getClass().getAnnotation(HdRequestMapping.class).value().trim();
}
String finalBaseUrl = baseUrl;
Arrays.asList(instance.getClass().getDeclaredMethods()).forEach(method -> {
if (method.isAnnotationPresent(HdRequestMapping.class)) {
String methodUrl = finalBaseUrl;
methodUrl += method.getAnnotation(HdRequestMapping.class).value().trim();
handlerMappings.put(methodUrl, method);
}
});
}
});
}
/**
* 属性设值
*/
private void population() {
Set<String> keySet = ioc.keySet();
keySet.forEach(key -> {
Field[] fields = ioc.get(key).getClass().getFields();
Arrays.asList(fields).forEach(field -> {
if (field.isAnnotationPresent(HdAutowired.class)) {
HdAutowired autowired = field.getAnnotation(HdAutowired.class);
String name = autowired.value().trim();
if ("".equals(autowired.value().trim())) {
name = field.getType().getName();
}
try {
field.setAccessible(true);
field.set(name, ioc.get(name));
} catch (IllegalAccessException e) {
}
}
});
});
}
/**
* 容器刷新
* 组件加入到容器中
*/
@SneakyThrows
private void refersh() {
if (classNames == null || classNames.isEmpty()) {
throw new RuntimeException("组件扫描出现异常!");
}
for (String className : classNames) {
Class<?> clazz = Class.forName(className);
if (clazz.isAnnotationPresent(HdController.class)) {
//TODO 类名处理
ioc.put(clazz.getSimpleName(), clazz.newInstance());
} else if (clazz.isAnnotationPresent(HdService.class)) {
Object instance = clazz.newInstance();
ioc.put(clazz.getSimpleName(), instance);
Class<?>[] interfaces = clazz.getInterfaces();
for (Class<?> inter : interfaces) {
ioc.put(inter.getSimpleName(), clazz);
}
} else {
continue;
}
}
}
/**
* 组件扫描
*
* @param scanPackage
*/
private void doScanPackages(String scanPackage) {
URL url = getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
File files = new File(url.getFile());
for (File file : files.listFiles()) {
if (file.isDirectory()) {
doScanPackages(scanPackage + "." + file.getName());
} else {
if (!file.getName().endsWith(".class")) {
continue;
}
String className = scanPackage + "." + file.getName().replace(".class", "");
classNames.add(className);
}
}
}
/**
* 加载配置文件
*
* @param initParameter
*/
@SneakyThrows
private void loadConfig(String initParameter) {
InputStream is = getClass().getClassLoader().getResourceAsStream(initParameter);
configContext.load(is);
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
doPost(req, resp);
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
try {
doDispatcher(req, resp);
} catch (Exception e) {
throw new RuntimeException(" 500 server error!");
}
}
@SneakyThrows
private void doDispatcher(HttpServletRequest req, HttpServletResponse resp) {
String realPath = req.getRequestURI().replace(req.getContextPath(), "");
Map<String, String[]> parameterMap = req.getParameterMap();
if (!handlerMappings.containsKey(realPath)) {
throw new RuntimeException("404 Not Found!");
}
Method method = handlerMappings.get(realPath);
Class<?>[] parameterTypes = method.getParameterTypes();
Object[] paramValues = new Object[parameterTypes.length];
for (int i = 0; i < parameterTypes.length - 1; i++) {
Class param = parameterTypes[i];
if (param == HttpServletRequest.class) {
paramValues[i] = req;
}
if (param == HttpServletResponse.class) {
paramValues[i] = resp;
}
if (param == String.class) {
HdRequestParam requestParam = parameterTypes[i].getAnnotation(HdRequestParam.class);
String value = requestParam.value();
String[] realParam = parameterMap.get(value);
paramValues[i] = Arrays.toString(realParam)
.replaceAll("\\[|\\]", "")
.replaceAll("\\s", ",");
}
}
method.invoke(method.getDeclaringClass().getSimpleName(), paramValues);
}
private Object convertParamType() {
return null;
}
}
3.0版本
分析
HandlerMapping还不能像SpringMVC一样支持正则,url参数还不支持强制类型转换,反射调用之前还需要重新获取bean的name。
改造 HandlerMapping,在真实的 Spring 源码中,HandlerMapping 其实是一个 List 而非 Map。List 中的元素是一个自定义的类型。
思路
使用内部类维护requestMapping和url之间的关系。
代码
public class HdDispatcherServlet3 extends HttpServlet {
private Map<String, Object> ioc = new ConcurrentHashMap<>();
private Map<String, Method> handlerMappings = new ConcurrentHashMap<>();
private List<String> classNames = new CopyOnWriteArrayList<>();
private Properties configContext = new Properties();
private static final String CONFIG_LOCATION = "contextConfigLocation";
private List<Handler> handlerMapping = new ArrayList<>();
/**
*
*/
@Data
private class Handler {
//保存方法对应的实例
private Object controller;
//保存映射的方法
private Method method;
//正则匹配
private Pattern pattern;
//参数顺序
private Map<String, Integer> paramIndexMapping = new ConcurrentHashMap<>();
public Handler(Pattern pattern, Object controller, Method method) {
this.controller = controller;
this.method = method;
this.pattern = pattern;
paramIndexMapping = new HashMap<String, Integer>();
putParamIndexMapping(method);
}
private void putParamIndexMapping(Method method) {
//提取方法中加了注解的参数
Annotation[][] pa = method.getParameterAnnotations();
for (int i = 0; i < pa.length; i++) {
for (Annotation a : pa[i]) {
if (a instanceof HdRequestParam) {
String paramName = ((HdRequestParam) a).value();
if (!"".equals(paramName.trim())) {
paramIndexMapping.put(paramName, i);
}
}
}
}
//提取方法中的req和resp
Class<?>[] parameterTypes = method.getParameterTypes();
for (int i = 0; i < parameterTypes.length; i++) {
Class<?> type = parameterTypes[i];
if (type == HttpServletRequest.class ||
type == HttpServletResponse.class) {
paramIndexMapping.put(type.getName(), i);
}
}
}
}
@Override
public void init(ServletConfig config) throws ServletException {
//1.加载配置文件
loadConfig(config.getInitParameter(CONFIG_LOCATION));
//2.扫描所有的组件
doScanPackages(configContext.getProperty("scanPackage"));
//3.将组件加入到容器
refersh();
//4.属性设值
population();
//5.建立方法与路径的映射
routingAndMapping();
}
/**
* 建立方法与路径的映射
*/
private void routingAndMapping() {
if (ioc.isEmpty()) {
return;
}
for (Map.Entry<String, Object> entry : ioc.entrySet()) {
Class<?> clazz = entry.getValue().getClass();
if (!clazz.isAnnotationPresent(HdController.class)) {
continue;
}
String url = "";
if (clazz.isAnnotationPresent(HdRequestMapping.class)) {
HdRequestMapping requestMapping = clazz.getAnnotation(HdRequestMapping.class);
url = requestMapping.value();
}
for (Method method : clazz.getMethods()) {
if (!method.isAnnotationPresent(HdRequestMapping.class)) {
continue;
}
HdRequestMapping requestMapping = method.getAnnotation(HdRequestMapping.class);
String regex = ("/" + url + requestMapping.value()).replaceAll("/+", "/");
Pattern pattern = Pattern.compile(regex);
handlerMapping.add(new Handler(pattern, entry.getValue(), method));
}
}
}
/**
* 属性设值
*/
private void population() {
Set<String> keySet = ioc.keySet();
keySet.forEach(key -> {
Field[] fields = ioc.get(key).getClass().getFields();
Arrays.asList(fields).forEach(field -> {
if (field.isAnnotationPresent(HdAutowired.class)) {
HdAutowired autowired = field.getAnnotation(HdAutowired.class);
String name = autowired.value().trim();
if ("".equals(autowired.value().trim())) {
name = field.getType().getName();
}
try {
field.setAccessible(true);
field.set(name, ioc.get(name));
} catch (IllegalAccessException e) {
}
}
});
});
}
/**
* 容器刷新
* 组件加入到容器中
*/
@SneakyThrows
private void refersh() {
if (classNames == null || classNames.isEmpty()) {
throw new RuntimeException("组件扫描出现异常!");
}
for (String className : classNames) {
Class<?> clazz = Class.forName(className);
if (clazz.isAnnotationPresent(HdController.class)) {
//TODO 类名处理
ioc.put(clazz.getSimpleName(), clazz.newInstance());
} else if (clazz.isAnnotationPresent(HdService.class)) {
Object instance = clazz.newInstance();
ioc.put(clazz.getSimpleName(), instance);
Class<?>[] interfaces = clazz.getInterfaces();
for (Class<?> inter : interfaces) {
ioc.put(inter.getSimpleName(), clazz);
}
} else {
continue;
}
}
}
/**
* 组件扫描
*
* @param scanPackage
*/
private void doScanPackages(String scanPackage) {
URL url = getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
File files = new File(url.getFile());
for (File file : files.listFiles()) {
if (file.isDirectory()) {
doScanPackages(scanPackage + "." + file.getName());
} else {
if (!file.getName().endsWith(".class")) {
continue;
}
String className = scanPackage + "." + file.getName().replace(".class", "");
classNames.add(className);
}
}
}
/**
* 加载配置文件
*
* @param initParameter
*/
@SneakyThrows
private void loadConfig(String initParameter) {
InputStream is = getClass().getClassLoader().getResourceAsStream(initParameter);
configContext.load(is);
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
doPost(req, resp);
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
try {
doDispatcher(req, resp);
} catch (Exception e) {
throw new RuntimeException(" 500 server error!");
}
}
@SneakyThrows
private void doDispatcher(HttpServletRequest req, HttpServletResponse resp) {
Handler handler = getHandler(req);
if (handler == null) {
throw new RuntimeException("404 Not Found!");
}
Class<?>[] parameterTypes = handler.getMethod().getParameterTypes();
Object[] paramValues = new Object[parameterTypes.length];
Map<String, String[]> params = req.getParameterMap();
for (Map.Entry<String, String[]> param : params.entrySet()) {
String value = Arrays.toString(param.getValue()).replaceAll("\\[|\\]", "")
.replaceAll("\\s", ",");
if (!handler.getParamIndexMapping().containsKey(param.getKey())) {
continue;
}
Integer index = handler.getParamIndexMapping().get(param.getKey());
paramValues[index] = this.convert(parameterTypes[index], value);
}
if (handler.paramIndexMapping.containsKey(HttpServletRequest.class.getName())) {
int reqIndex = handler.paramIndexMapping.get(HttpServletRequest.class.getName());
paramValues[reqIndex] = req;
}
if (handler.paramIndexMapping.containsKey(HttpServletResponse.class.getName())) {
int respIndex = handler.paramIndexMapping.get(HttpServletResponse.class.getName());
paramValues[respIndex] = resp;
}
Object returnValue = handler.getMethod().invoke(handler.getController(), paramValues);
if (returnValue == null || returnValue instanceof Void) {
return;
}
resp.getWriter().write(returnValue.toString());
}
private Object convert(Class<?> parameterType, String value) {
if (Integer.class == parameterType) {
return Integer.parseInt(value);
}
return value;
}
private Handler getHandler(HttpServletRequest req) {
if (handlerMapping.isEmpty()) {
return null;
}
String url = req.getRequestURI();
String contextPath = req.getContextPath();
url = url.replace(contextPath, "")
.replaceAll("/+", "/");
for (Handler handler : handlerMapping) {
try {
Matcher matcher = handler.pattern.matcher(url);
//如果没有匹配上继续下一个匹配
if (!matcher.matches()) {
continue;
}
return handler;
} catch (Exception e) {
throw e;
}
}
return null;
}
}