Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,18 @@ var (
type contextKey int

// userContextKey define the key name for storing context.Context in *fasthttp.RequestCtx
const userContextKey contextKey = 0 // __local_user_context__
const (
userContextKey contextKey = 0 // __local_user_context__
handlerContextKey contextKey = 1 // __local_handler_context__
)

// DefaultCtx is the default implementation of the Ctx interface
// generation tool `go install github.com/vburenin/ifacemaker@f30b6f9bdbed4b5c4804ec9ba4a04a999525c202`
// https://github.com/vburenin/ifacemaker/blob/f30b6f9bdbed4b5c4804ec9ba4a04a999525c202/ifacemaker.go#L14-L31
//
//go:generate ifacemaker --file ctx.go --file req.go --file res.go --struct DefaultCtx --iface Ctx --pkg fiber --promoted --output ctx_interface_gen.go --not-exported true --iface-comment "Ctx represents the Context which hold the HTTP request and response.\nIt has methods for the request query string, parameters, body, HTTP headers and so on."
type DefaultCtx struct {
customCtx CustomCtx // Active custom context implementation, if any
DefaultReq // Default request api
DefaultRes // Default response api
app *App // Reference to *App
Expand Down Expand Up @@ -227,10 +231,19 @@ func (c *DefaultCtx) Next() error {
// Did we execute all route handlers?
if c.indexHandler < len(c.route.Handlers) {
// Continue route stack
return c.route.Handlers[c.indexHandler](c)
handler := Ctx(c)
if c.customCtx != nil {
handler = c.customCtx
}
return c.route.Handlers[c.indexHandler](handler)
}

// Continue handler stack
if c.customCtx != nil {
_, err := c.app.nextCustom(c.customCtx)
return err
}

_, err := c.app.next(c)
return err
}
Expand All @@ -241,6 +254,11 @@ func (c *DefaultCtx) RestartRouting() error {
var err error

c.indexRoute = -1
if c.customCtx != nil {
_, err = c.app.nextCustom(c.customCtx)
return err
}

_, err = c.app.next(c)
return err
Comment thread
ReneWerner87 marked this conversation as resolved.
Outdated
}
Expand Down Expand Up @@ -569,7 +587,13 @@ func (c *DefaultCtx) Reset(fctx *fasthttp.RequestCtx) {

c.DefaultReq.c = c
c.DefaultRes.c = c
if ctx, ok := fctx.UserValue(handlerContextKey).(CustomCtx); ok {
c.customCtx = ctx
} else {
c.customCtx = nil
}
Comment thread
ReneWerner87 marked this conversation as resolved.
Outdated
c.fasthttp.SetUserValue(userContextKey, nil)
c.fasthttp.SetUserValue(handlerContextKey, nil)
}

// Release is a method to reset context fields when to use ReleaseCtx()
Expand All @@ -587,6 +611,7 @@ func (c *DefaultCtx) release() {
c.redirect = nil
}
c.skipNonUseRoutes = false
c.customCtx = nil
c.DefaultReq.release()
c.DefaultRes.release()
}
Expand Down
7 changes: 7 additions & 0 deletions ctx_interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ func (app *App) AcquireCtx(fctx *fasthttp.RequestCtx) CustomCtx {
if !ok {
panic(errors.New("failed to type-assert to CustomCtx"))
}

if _, isDefault := ctx.(*DefaultCtx); isDefault {
fctx.SetUserValue(handlerContextKey, nil)
} else {
fctx.SetUserValue(handlerContextKey, ctx)
}

ctx.Reset(fctx)

return ctx
Expand Down
30 changes: 30 additions & 0 deletions ctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,36 @@ func Test_Ctx_CustomCtx(t *testing.T) {
require.Equal(t, int64(len(body)), resp.ContentLength)
}

func Test_Ctx_CustomCtx_WithMiddleware(t *testing.T) {
t.Parallel()

app := NewWithCustomCtx(func(app *App) CustomCtx {
return &customCtx{
DefaultCtx: *NewDefaultCtx(app),
}
})

app.Use(func(c Ctx) error {
_, ok := c.(*customCtx)
require.True(t, ok)
return c.Next()
})

app.Get("/", func(c Ctx) error {
custom, ok := c.(*customCtx)
require.True(t, ok)
return c.SendString(custom.Params(""))
})

resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil))
require.NoError(t, err, "app.Test(req)")
defer func() { require.NoError(t, resp.Body.Close()) }()

body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "io.ReadAll(resp.Body)")
require.Equal(t, "prefix_", string(body))
}

// go test -run Test_Ctx_CustomCtx
func Test_Ctx_CustomCtx_and_Method(t *testing.T) {
t.Parallel()
Expand Down
Loading