Skip to content
This repository was archived by the owner on Jul 18, 2025. It is now read-only.

Commit 3e33f1e

Browse files
committed
Split stream into separate type
Signed-off-by: Derek McGowan <derek@mcgstyle.net> (github: dmcgowan)
1 parent 9622304 commit 3e33f1e

4 files changed

Lines changed: 330 additions & 346 deletions

File tree

spdy/encode.go

Lines changed: 164 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,21 @@ import (
55
"errors"
66
"io"
77
"reflect"
8+
"time"
89

910
"github.com/dmcgowan/msgpack"
1011
"github.com/docker/libchan"
1112
)
1213

14+
const (
15+
duplexStreamCode = 1
16+
inboundStreamCode = 2
17+
outboundStreamCode = 3
18+
inboundChannelCode = 4
19+
outboundChannelCode = 5
20+
timeCode = 6
21+
)
22+
1323
func decodeReferenceID(b []byte) (referenceID uint64, err error) {
1424
if len(b) == 8 {
1525
referenceID = binary.BigEndian.Uint64(b)
@@ -32,21 +42,14 @@ func encodeReferenceID(b []byte, referenceID uint64) (n int) {
3242
return
3343
}
3444

35-
func (c *channel) channelBytes() ([]byte, error) {
36-
buf := make([]byte, 9)
37-
if c.direction == inbound {
38-
buf[0] = 0x02 // Reverse direction
39-
} else if c.direction == outbound {
40-
buf[0] = 0x01 // Reverse direction
41-
} else {
42-
return nil, errors.New("invalid direction")
43-
}
44-
written := encodeReferenceID(buf[1:], c.referenceID)
45-
return buf[:(written + 1)], nil
45+
func (s *stream) channelBytes() ([]byte, error) {
46+
buf := make([]byte, 8)
47+
written := encodeReferenceID(buf, s.referenceID)
48+
return buf[:written], nil
4649
}
4750

48-
func (c *channel) copySendChannel(send libchan.Sender) (*channel, error) {
49-
recv, sendCopy, err := c.CreateNestedReceiver()
51+
func (s *stream) copySendChannel(send libchan.Sender) (*nopSender, error) {
52+
recv, sendCopy, err := s.CreateNestedReceiver()
5053
if err != nil {
5154
return nil, err
5255
}
@@ -55,11 +58,11 @@ func (c *channel) copySendChannel(send libchan.Sender) (*channel, error) {
5558
libchan.Copy(send, recv)
5659
send.Close()
5760
}()
58-
return sendCopy.(*channel), nil
61+
return sendCopy.(*nopSender), nil
5962
}
6063

61-
func (c *channel) copyReceiveChannel(recv libchan.Receiver) (*channel, error) {
62-
send, recvCopy, err := c.CreateNestedSender()
64+
func (s *stream) copyReceiveChannel(recv libchan.Receiver) (*nopReceiver, error) {
65+
send, recvCopy, err := s.CreateNestedSender()
6366
if err != nil {
6467
return nil, err
6568
}
@@ -68,53 +71,134 @@ func (c *channel) copyReceiveChannel(recv libchan.Receiver) (*channel, error) {
6871
libchan.Copy(send, recv)
6972
send.Close()
7073
}()
71-
return recvCopy.(*channel), nil
74+
return recvCopy.(*nopReceiver), nil
7275
}
7376

74-
func (c *channel) decodeChannel(v reflect.Value, b []byte) error {
75-
var d direction
76-
if b[0] == 0x01 {
77-
d = inbound
78-
} else if b[0] == 0x02 {
79-
d = outbound
80-
} else {
81-
return errors.New("unexpected direction")
77+
func (s *stream) decodeStream(b []byte) (*stream, error) {
78+
referenceID, err := decodeReferenceID(b)
79+
if err != nil {
80+
return nil, err
81+
}
82+
83+
gs := s.session.getStream(referenceID)
84+
if gs == nil {
85+
return nil, errors.New("stream does not exist")
8286
}
83-
referenceID, err := decodeReferenceID(b[1:])
87+
88+
return gs, nil
89+
}
90+
91+
func (s *stream) decodeReceiver(v reflect.Value, b []byte) error {
92+
bs, err := s.decodeStream(b)
8493
if err != nil {
8594
return err
8695
}
8796

88-
ch := c.session.getChannel(referenceID)
89-
if ch == nil {
90-
return errors.New("channel does not exist")
91-
}
92-
// TODO lock channel while check and setting
93-
if ch.direction != 0 && ch.direction != d {
94-
return ErrWrongDirection
97+
v.Set(reflect.ValueOf(&receiver{stream: bs}))
98+
99+
return nil
100+
}
101+
102+
func (s *stream) decodeSender(v reflect.Value, b []byte) error {
103+
bs, err := s.decodeStream(b)
104+
if err != nil {
105+
return err
95106
}
96-
ch.direction = d
97107

98-
v.Set(reflect.ValueOf(ch))
108+
v.Set(reflect.ValueOf(&sender{stream: bs}))
99109

100110
return nil
101111
}
102112

103-
func (b *byteStream) streamBytes() ([]byte, error) {
113+
func (s *stream) streamBytes() ([]byte, error) {
104114
var buf [8]byte
105-
written := encodeReferenceID(buf[:], b.referenceID)
115+
written := encodeReferenceID(buf[:], s.referenceID)
106116

107117
return buf[:written], nil
108118
}
109119

110-
func (c *channel) encodeExtension(iv reflect.Value) (int, []byte, error) {
120+
func (s *stream) decodeWStream(v reflect.Value, b []byte) error {
121+
bs, err := s.decodeStream(b)
122+
if err != nil {
123+
return err
124+
}
125+
126+
v.Set(reflect.ValueOf(bs))
127+
128+
return nil
129+
}
130+
131+
func (s *stream) decodeRStream(v reflect.Value, b []byte) error {
132+
bs, err := s.decodeStream(b)
133+
if err != nil {
134+
return err
135+
}
136+
137+
v.Set(reflect.ValueOf(bs))
138+
139+
return nil
140+
}
141+
142+
func encodeTime(t *time.Time) ([]byte, error) {
143+
var b [12]byte
144+
binary.BigEndian.PutUint64(b[0:8], uint64(t.Unix()))
145+
binary.BigEndian.PutUint32(b[8:12], uint32(t.Nanosecond()))
146+
return b[:], nil
147+
}
148+
149+
func decodeTime(v reflect.Value, b []byte) error {
150+
if len(b) != 12 {
151+
return errors.New("Invalid length")
152+
}
153+
t := time.Unix(int64(binary.BigEndian.Uint64(b[0:8])), int64(binary.BigEndian.Uint32(b[8:12])))
154+
155+
if v.Kind() == reflect.Ptr {
156+
v.Set(reflect.ValueOf(&t))
157+
} else {
158+
v.Set(reflect.ValueOf(t))
159+
}
160+
161+
return nil
162+
}
163+
164+
func (s *stream) encodeExtended(iv reflect.Value) (i int, b []byte, e error) {
111165
switch v := iv.Interface().(type) {
112-
case *byteStream:
166+
case *nopSender:
167+
if v.stream == nil {
168+
return 0, nil, errors.New("bad type")
169+
}
170+
if v.stream.session != s.session {
171+
rc, err := s.copySendChannel(v)
172+
if err != nil {
173+
return 0, nil, err
174+
}
175+
b, err := rc.stream.channelBytes()
176+
return inboundChannelCode, b, err
177+
}
178+
179+
b, err := v.stream.channelBytes()
180+
return inboundChannelCode, b, err
181+
case *nopReceiver:
182+
if v.stream == nil {
183+
return 0, nil, errors.New("bad type")
184+
}
185+
if v.stream.session != s.session {
186+
rc, err := s.copyReceiveChannel(v)
187+
if err != nil {
188+
return 0, nil, err
189+
}
190+
b, err := rc.stream.channelBytes()
191+
return outboundChannelCode, b, err
192+
}
193+
194+
b, err := v.stream.channelBytes()
195+
return outboundChannelCode, b, err
196+
case *stream:
113197
if v.referenceID == 0 {
114198
return 0, nil, errors.New("bad type")
115199
}
116-
if v.session != c.session {
117-
streamCopy, err := c.createByteStream()
200+
if v.session != s.session {
201+
streamCopy, err := s.createByteStream()
118202
if err != nil {
119203
return 0, nil, err
120204
}
@@ -126,35 +210,53 @@ func (c *channel) encodeExtension(iv reflect.Value) (int, []byte, error) {
126210
io.Copy(w, streamCopy)
127211
w.Close()
128212
}(v)
129-
v = streamCopy.(*byteStream)
213+
v = streamCopy
130214

131215
}
132-
b, err := v.streamBytes()
133-
return 2, b, err
216+
b, err := v.channelBytes()
217+
return duplexStreamCode, b, err
218+
case libchan.Sender:
219+
copyCh, err := s.copySendChannel(v)
220+
if err != nil {
221+
return 0, nil, err
222+
}
223+
b, err := copyCh.stream.channelBytes()
224+
return inboundChannelCode, b, err
225+
case libchan.Receiver:
226+
copyCh, err := s.copyReceiveChannel(v)
227+
if err != nil {
228+
return 0, nil, err
229+
}
230+
b, err := copyCh.stream.channelBytes()
231+
return outboundChannelCode, b, err
232+
134233
case io.Reader:
135234
// Either ReadWriteCloser, ReadWriter, or ReadCloser
136-
streamCopy, err := c.createByteStream()
235+
streamCopy, err := s.createByteStream()
137236
if err != nil {
138237
return 0, nil, err
139238
}
140239
go func() {
141240
io.Copy(streamCopy, v)
142241
streamCopy.Close()
143242
}()
243+
code := outboundStreamCode
144244
if wc, ok := v.(io.WriteCloser); ok {
145245
go func() {
146246
io.Copy(wc, streamCopy)
147247
wc.Close()
148248
}()
249+
code = duplexStreamCode
149250
} else if w, ok := v.(io.Writer); ok {
150251
go func() {
151252
io.Copy(w, streamCopy)
152253
}()
254+
code = duplexStreamCode
153255
}
154-
b, err := streamCopy.(*byteStream).streamBytes()
155-
return 2, b, err
256+
b, err := streamCopy.streamBytes()
257+
return code, b, err
156258
case io.Writer:
157-
streamCopy, err := c.createByteStream()
259+
streamCopy, err := s.createByteStream()
158260
if err != nil {
159261
return 0, nil, err
160262
}
@@ -168,64 +270,24 @@ func (c *channel) encodeExtension(iv reflect.Value) (int, []byte, error) {
168270
io.Copy(v, streamCopy)
169271
}()
170272
}
171-
b, err := streamCopy.(*byteStream).streamBytes()
172-
return 2, b, err
173-
case *channel:
174-
if v.stream == nil {
175-
return 0, nil, errors.New("bad type")
176-
}
177-
if v.session != c.session {
178-
var rc *channel
179-
var err error
180-
if c.direction == inbound {
181-
rc, err = c.copyReceiveChannel(v)
182-
} else {
183-
rc, err = c.copySendChannel(v)
184-
}
185-
if err != nil {
186-
return 0, nil, err
187-
}
188-
b, err := rc.channelBytes()
189-
return 1, b, err
190-
}
191-
b, err := v.channelBytes()
192-
return 1, b, err
193-
case libchan.Sender:
194-
copyCh, err := c.copySendChannel(v)
195-
if err != nil {
196-
return 0, nil, err
197-
}
198-
b, err := copyCh.channelBytes()
199-
return 1, b, err
200-
case libchan.Receiver:
201-
copyCh, err := c.copyReceiveChannel(v)
202-
if err != nil {
203-
return 0, nil, err
204-
}
205-
b, err := copyCh.channelBytes()
206-
return 1, b, err
207-
}
208-
return 0, nil, nil
209-
}
210-
211-
func (c *channel) decodeStream(v reflect.Value, b []byte) error {
212-
referenceID, err := decodeReferenceID(b)
213-
if err != nil {
214-
return err
215-
}
216273

217-
bs := c.session.getByteStream(referenceID)
218-
if bs != nil {
219-
v.Set(reflect.ValueOf(bs))
274+
b, err := streamCopy.streamBytes()
275+
return inboundStreamCode, b, err
276+
case *time.Time:
277+
b, err := encodeTime(v)
278+
return timeCode, b, err
220279
}
221-
222-
return nil
280+
return 0, nil, nil
223281
}
224282

225-
func (c *channel) initializeExtensions() *msgpack.Extensions {
283+
func (s *stream) initializeExtensions() *msgpack.Extensions {
226284
exts := msgpack.NewExtensions()
227-
exts.SetEncoder(c.encodeExtension)
228-
exts.AddDecoder(1, reflect.TypeOf(&channel{}), c.decodeChannel)
229-
exts.AddDecoder(2, reflect.TypeOf(&byteStream{}), c.decodeStream)
285+
exts.SetEncoder(s.encodeExtended)
286+
exts.AddDecoder(duplexStreamCode, reflect.TypeOf(&stream{}), s.decodeWStream)
287+
exts.AddDecoder(inboundStreamCode, reflect.TypeOf(&stream{}), s.decodeWStream)
288+
exts.AddDecoder(outboundStreamCode, reflect.TypeOf(&stream{}), s.decodeRStream)
289+
exts.AddDecoder(inboundChannelCode, reflect.TypeOf(&sender{}), s.decodeSender)
290+
exts.AddDecoder(outboundChannelCode, reflect.TypeOf(&receiver{}), s.decodeReceiver)
291+
exts.AddDecoder(timeCode, reflect.TypeOf(&time.Time{}), decodeTime)
230292
return exts
231293
}

0 commit comments

Comments
 (0)