mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-27 23:48:27 +00:00
Take write lock for rate limit map (#1532)
* Take write lock for rate limit map * Fix potential race condition
This commit is contained in:
parent
4a7fb9c045
commit
640e8c50ec
1 changed files with 20 additions and 10 deletions
|
@ -13,6 +13,7 @@ import (
|
||||||
type rateLimits struct {
|
type rateLimits struct {
|
||||||
limits map[string]chan struct{}
|
limits map[string]chan struct{}
|
||||||
limitsMutex sync.RWMutex
|
limitsMutex sync.RWMutex
|
||||||
|
cleanMutex sync.RWMutex
|
||||||
enabled bool
|
enabled bool
|
||||||
requestThreshold int64
|
requestThreshold int64
|
||||||
cooloffDuration time.Duration
|
cooloffDuration time.Duration
|
||||||
|
@ -38,6 +39,7 @@ func (l *rateLimits) clean() {
|
||||||
// empty. If they are then we will close and delete them,
|
// empty. If they are then we will close and delete them,
|
||||||
// freeing up memory.
|
// freeing up memory.
|
||||||
time.Sleep(time.Second * 30)
|
time.Sleep(time.Second * 30)
|
||||||
|
l.cleanMutex.Lock()
|
||||||
l.limitsMutex.Lock()
|
l.limitsMutex.Lock()
|
||||||
for k, c := range l.limits {
|
for k, c := range l.limits {
|
||||||
if len(c) == 0 {
|
if len(c) == 0 {
|
||||||
|
@ -46,6 +48,7 @@ func (l *rateLimits) clean() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
l.limitsMutex.Unlock()
|
l.limitsMutex.Unlock()
|
||||||
|
l.cleanMutex.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,12 +58,12 @@ func (l *rateLimits) rateLimit(req *http.Request) *util.JSONResponse {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lock the map long enough to check for rate limiting. We hold it
|
// Take a read lock out on the cleaner mutex. The cleaner expects to
|
||||||
// for longer here than we really need to but it makes sure that we
|
// be able to take a write lock, which isn't possible while there are
|
||||||
// also don't conflict with the cleaner goroutine which might clean
|
// readers, so this has the effect of blocking the cleaner goroutine
|
||||||
// up a channel after we have retrieved it otherwise.
|
// from doing its work until there are no requests in flight.
|
||||||
l.limitsMutex.RLock()
|
l.cleanMutex.RLock()
|
||||||
defer l.limitsMutex.RUnlock()
|
defer l.cleanMutex.RUnlock()
|
||||||
|
|
||||||
// First of all, work out if X-Forwarded-For was sent to us. If not
|
// First of all, work out if X-Forwarded-For was sent to us. If not
|
||||||
// then we'll just use the IP address of the caller.
|
// then we'll just use the IP address of the caller.
|
||||||
|
@ -69,12 +72,19 @@ func (l *rateLimits) rateLimit(req *http.Request) *util.JSONResponse {
|
||||||
caller = forwardedFor
|
caller = forwardedFor
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look up the caller's channel, if they have one. If they don't then
|
// Look up the caller's channel, if they have one.
|
||||||
// let's create one.
|
l.limitsMutex.RLock()
|
||||||
rateLimit, ok := l.limits[caller]
|
rateLimit, ok := l.limits[caller]
|
||||||
|
l.limitsMutex.RUnlock()
|
||||||
|
|
||||||
|
// If the caller doesn't have a channel, create one and write it
|
||||||
|
// back to the map.
|
||||||
if !ok {
|
if !ok {
|
||||||
l.limits[caller] = make(chan struct{}, l.requestThreshold)
|
rateLimit = make(chan struct{}, l.requestThreshold)
|
||||||
rateLimit = l.limits[caller]
|
|
||||||
|
l.limitsMutex.Lock()
|
||||||
|
l.limits[caller] = rateLimit
|
||||||
|
l.limitsMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the user has got free resource slots for this request.
|
// Check if the user has got free resource slots for this request.
|
||||||
|
|
Loading…
Reference in a new issue