|
1 | | -//go:generate mockgen -package scan -destination=mock_request_test.go -source request.go |
| 1 | +//go:generate mockgen -package scan -destination=mock_request_test.go . PortGenerator,IPGenerator,RequestGenerator,IPContainer |
2 | 2 | //go:generate easyjson -output_filename request_easyjson.go request.go |
3 | 3 |
|
4 | 4 | package scan |
@@ -31,37 +31,67 @@ type Request struct { |
31 | 31 | Err error |
32 | 32 | } |
33 | 33 |
|
| 34 | +type PortGetter interface { |
| 35 | + GetPort() (uint16, error) |
| 36 | +} |
| 37 | + |
| 38 | +type WrapPort uint16 |
| 39 | + |
| 40 | +func (p WrapPort) GetPort() (uint16, error) { |
| 41 | + return uint16(p), nil |
| 42 | +} |
| 43 | + |
| 44 | +type portError struct { |
| 45 | + error |
| 46 | +} |
| 47 | + |
| 48 | +func (err *portError) GetPort() (uint16, error) { |
| 49 | + return 0, err |
| 50 | +} |
| 51 | + |
34 | 52 | type PortGenerator interface { |
35 | | - Ports(ctx context.Context, r *Range) (<-chan uint16, error) |
| 53 | + Ports(ctx context.Context, r *Range) (<-chan PortGetter, error) |
36 | 54 | } |
37 | 55 |
|
38 | 56 | func NewPortGenerator() PortGenerator { |
39 | 57 | return &portGenerator{} |
40 | 58 | } |
41 | 59 |
|
42 | | -// TODO randomizedPortGenerator |
43 | 60 | type portGenerator struct{} |
44 | 61 |
|
45 | | -func (*portGenerator) Ports(ctx context.Context, r *Range) (<-chan uint16, error) { |
| 62 | +func (*portGenerator) Ports(ctx context.Context, r *Range) (<-chan PortGetter, error) { |
46 | 63 | if err := validatePorts(r.Ports); err != nil { |
47 | 64 | return nil, err |
48 | 65 | } |
49 | | - out := make(chan uint16, 100) |
| 66 | + out := make(chan PortGetter, 100) |
50 | 67 | go func() { |
51 | 68 | defer close(out) |
52 | 69 | for _, portRange := range r.Ports { |
53 | | - for port := int(portRange.StartPort); port <= int(portRange.EndPort); port++ { |
54 | | - select { |
55 | | - case <-ctx.Done(): |
56 | | - return |
57 | | - case out <- uint16(port): |
| 70 | + it, err := newRangeIterator(int64(portRange.EndPort) - int64(portRange.StartPort) + 1) |
| 71 | + if err != nil { |
| 72 | + writePort(ctx, out, &portError{err}) |
| 73 | + continue |
| 74 | + } |
| 75 | + basePort := int64(portRange.StartPort) - 1 |
| 76 | + for { |
| 77 | + writePort(ctx, out, WrapPort(basePort+it.Int().Int64())) |
| 78 | + if !it.Next() { |
| 79 | + break |
58 | 80 | } |
59 | 81 | } |
60 | 82 | } |
61 | 83 | }() |
62 | 84 | return out, nil |
63 | 85 | } |
64 | 86 |
|
| 87 | +func writePort(ctx context.Context, out chan<- PortGetter, port PortGetter) { |
| 88 | + select { |
| 89 | + case <-ctx.Done(): |
| 90 | + return |
| 91 | + case out <- port: |
| 92 | + } |
| 93 | +} |
| 94 | + |
65 | 95 | func validatePorts(ports []*PortRange) error { |
66 | 96 | if len(ports) == 0 { |
67 | 97 | return ErrPortRange |
@@ -153,7 +183,12 @@ func (rg *ipPortGenerator) GenerateRequests(ctx context.Context, r *Range) (<-ch |
153 | 183 | out := make(chan *Request, 100) |
154 | 184 | go func() { |
155 | 185 | defer close(out) |
156 | | - for port := range ports { |
| 186 | + for p := range ports { |
| 187 | + port, err := p.GetPort() |
| 188 | + if err != nil { |
| 189 | + writeRequest(ctx, out, &Request{Err: err}) |
| 190 | + continue |
| 191 | + } |
157 | 192 | for ipaddr := range ips { |
158 | 193 | dstip, err := ipaddr.GetIP() |
159 | 194 | writeRequest(ctx, out, &Request{ |
|
0 commit comments