在框架里,我们需要有更强大的 Context,除了可以控制超时之外,常用的功能比如获取请求、返回结果、实现标准库的 Context 接口,也都要有。
自己封装的 Context 最终需要提供四类功能函数:
base 封装基本的函数功能,比如获取 http.Request 结构
context 实现标准 Context 接口
request 封装了 http.Request 的对外接口
response 封装了 http.ResponseWriter 对外接口
边界场景到这里,我们的超时逻辑设置就结束且生效了。但是,这样的代码逻辑只能算是及格,为什么这么说呢?因为它并没有覆盖所有的场景。我们的代码逻辑要再严谨一些,把边界场景也考虑进来。这里有两种可能:
异常事件、超时事件触发时,需要往 responseWriter 中写入信息,这个时候如果有其他 Goroutine 也要操作 responseWriter,会不会导致 responseWriter 中的信息出现乱序?
超时事件触发结束之后,已经往 responseWriter 中写入信息了,这个时候如果有其他 Goroutine 也要操作 responseWriter, 会不会导致 responseWriter 中的信息重复写入?
package basecontext
import (
"bytes"
"context"
"encoding/json"
"errors"
"io/ioutil"
"net/http"
"strconv"
"sync"
"time"
)
type Context struct {
request *http.Request
responseWriter http.ResponseWriter
ctx context.Context
handler ControllerHandler
//是否超时标记
hasTimeout bool
//写保护机制
writerMux *sync.Mutex
}
func NewContext(r *http.Request, w http.ResponseWriter) *Context{
return &Context{
request: r,
responseWriter: w,
ctx: r.Context(),
handler: nil,
hasTimeout: false,
writerMux: &sync.Mutex{},
}
}
// #region base function
func (ctx *Context) WriterMux() *sync.Mutex {
return ctx.writerMux
}
func (ctx *Context) GetRequest() *http.Request {
return ctx.request
}
func (ctx *Context) GetResponse() http.ResponseWriter {
return ctx.responseWriter
}
func (ctx *Context) SetHasTimeout() {
ctx.hasTimeout = true
}
func (ctx *Context) HasTimeout() bool {
return ctx.hasTimeout
}
// #endregion
func (ctx *Context) BaseContext() context.Context {
return ctx.request.Context()
}
// #region implement context.Context
func (ctx *Context) Deadline() (deadline time.Time, ok bool) {
return ctx.BaseContext().Deadline()
}
func (ctx *Context) Done() <-chan struct{} {
return ctx.BaseContext().Done()
}
func (ctx *Context) Err() error {
return ctx.BaseContext().Err()
}
func (ctx *Context) Value(key interface{}) interface{} {
return ctx.BaseContext().Value(key)
}
// #endregion
// #region query url
func (ctx *Context) QueryInt(key string, def int) int {
params := ctx.QueryAll();
if vals, ok := params[key]; ok {
len := len(vals)
if len > 0 {
intval, err := strconv.Atoi(vals[len-1])
if err != nil{
return def
}
return intval
}
}
return def
}
func (ctx *Context) QueryString(key string, def string) string{
params := ctx.QueryAll();
if vals, ok := params[key]; ok {
len := len(vals)
if len > 0 {
return vals[len - 1]
}
}
return def
}
func (ctx *Context) QueryArray(key string, def []string) []string{
params := ctx.QueryAll();
if vals, ok := params[key]; ok {
return vals
}
return def
}
func (ctx *Context) QueryAll() map[string][]string{
if ctx.request != nil{
return map[string][]string(ctx.request.URL.Query())
}
return map[string][]string{}
}
// #endregion
// #region form post
func (ctx *Context) FormInt(key string, def int) int {
params := ctx.FormAll()
if vals, ok := params[key]; ok {
len := len(vals)
if len > 0 {
intval, err := strconv.Atoi(vals[len-1])
if err != nil {
return def
}
return intval
}
}
return def
}
func (ctx *Context) FormString(key string, def string) string {
params := ctx.FormAll()
if vals, ok := params[key]; ok {
len := len(vals)
if len > 0 {
return vals[len-1]
}
}
return def
}
func (ctx *Context) FormArray(key string, def []string) []string{
params := ctx.QueryAll();
if vals, ok := params[key]; ok {
return vals
}
return def
}
func (ctx *Context) FormAll() map[string][]string{
if ctx.request != nil{
return map[string][]string(ctx.request.PostForm)
}
return map[string][]string{}
}
// #endregion
// #region application/json post
func (ctx *Context) BindJson(obj interface{}) error {
if ctx.request != nil {
body, err := ioutil.ReadAll(ctx.request.Body)
if err != nil {
return err
}
ctx.request.Body = ioutil.NopCloser(bytes.NewBuffer(body))
err = json.Unmarshal(body, obj)
if err != nil {
return err
}
} else {
return errors.New("ctx.request empty")
}
return nil
}
// #endregion
// #region response
func (ctx *Context) Json(status int, obj interface{}) error {
if ctx.HasTimeout() {
return nil
}
ctx.responseWriter.Header().Set("Content-Type", "application/json")
ctx.responseWriter.WriteHeader(status)
byt, err := json.Marshal(obj)
if err != nil {
ctx.responseWriter.WriteHeader(500)
return err
}
ctx.responseWriter.Write(byt)
return nil
}
func (ctx *Context) HTML(status int, obj interface{}, template string) error {
return nil
}
func (ctx *Context) Text(status int, obj string) error {
return nil
}
package basecontext
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"strconv"
"time"
)
type ControllerHandler func(c *Context) error
func FooControllerHandler(c *Context) error {
finish := make(chan struct{}, 1)
panicChan := make(chan interface{}, 1)
durationCtx, cancel := context.WithTimeout(c.BaseContext(), time.Duration(1*time.Second))
defer cancel()
// mu := sync.Mutex{}
go func() {
defer func() {
if p := recover(); p != nil {
panicChan <- p
}
}()
// Do real action
time.Sleep(10 * time.Second)
c.Json(200, "ok")
finish <- struct{}{}
}()
select {
case p := <-panicChan:
c.WriterMux().Lock()
defer c.WriterMux().Unlock()
log.Println(p)
c.Json(500, "panic")
case <-finish:
fmt.Println("finish")
case <-durationCtx.Done():
c.WriterMux().Lock()
defer c.WriterMux().Unlock()
c.Json(500, "time out")
c.SetHasTimeout()
}
return nil
}
func Foo(request *http.Request, response http.ResponseWriter) {
obj := map[string]interface{}{
"errno": 50001,
"errmsg": "inner error",
"data": nil,
}
response.Header().Set("Content-Type", "application/json")
foo := request.PostFormValue("foo")
if foo == "" {
foo = "10"
}
fooInt, err := strconv.Atoi(foo)
if err != nil {
response.WriteHeader(500)
return
}
obj["data"] = fooInt
byt, err := json.Marshal(obj)
if err != nil {
response.WriteHeader(500)
return
}
response.WriteHeader(200)
response.Write(byt)
return
}
// func Foo2(ctx *framework.Context) error {
// obj := map[string]interface{}{
// "errno": 50001,
// "errmsg": "inner error",
// "data": nil,
// }
// fooInt := ctx.FormInt("foo", 10)
// obj["data"] = fooInt
// return ctx.Json(http.StatusOK, obj)
// }
// func Foo3(ctx *framework.Context) error {
// rdb := redis.NewClient(&redis.Options{
// Addr: "localhost:6379",
// Password: "", // no password set
// DB: 0, // use default DB
// })
// return rdb.Set(ctx, "key", "value", 0).Err()
// }
package basecontext
import (
"log"
"net/http"
)
type Core struct {
router map[string]ControllerHandler
}
func NewCore() *Core {
return &Core{router: map[string]ControllerHandler{}}
}
func (c *Core) Get(url string, handler ControllerHandler) {
c.router[url] = handler
}
func (c *Core) ServeHTTP(response http.ResponseWriter, request *http.Request) {
log.Panicln("core.serveHTTP")
ctx := NewContext(request, response)
// 一个简单的路由选择器,这里直接写死为测试路由foo
router := c.router["foo"]
if router == nil {
return
}
log.Println("core.router")
router(ctx)
}
package basecontext
func RegisterRouter(core *Core) {
// core.Get("foo", framework.TimeoutHandler(FooControllerHandler, time.Second*1))
core.Get("foo", FooControllerHandler)
}
package example
import (
"github.com/programmerug/fibergorm/basecontext"
"net/http"
)
func test() {
core := basecontext.NewCore()
basecontext.RegisterRouter(core)
server := &http.Server{
Handler: core,
Addr: ":8888",
}
server.ListenAndServe()
}