Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,17 @@ var (
type contextKey int

// userContextKey define the key name for storing context.Context in *fasthttp.RequestCtx
const userContextKey contextKey = 0 // __local_user_context__
const (
userContextKey contextKey = iota // __local_user_context__
)

// DefaultCtx is the default implementation of the Ctx interface
// generation tool `go install github.com/vburenin/ifacemaker@f30b6f9bdbed4b5c4804ec9ba4a04a999525c202`
// https://github.com/vburenin/ifacemaker/blob/f30b6f9bdbed4b5c4804ec9ba4a04a999525c202/ifacemaker.go#L14-L31
//
//go:generate ifacemaker --file ctx.go --file req.go --file res.go --struct DefaultCtx --iface Ctx --pkg fiber --promoted --output ctx_interface_gen.go --not-exported true --iface-comment "Ctx represents the Context which hold the HTTP request and response.\nIt has methods for the request query string, parameters, body, HTTP headers and so on."
type DefaultCtx struct {
handlerCtx CustomCtx // Active custom context implementation, if any
DefaultReq // Default request api
DefaultRes // Default response api
app *App // Reference to *App
Expand Down Expand Up @@ -227,21 +230,45 @@ func (c *DefaultCtx) Next() error {
// Did we execute all route handlers?
if c.indexHandler < len(c.route.Handlers) {
// Continue route stack
return c.route.Handlers[c.indexHandler](c)
return c.route.Handlers[c.indexHandler](activeHandler(c))
}

// Continue handler stack
_, err := c.app.next(c)
return err
return continueHandlers(c)
}

// RestartRouting instead of going to the next handler. This may be useful after
// changing the request path. Note that handlers might be executed again.
func (c *DefaultCtx) RestartRouting() error {
var err error

c.indexRoute = -1
_, err = c.app.next(c)
return continueHandlers(c)
}

func (c *DefaultCtx) setHandlerCtx(ctx CustomCtx) {
if ctx == nil {
c.handlerCtx = nil
return
}
if defaultCtx, ok := ctx.(*DefaultCtx); ok && defaultCtx == c {
c.handlerCtx = nil
return
}
c.handlerCtx = ctx
}

func activeHandler(c *DefaultCtx) Ctx {
if c.handlerCtx != nil {
return c.handlerCtx
}
return c
}

func continueHandlers(c *DefaultCtx) error {
if c.handlerCtx != nil {
_, err := c.app.nextCustom(c.handlerCtx)
return err
}
_, err := c.app.next(c)
return err
}

Expand Down Expand Up @@ -587,6 +614,7 @@ func (c *DefaultCtx) release() {
c.redirect = nil
}
c.skipNonUseRoutes = false
c.handlerCtx = nil
c.DefaultReq.release()
c.DefaultRes.release()
}
Expand Down
5 changes: 5 additions & 0 deletions ctx_interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ func (app *App) AcquireCtx(fctx *fasthttp.RequestCtx) CustomCtx {
if !ok {
panic(errors.New("failed to type-assert to CustomCtx"))
}

if setter, ok := ctx.(interface{ setHandlerCtx(CustomCtx) }); ok {
setter.setHandlerCtx(ctx)
}

ctx.Reset(fctx)

return ctx
Expand Down
1 change: 1 addition & 0 deletions ctx_interface_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions ctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,36 @@ func Test_Ctx_CustomCtx(t *testing.T) {
require.Equal(t, int64(len(body)), resp.ContentLength)
}

func Test_Ctx_CustomCtx_WithMiddleware(t *testing.T) {
t.Parallel()

app := NewWithCustomCtx(func(app *App) CustomCtx {
return &customCtx{
DefaultCtx: *NewDefaultCtx(app),
}
})

app.Use(func(c Ctx) error {
_, ok := c.(*customCtx)
require.True(t, ok)
return c.Next()
})

app.Get("/", func(c Ctx) error {
custom, ok := c.(*customCtx)
require.True(t, ok)
return c.SendString(custom.Params(""))
})

resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil))
require.NoError(t, err, "app.Test(req)")
defer func() { require.NoError(t, resp.Body.Close()) }()

body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "io.ReadAll(resp.Body)")
require.Equal(t, "prefix_", string(body))
}

// go test -run Test_Ctx_CustomCtx
func Test_Ctx_CustomCtx_and_Method(t *testing.T) {
t.Parallel()
Expand Down
Loading