-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathserver.go
More file actions
128 lines (118 loc) · 3.71 KB
/
server.go
File metadata and controls
128 lines (118 loc) · 3.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package httpproxy
import (
"errors"
"net"
"net/http"
"net/url"
"slices"
"sync"
)
var MaxCachedRoundTrippers = 100
var DefaultContextDialer ContextDialer = &net.Dialer{}
var ErrUnauthorized = errors.New("unauthorized")
type Server struct {
Logger Logger // optional logger to use
Handler http.Handler // optional handler for requests that aren't proxy requests
DialerSelector DialerSelector // optional handler to select ContextDialer per proxy request, otherwise uses DefaultContextDialer
CredentialsValidator CredentialsValidator // optional credentials validator
RoundTripperMaker RoundTripperMaker // optional RoundTripperMaker, defaults to DefaultMakeRoundTripper
mu sync.Mutex // protects following
counter int64 // counts ensureTripper calls
trippers map[ContextDialer]*roundTripperCache // LRU cache mapping CD -> RT
}
func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodConnect {
srv.connect(w, r)
} else if r.URL.IsAbs() {
srv.proxy(w, r)
} else if srv.Handler != nil {
srv.Handler.ServeHTTP(w, r)
} else {
w.WriteHeader(http.StatusNotFound)
}
}
// DefaultMakeRoundTripper clones http.DefaultTransport, sets
// it's DialContext member and returns it.
func DefaultMakeRoundTripper(cd ContextDialer) http.RoundTripper {
tp := http.DefaultTransport.(*http.Transport).Clone()
tp.DialContext = cd.DialContext
return tp
}
func (srv *Server) cleanTripperCacheLocked() {
type roundTripperCacheList struct {
ContextDialer
*roundTripperCache
}
var trippers []roundTripperCacheList
for cd, rtc := range srv.trippers {
trippers = append(trippers, roundTripperCacheList{ContextDialer: cd, roundTripperCache: rtc})
}
slices.SortFunc(trippers, func(a, b roundTripperCacheList) int { return int(b.counter - a.counter) })
for i, rtcl := range trippers {
if i >= MaxCachedRoundTrippers/2 {
delete(srv.trippers, rtcl.ContextDialer)
}
}
}
func (srv *Server) ensureTripper(cd ContextDialer) (rt http.RoundTripper) {
srv.mu.Lock()
defer srv.mu.Unlock()
var rtc *roundTripperCache
if rtc = srv.trippers[cd]; rtc == nil {
if srv.trippers == nil {
srv.trippers = make(map[ContextDialer]*roundTripperCache)
}
if len(srv.trippers) >= MaxCachedRoundTrippers {
srv.cleanTripperCacheLocked()
}
rtm := DefaultMakeRoundTripper
if srv.RoundTripperMaker != nil {
rtm = srv.RoundTripperMaker.MakeRoundTripper
}
rt = rtm(cd)
rtc = &roundTripperCache{RoundTripper: rt}
srv.trippers[cd] = rtc
}
srv.counter++
rtc.counter = srv.counter
return rtc.RoundTripper
}
func getAddress(u *url.URL) (address string) {
address = u.Host
if u.Port() == "" {
switch u.Scheme {
case "http", "ws":
address += ":80"
case "https", "wss":
address += ":443"
}
}
return
}
func (srv *Server) getDialer(r *http.Request) (cd ContextDialer, address string, err error) {
var username string
address = getAddress(r.URL)
if srv.CredentialsValidator != nil {
var password string
if username, password, err = GetBasicAuth(r.Header); err == nil {
if !srv.CredentialsValidator.ValidateCredentials(username, password, address) {
err = ErrUnauthorized
}
}
}
if err == nil {
cd = DefaultContextDialer
if srv.DialerSelector != nil {
cd, err = srv.DialerSelector.SelectDialer(username, "tcp", address)
}
}
return
}
func (srv *Server) getRoundTripper(r *http.Request) (rt http.RoundTripper) {
if cd, _, err := srv.getDialer(r); err == nil {
rt = srv.ensureTripper(cd)
} else {
rt = fakeRoundTripper{err: err}
}
return
}