@@ -5,11 +5,21 @@ import (
55 "errors"
66 "io"
77 "reflect"
8+ "time"
89
910 "github.com/dmcgowan/msgpack"
1011 "github.com/docker/libchan"
1112)
1213
14+ const (
15+ duplexStreamCode = 1
16+ inboundStreamCode = 2
17+ outboundStreamCode = 3
18+ inboundChannelCode = 4
19+ outboundChannelCode = 5
20+ timeCode = 6
21+ )
22+
1323func decodeReferenceID (b []byte ) (referenceID uint64 , err error ) {
1424 if len (b ) == 8 {
1525 referenceID = binary .BigEndian .Uint64 (b )
@@ -32,21 +42,14 @@ func encodeReferenceID(b []byte, referenceID uint64) (n int) {
3242 return
3343}
3444
35- func (c * channel ) channelBytes () ([]byte , error ) {
36- buf := make ([]byte , 9 )
37- if c .direction == inbound {
38- buf [0 ] = 0x02 // Reverse direction
39- } else if c .direction == outbound {
40- buf [0 ] = 0x01 // Reverse direction
41- } else {
42- return nil , errors .New ("invalid direction" )
43- }
44- written := encodeReferenceID (buf [1 :], c .referenceID )
45- return buf [:(written + 1 )], nil
45+ func (s * stream ) channelBytes () ([]byte , error ) {
46+ buf := make ([]byte , 8 )
47+ written := encodeReferenceID (buf , s .referenceID )
48+ return buf [:written ], nil
4649}
4750
48- func (c * channel ) copySendChannel (send libchan.Sender ) (* channel , error ) {
49- recv , sendCopy , err := c .CreateNestedReceiver ()
51+ func (s * stream ) copySendChannel (send libchan.Sender ) (* nopSender , error ) {
52+ recv , sendCopy , err := s .CreateNestedReceiver ()
5053 if err != nil {
5154 return nil , err
5255 }
@@ -55,11 +58,11 @@ func (c *channel) copySendChannel(send libchan.Sender) (*channel, error) {
5558 libchan .Copy (send , recv )
5659 send .Close ()
5760 }()
58- return sendCopy .(* channel ), nil
61+ return sendCopy .(* nopSender ), nil
5962}
6063
61- func (c * channel ) copyReceiveChannel (recv libchan.Receiver ) (* channel , error ) {
62- send , recvCopy , err := c .CreateNestedSender ()
64+ func (s * stream ) copyReceiveChannel (recv libchan.Receiver ) (* nopReceiver , error ) {
65+ send , recvCopy , err := s .CreateNestedSender ()
6366 if err != nil {
6467 return nil , err
6568 }
@@ -68,53 +71,134 @@ func (c *channel) copyReceiveChannel(recv libchan.Receiver) (*channel, error) {
6871 libchan .Copy (send , recv )
6972 send .Close ()
7073 }()
71- return recvCopy .(* channel ), nil
74+ return recvCopy .(* nopReceiver ), nil
7275}
7376
74- func (c * channel ) decodeChannel (v reflect.Value , b []byte ) error {
75- var d direction
76- if b [0 ] == 0x01 {
77- d = inbound
78- } else if b [0 ] == 0x02 {
79- d = outbound
80- } else {
81- return errors .New ("unexpected direction" )
77+ func (s * stream ) decodeStream (b []byte ) (* stream , error ) {
78+ referenceID , err := decodeReferenceID (b )
79+ if err != nil {
80+ return nil , err
81+ }
82+
83+ gs := s .session .getStream (referenceID )
84+ if gs == nil {
85+ return nil , errors .New ("stream does not exist" )
8286 }
83- referenceID , err := decodeReferenceID (b [1 :])
87+
88+ return gs , nil
89+ }
90+
91+ func (s * stream ) decodeReceiver (v reflect.Value , b []byte ) error {
92+ bs , err := s .decodeStream (b )
8493 if err != nil {
8594 return err
8695 }
8796
88- ch := c .session .getChannel (referenceID )
89- if ch == nil {
90- return errors .New ("channel does not exist" )
91- }
92- // TODO lock channel while check and setting
93- if ch .direction != 0 && ch .direction != d {
94- return ErrWrongDirection
97+ v .Set (reflect .ValueOf (& receiver {stream : bs }))
98+
99+ return nil
100+ }
101+
102+ func (s * stream ) decodeSender (v reflect.Value , b []byte ) error {
103+ bs , err := s .decodeStream (b )
104+ if err != nil {
105+ return err
95106 }
96- ch .direction = d
97107
98- v .Set (reflect .ValueOf (ch ))
108+ v .Set (reflect .ValueOf (& sender { stream : bs } ))
99109
100110 return nil
101111}
102112
103- func (b * byteStream ) streamBytes () ([]byte , error ) {
113+ func (s * stream ) streamBytes () ([]byte , error ) {
104114 var buf [8 ]byte
105- written := encodeReferenceID (buf [:], b .referenceID )
115+ written := encodeReferenceID (buf [:], s .referenceID )
106116
107117 return buf [:written ], nil
108118}
109119
110- func (c * channel ) encodeExtension (iv reflect.Value ) (int , []byte , error ) {
120+ func (s * stream ) decodeWStream (v reflect.Value , b []byte ) error {
121+ bs , err := s .decodeStream (b )
122+ if err != nil {
123+ return err
124+ }
125+
126+ v .Set (reflect .ValueOf (bs ))
127+
128+ return nil
129+ }
130+
131+ func (s * stream ) decodeRStream (v reflect.Value , b []byte ) error {
132+ bs , err := s .decodeStream (b )
133+ if err != nil {
134+ return err
135+ }
136+
137+ v .Set (reflect .ValueOf (bs ))
138+
139+ return nil
140+ }
141+
142+ func encodeTime (t * time.Time ) ([]byte , error ) {
143+ var b [12 ]byte
144+ binary .BigEndian .PutUint64 (b [0 :8 ], uint64 (t .Unix ()))
145+ binary .BigEndian .PutUint32 (b [8 :12 ], uint32 (t .Nanosecond ()))
146+ return b [:], nil
147+ }
148+
149+ func decodeTime (v reflect.Value , b []byte ) error {
150+ if len (b ) != 12 {
151+ return errors .New ("Invalid length" )
152+ }
153+ t := time .Unix (int64 (binary .BigEndian .Uint64 (b [0 :8 ])), int64 (binary .BigEndian .Uint32 (b [8 :12 ])))
154+
155+ if v .Kind () == reflect .Ptr {
156+ v .Set (reflect .ValueOf (& t ))
157+ } else {
158+ v .Set (reflect .ValueOf (t ))
159+ }
160+
161+ return nil
162+ }
163+
164+ func (s * stream ) encodeExtended (iv reflect.Value ) (i int , b []byte , e error ) {
111165 switch v := iv .Interface ().(type ) {
112- case * byteStream :
166+ case * nopSender :
167+ if v .stream == nil {
168+ return 0 , nil , errors .New ("bad type" )
169+ }
170+ if v .stream .session != s .session {
171+ rc , err := s .copySendChannel (v )
172+ if err != nil {
173+ return 0 , nil , err
174+ }
175+ b , err := rc .stream .channelBytes ()
176+ return inboundChannelCode , b , err
177+ }
178+
179+ b , err := v .stream .channelBytes ()
180+ return inboundChannelCode , b , err
181+ case * nopReceiver :
182+ if v .stream == nil {
183+ return 0 , nil , errors .New ("bad type" )
184+ }
185+ if v .stream .session != s .session {
186+ rc , err := s .copyReceiveChannel (v )
187+ if err != nil {
188+ return 0 , nil , err
189+ }
190+ b , err := rc .stream .channelBytes ()
191+ return outboundChannelCode , b , err
192+ }
193+
194+ b , err := v .stream .channelBytes ()
195+ return outboundChannelCode , b , err
196+ case * stream :
113197 if v .referenceID == 0 {
114198 return 0 , nil , errors .New ("bad type" )
115199 }
116- if v .session != c .session {
117- streamCopy , err := c .createByteStream ()
200+ if v .session != s .session {
201+ streamCopy , err := s .createByteStream ()
118202 if err != nil {
119203 return 0 , nil , err
120204 }
@@ -126,35 +210,53 @@ func (c *channel) encodeExtension(iv reflect.Value) (int, []byte, error) {
126210 io .Copy (w , streamCopy )
127211 w .Close ()
128212 }(v )
129- v = streamCopy .( * byteStream )
213+ v = streamCopy
130214
131215 }
132- b , err := v .streamBytes ()
133- return 2 , b , err
216+ b , err := v .channelBytes ()
217+ return duplexStreamCode , b , err
218+ case libchan.Sender :
219+ copyCh , err := s .copySendChannel (v )
220+ if err != nil {
221+ return 0 , nil , err
222+ }
223+ b , err := copyCh .stream .channelBytes ()
224+ return inboundChannelCode , b , err
225+ case libchan.Receiver :
226+ copyCh , err := s .copyReceiveChannel (v )
227+ if err != nil {
228+ return 0 , nil , err
229+ }
230+ b , err := copyCh .stream .channelBytes ()
231+ return outboundChannelCode , b , err
232+
134233 case io.Reader :
135234 // Either ReadWriteCloser, ReadWriter, or ReadCloser
136- streamCopy , err := c .createByteStream ()
235+ streamCopy , err := s .createByteStream ()
137236 if err != nil {
138237 return 0 , nil , err
139238 }
140239 go func () {
141240 io .Copy (streamCopy , v )
142241 streamCopy .Close ()
143242 }()
243+ code := outboundStreamCode
144244 if wc , ok := v .(io.WriteCloser ); ok {
145245 go func () {
146246 io .Copy (wc , streamCopy )
147247 wc .Close ()
148248 }()
249+ code = duplexStreamCode
149250 } else if w , ok := v .(io.Writer ); ok {
150251 go func () {
151252 io .Copy (w , streamCopy )
152253 }()
254+ code = duplexStreamCode
153255 }
154- b , err := streamCopy .( * byteStream ). streamBytes ()
155- return 2 , b , err
256+ b , err := streamCopy .streamBytes ()
257+ return code , b , err
156258 case io.Writer :
157- streamCopy , err := c .createByteStream ()
259+ streamCopy , err := s .createByteStream ()
158260 if err != nil {
159261 return 0 , nil , err
160262 }
@@ -168,64 +270,24 @@ func (c *channel) encodeExtension(iv reflect.Value) (int, []byte, error) {
168270 io .Copy (v , streamCopy )
169271 }()
170272 }
171- b , err := streamCopy .(* byteStream ).streamBytes ()
172- return 2 , b , err
173- case * channel :
174- if v .stream == nil {
175- return 0 , nil , errors .New ("bad type" )
176- }
177- if v .session != c .session {
178- var rc * channel
179- var err error
180- if c .direction == inbound {
181- rc , err = c .copyReceiveChannel (v )
182- } else {
183- rc , err = c .copySendChannel (v )
184- }
185- if err != nil {
186- return 0 , nil , err
187- }
188- b , err := rc .channelBytes ()
189- return 1 , b , err
190- }
191- b , err := v .channelBytes ()
192- return 1 , b , err
193- case libchan.Sender :
194- copyCh , err := c .copySendChannel (v )
195- if err != nil {
196- return 0 , nil , err
197- }
198- b , err := copyCh .channelBytes ()
199- return 1 , b , err
200- case libchan.Receiver :
201- copyCh , err := c .copyReceiveChannel (v )
202- if err != nil {
203- return 0 , nil , err
204- }
205- b , err := copyCh .channelBytes ()
206- return 1 , b , err
207- }
208- return 0 , nil , nil
209- }
210-
211- func (c * channel ) decodeStream (v reflect.Value , b []byte ) error {
212- referenceID , err := decodeReferenceID (b )
213- if err != nil {
214- return err
215- }
216273
217- bs := c .session .getByteStream (referenceID )
218- if bs != nil {
219- v .Set (reflect .ValueOf (bs ))
274+ b , err := streamCopy .streamBytes ()
275+ return inboundStreamCode , b , err
276+ case * time.Time :
277+ b , err := encodeTime (v )
278+ return timeCode , b , err
220279 }
221-
222- return nil
280+ return 0 , nil , nil
223281}
224282
225- func (c * channel ) initializeExtensions () * msgpack.Extensions {
283+ func (s * stream ) initializeExtensions () * msgpack.Extensions {
226284 exts := msgpack .NewExtensions ()
227- exts .SetEncoder (c .encodeExtension )
228- exts .AddDecoder (1 , reflect .TypeOf (& channel {}), c .decodeChannel )
229- exts .AddDecoder (2 , reflect .TypeOf (& byteStream {}), c .decodeStream )
285+ exts .SetEncoder (s .encodeExtended )
286+ exts .AddDecoder (duplexStreamCode , reflect .TypeOf (& stream {}), s .decodeWStream )
287+ exts .AddDecoder (inboundStreamCode , reflect .TypeOf (& stream {}), s .decodeWStream )
288+ exts .AddDecoder (outboundStreamCode , reflect .TypeOf (& stream {}), s .decodeRStream )
289+ exts .AddDecoder (inboundChannelCode , reflect .TypeOf (& sender {}), s .decodeSender )
290+ exts .AddDecoder (outboundChannelCode , reflect .TypeOf (& receiver {}), s .decodeReceiver )
291+ exts .AddDecoder (timeCode , reflect .TypeOf (& time.Time {}), decodeTime )
230292 return exts
231293}
0 commit comments