@@ -3,7 +3,6 @@ package spdy
33import (
44 "encoding/binary"
55 "errors"
6- "fmt"
76 "io"
87 "reflect"
98
@@ -72,43 +71,6 @@ func (c *channel) copyReceiveChannel(recv libchan.Receiver) (*channel, error) {
7271 return recvCopy .(* channel ), nil
7372}
7473
75- func (c * channel ) encodeChannel (iv reflect.Value ) ([]byte , error ) {
76- switch v := iv .Interface ().(type ) {
77- case * channel :
78- if v .stream == nil {
79- return nil , errors .New ("bad type" )
80- }
81- if v .session != c .session {
82- var rc * channel
83- var err error
84- if c .direction == inbound {
85- rc , err = c .copyReceiveChannel (v )
86- } else {
87- rc , err = c .copySendChannel (v )
88- }
89- if err != nil {
90- return nil , err
91- }
92- return rc .channelBytes ()
93- }
94-
95- return v .channelBytes ()
96- case libchan.Sender :
97- copyCh , err := c .copySendChannel (v )
98- if err != nil {
99- return nil , err
100- }
101- return copyCh .channelBytes ()
102- case libchan.Receiver :
103- copyCh , err := c .copyReceiveChannel (v )
104- if err != nil {
105- return nil , err
106- }
107- return copyCh .channelBytes ()
108- }
109- return nil , fmt .Errorf ("unsupported channel type: %T" , iv .Interface ())
110- }
111-
11274func (c * channel ) decodeChannel (v reflect.Value , b []byte ) error {
11375 var d direction
11476 if b [0 ] == 0x01 {
@@ -145,16 +107,16 @@ func (b *byteStream) streamBytes() ([]byte, error) {
145107 return buf [:written ], nil
146108}
147109
148- func (s * Transport ) encodeStream (iv reflect.Value ) ([]byte , error ) {
110+ func (c * channel ) encodeExtension (iv reflect.Value ) (int , []byte , error ) {
149111 switch v := iv .Interface ().(type ) {
150112 case * byteStream :
151113 if v .referenceID == 0 {
152- return nil , errors .New ("bad type" )
114+ return 0 , nil , errors .New ("bad type" )
153115 }
154- if v .session != s {
155- streamCopy , err := s .createByteStream ()
116+ if v .session != c . session {
117+ streamCopy , err := c . session .createByteStream ()
156118 if err != nil {
157- return nil , err
119+ return 0 , nil , err
158120 }
159121 go func (r io.Reader ) {
160122 io .Copy (streamCopy , r )
@@ -167,12 +129,13 @@ func (s *Transport) encodeStream(iv reflect.Value) ([]byte, error) {
167129 v = streamCopy .(* byteStream )
168130
169131 }
170- return v .streamBytes ()
132+ b , err := v .streamBytes ()
133+ return 2 , b , err
171134 case io.Reader :
172135 // Either ReadWriteCloser, ReadWriter, or ReadCloser
173- streamCopy , err := s .createByteStream ()
136+ streamCopy , err := c . session .createByteStream ()
174137 if err != nil {
175- return nil , err
138+ return 0 , nil , err
176139 }
177140 go func () {
178141 io .Copy (streamCopy , v )
@@ -188,11 +151,12 @@ func (s *Transport) encodeStream(iv reflect.Value) ([]byte, error) {
188151 io .Copy (w , streamCopy )
189152 }()
190153 }
191- return streamCopy .(* byteStream ).streamBytes ()
154+ b , err := streamCopy .(* byteStream ).streamBytes ()
155+ return 2 , b , err
192156 case io.Writer :
193- streamCopy , err := s .createByteStream ()
157+ streamCopy , err := c . session .createByteStream ()
194158 if err != nil {
195- return nil , err
159+ return 0 , nil , err
196160 }
197161 if wc , ok := v .(io.WriteCloser ); ok {
198162 go func () {
@@ -204,10 +168,44 @@ func (s *Transport) encodeStream(iv reflect.Value) ([]byte, error) {
204168 io .Copy (v , streamCopy )
205169 }()
206170 }
207-
208- return streamCopy .(* byteStream ).streamBytes ()
171+ b , err := streamCopy .(* byteStream ).streamBytes ()
172+ return 2 , b , err
173+ case * channel :
174+ if v .stream == nil {
175+ return 0 , nil , errors .New ("bad type" )
176+ }
177+ if v .session != c .session {
178+ var rc * channel
179+ var err error
180+ if c .direction == inbound {
181+ rc , err = c .copyReceiveChannel (v )
182+ } else {
183+ rc , err = c .copySendChannel (v )
184+ }
185+ if err != nil {
186+ return 0 , nil , err
187+ }
188+ b , err := rc .channelBytes ()
189+ return 1 , b , err
190+ }
191+ b , err := v .channelBytes ()
192+ return 1 , b , err
193+ case libchan.Sender :
194+ copyCh , err := c .copySendChannel (v )
195+ if err != nil {
196+ return 0 , nil , err
197+ }
198+ b , err := copyCh .channelBytes ()
199+ return 1 , b , err
200+ case libchan.Receiver :
201+ copyCh , err := c .copyReceiveChannel (v )
202+ if err != nil {
203+ return 0 , nil , err
204+ }
205+ b , err := copyCh .channelBytes ()
206+ return 1 , b , err
209207 }
210- return nil , fmt . Errorf ( "unsupported stream type: %T" , iv . Interface ())
208+ return 0 , nil , nil
211209}
212210
213211func (s * Transport ) decodeStream (v reflect.Value , b []byte ) error {
@@ -226,15 +224,8 @@ func (s *Transport) decodeStream(v reflect.Value, b []byte) error {
226224
227225func (c * channel ) initializeExtensions () * msgpack.Extensions {
228226 exts := msgpack .NewExtensions ()
229- chanInterfaces := []reflect.Type {
230- reflect .TypeOf (new (libchan.Sender )),
231- reflect .TypeOf (new (libchan.Receiver )),
232- }
233- streamInterfaces := []reflect.Type {
234- reflect .TypeOf (new (io.Reader )),
235- reflect .TypeOf (new (io.Writer )),
236- }
237- exts .AddExtension (1 , reflect .TypeOf (& channel {}), chanInterfaces , c .encodeChannel , c .decodeChannel )
238- exts .AddExtension (2 , reflect .TypeOf (& byteStream {}), streamInterfaces , c .session .encodeStream , c .session .decodeStream )
227+ exts .SetEncoder (c .encodeExtension )
228+ exts .AddDecoder (1 , reflect .TypeOf (& channel {}), c .decodeChannel )
229+ exts .AddDecoder (2 , reflect .TypeOf (& byteStream {}), c .session .decodeStream )
239230 return exts
240231}
0 commit comments