Skip to content
This repository was archived by the owner on Jul 18, 2025. It is now read-only.

Commit 1342ccf

Browse files
committed
Add ByteStreamWrapper to transparently copying ReadWriteClosers
Signed-off-by: Derek McGowan <derek@mcgstyle.net> (github: dmcgowan)
1 parent 51d8e0a commit 1342ccf

5 files changed

Lines changed: 181 additions & 0 deletions

File tree

inmem.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,33 @@ func (s *streamSession) decodeStream(v reflect.Value, b []byte) error {
156156
return nil
157157
}
158158

159+
func (s *streamSession) encodeWrapper(v reflect.Value) ([]byte, error) {
160+
wrapper := v.Interface().(ByteStreamWrapper)
161+
bs, err := s.newByteStream()
162+
if err != nil {
163+
return nil, err
164+
}
165+
166+
go func() {
167+
io.Copy(bs, wrapper)
168+
bs.Close()
169+
}()
170+
171+
go func() {
172+
io.Copy(wrapper, bs)
173+
wrapper.Close()
174+
}()
175+
176+
return s.encodeStream(reflect.ValueOf(bs).Elem())
177+
}
178+
179+
func (s *streamSession) decodeWrapper(v reflect.Value, b []byte) error {
180+
bs := &byteStream{}
181+
s.decodeStream(reflect.ValueOf(bs).Elem(), b)
182+
v.FieldByName("ReadWriteCloser").Set(reflect.ValueOf(bs))
183+
return nil
184+
}
185+
159186
func getMsgPackHandler(session *streamSession) *codec.MsgpackHandle {
160187
mh := &codec.MsgpackHandle{WriteExt: true}
161188
mh.RawToString = true
@@ -175,6 +202,11 @@ func getMsgPackHandler(session *streamSession) *codec.MsgpackHandle {
175202
panic(err)
176203
}
177204

205+
err = mh.AddExt(reflect.TypeOf(ByteStreamWrapper{}), 4, session.encodeWrapper, session.decodeWrapper)
206+
if err != nil {
207+
panic(err)
208+
}
209+
178210
return mh
179211
}
180212

inmem_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,44 @@ func TestComplexMessage(t *testing.T) {
303303
SpawnPipeTestRoutines(t, client, server)
304304
}
305305

306+
func TestInmemWrappedSend(t *testing.T) {
307+
tmp, err := ioutil.TempFile("", "libchan-test-")
308+
if err != nil {
309+
t.Fatal(err)
310+
}
311+
defer os.RemoveAll(tmp.Name())
312+
fmt.Fprintf(tmp, "hello through a wrapper\n")
313+
tmp.Sync()
314+
tmp.Seek(0, 0)
315+
316+
client := func(t *testing.T, w Sender) {
317+
message := &InMemMessage{Data: "path=" + tmp.Name(), Stream: ByteStreamWrapper{tmp}}
318+
err = w.Send(message)
319+
if err != nil {
320+
t.Fatal(err)
321+
}
322+
}
323+
server := func(t *testing.T, r Receiver) {
324+
msg := &InMemMessage{}
325+
err := r.Receive(msg)
326+
if err != nil {
327+
t.Fatal(err)
328+
}
329+
if msg.Data != "path="+tmp.Name() {
330+
t.Fatalf("%#v", msg)
331+
}
332+
txt, err := ioutil.ReadAll(msg.Stream)
333+
if err != nil {
334+
t.Fatal(err)
335+
}
336+
if string(txt) != "hello through a wrapper\n" {
337+
t.Fatalf("%s\n", txt)
338+
}
339+
340+
}
341+
SpawnPipeTestRoutines(t, client, server)
342+
}
343+
306344
type SendTestRoutine func(*testing.T, Sender)
307345
type ReceiveTestRoutine func(*testing.T, Receiver)
308346

libchan.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,12 @@ type Receiver interface {
6060
// automatically closed through receiving an EOF.
6161
Close() error
6262
}
63+
64+
// ByteStreamWrapper is a wrapper around a ReadWriteCloser
65+
// to cue the transport to copy to a transport byte stream.
66+
// Note: ReadWriteClosers created through calling the
67+
// CreateByteStream method on a Sender do not need
68+
// to wrap the ByteStream.
69+
type ByteStreamWrapper struct {
70+
io.ReadWriteCloser
71+
}

spdy/encode.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@ package spdy
33
import (
44
"encoding/binary"
55
"errors"
6+
"io"
67
"net"
78
"reflect"
89
"time"
910

1011
"github.com/dmcgowan/go/codec"
12+
"github.com/docker/libchan"
1113
)
1214

1315
func (s *SpdyTransport) encodeChannel(v reflect.Value) ([]byte, error) {
@@ -87,6 +89,33 @@ func (s *SpdyTransport) decodeStream(v reflect.Value, b []byte) error {
8789
return nil
8890
}
8991

92+
func (s *SpdyTransport) encodeWrapper(v reflect.Value) ([]byte, error) {
93+
wrapper := v.Interface().(libchan.ByteStreamWrapper)
94+
bs, err := s.createByteStream()
95+
if err != nil {
96+
return nil, err
97+
}
98+
99+
go func() {
100+
io.Copy(bs, wrapper)
101+
bs.Close()
102+
}()
103+
104+
go func() {
105+
io.Copy(wrapper, bs)
106+
wrapper.Close()
107+
}()
108+
109+
return s.encodeStream(reflect.ValueOf(bs).Elem())
110+
}
111+
112+
func (s *SpdyTransport) decodeWrapper(v reflect.Value, b []byte) error {
113+
bs := &byteStream{}
114+
s.decodeStream(reflect.ValueOf(bs).Elem(), b)
115+
v.FieldByName("ReadWriteCloser").Set(reflect.ValueOf(bs))
116+
return nil
117+
}
118+
90119
func (s *SpdyTransport) waitConn(network, local, remote string, timeout time.Duration) (net.Conn, error) {
91120
timeoutChan := time.After(timeout)
92121
connChan := make(chan net.Conn)
@@ -218,6 +247,11 @@ func (s *SpdyTransport) initializeHandler() *codec.MsgpackHandle {
218247
panic(err)
219248
}
220249

250+
err = mh.AddExt(reflect.TypeOf(libchan.ByteStreamWrapper{}), 0x03, s.encodeWrapper, s.decodeWrapper)
251+
if err != nil {
252+
panic(err)
253+
}
254+
221255
// Register networks
222256
s.networks["tcp"] = 0x04
223257
s.netConns[0x04] = make(map[string]net.Conn)

spdy/session_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,74 @@ func TestByteStream(t *testing.T) {
292292
SpawnClientServerTest(t, "localhost:12944", ClientSendWrapper(client), ServerReceiveWrapper(server))
293293
}
294294

295+
type WrappedMessage struct {
296+
Message string
297+
Wrapped io.ReadWriteCloser
298+
}
299+
300+
func TestWrappedByteStreams(t *testing.T) {
301+
serverSend := "G'day client ☺"
302+
clientReply := "Hello Server, ☢ FYI your stream was transparently copied ☠"
303+
client := func(t *testing.T, sender libchan.Sender, s *SpdyTransport) {
304+
// Create pipe
305+
p1, p2 := net.Pipe()
306+
307+
m1 := &WrappedMessage{
308+
Message: "wrapped",
309+
Wrapped: libchan.ByteStreamWrapper{p2},
310+
}
311+
312+
sendErr := sender.Send(m1)
313+
if sendErr != nil {
314+
t.Fatalf("Error sending channel: %s", sendErr)
315+
}
316+
317+
// read
318+
readBytes := make([]byte, 30)
319+
n, readErr := p1.Read(readBytes)
320+
if readErr != nil {
321+
t.Fatalf("Error reading from byte stream: %s", readErr)
322+
}
323+
if expected := serverSend; string(readBytes[:n]) != expected {
324+
t.Fatalf("Unexpected read value:\n\tExpected: %q\n\tActual: %q", expected, string(readBytes[:n]))
325+
}
326+
327+
// write
328+
_, writeErr := p1.Write([]byte(clientReply))
329+
if writeErr != nil {
330+
t.Fatalf("Error writing to byte stream: %s", writeErr)
331+
}
332+
333+
}
334+
server := func(t *testing.T, receiver libchan.Receiver, s *SpdyTransport) {
335+
m1 := &WrappedMessage{}
336+
recvErr := receiver.Receive(m1)
337+
if recvErr != nil {
338+
t.Fatalf("Error receiving message: %s", recvErr)
339+
}
340+
341+
if expected := "wrapped"; m1.Message != expected {
342+
t.Fatalf("Unexpected message\n\tExpected: %s\n\tActual: %s", expected, m1.Message)
343+
}
344+
345+
_, writeErr := m1.Wrapped.Write([]byte(serverSend))
346+
if writeErr != nil {
347+
t.Fatalf("Error writing to byte stream: %s", writeErr)
348+
}
349+
350+
readBytes := make([]byte, 80)
351+
n, readErr := m1.Wrapped.Read(readBytes)
352+
if readErr != nil {
353+
t.Fatalf("Error reading from byte stream: %s", readErr)
354+
}
355+
if expected := clientReply; string(readBytes[:n]) != expected {
356+
t.Fatalf("Unexpected read value:\n\tExpected: %q\n\tActual: %q", expected, string(readBytes[:n]))
357+
}
358+
359+
}
360+
SpawnClientServerTest(t, "localhost:12943", ClientSendWrapper(client), ServerReceiveWrapper(server))
361+
}
362+
295363
func ClientSendWrapper(f func(t *testing.T, c libchan.Sender, s *SpdyTransport)) ClientRoutine {
296364
return func(t *testing.T, server string) {
297365
conn, connErr := net.Dial("tcp", server)

0 commit comments

Comments
 (0)