This project has retired. For details please refer to its Attic page.
Source code
001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.io.network.impl;
020
021import com.google.protobuf.ByteString;
022import com.google.protobuf.InvalidProtocolBufferException;
023import org.apache.reef.io.network.exception.NetworkRuntimeException;
024import org.apache.reef.io.network.proto.ReefNetworkServiceProtos.NSMessagePBuf;
025import org.apache.reef.io.network.proto.ReefNetworkServiceProtos.NSRecordPBuf;
026import org.apache.reef.wake.Identifier;
027import org.apache.reef.wake.IdentifierFactory;
028import org.apache.reef.wake.remote.Codec;
029
030import java.io.*;
031import java.util.ArrayList;
032import java.util.List;
033
034/**
035 * Network service message codec.
036 *
037 * @param <T> type
038 */
039public class NSMessageCodec<T> implements Codec<NSMessage<T>> {
040
041  private final Codec<T> codec;
042  private final IdentifierFactory factory;
043  private final boolean isStreamingCodec;
044
045  /**
046   * Constructs a network service message codec.
047   *
048   * @param codec   a codec
049   * @param factory an identifier factory
050   */
051  public NSMessageCodec(final Codec<T> codec, final IdentifierFactory factory) {
052    this.codec = codec;
053    this.factory = factory;
054    this.isStreamingCodec = codec instanceof StreamingCodec;
055  }
056
057  /**
058   * Encodes a network service message to bytes.
059   *
060   * @param obj a message
061   * @return bytes
062   */
063  @Override
064  public byte[] encode(final NSMessage<T> obj) {
065    if (isStreamingCodec) {
066      final StreamingCodec<T> streamingCodec = (StreamingCodec<T>) codec;
067      try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
068        try (DataOutputStream daos = new DataOutputStream(baos)) {
069          daos.writeUTF(obj.getSrcId().toString());
070          daos.writeUTF(obj.getDestId().toString());
071          daos.writeInt(obj.getData().size());
072          for (final T rec : obj.getData()) {
073            streamingCodec.encodeToStream(rec, daos);
074          }
075        }
076        return baos.toByteArray();
077      } catch (final IOException e) {
078        throw new RuntimeException("IOException", e);
079      }
080    } else {
081      final NSMessagePBuf.Builder pbuf = NSMessagePBuf.newBuilder();
082      pbuf.setSrcid(obj.getSrcId().toString());
083      pbuf.setDestid(obj.getDestId().toString());
084      for (final T rec : obj.getData()) {
085        final NSRecordPBuf.Builder rbuf = NSRecordPBuf.newBuilder();
086        rbuf.setData(ByteString.copyFrom(codec.encode(rec)));
087        pbuf.addMsgs(rbuf);
088      }
089      return pbuf.build().toByteArray();
090    }
091  }
092
093  /**
094   * Decodes a network service message from bytes.
095   *
096   * @param buf bytes
097   * @return a message
098   */
099  @Override
100  public NSMessage<T> decode(final byte[] buf) {
101    if (isStreamingCodec) {
102      final StreamingCodec<T> streamingCodec = (StreamingCodec<T>) codec;
103      try (ByteArrayInputStream bais = new ByteArrayInputStream(buf)) {
104        try (DataInputStream dais = new DataInputStream(bais)) {
105          final Identifier srcId = factory.getNewInstance(dais.readUTF());
106          final Identifier destId = factory.getNewInstance(dais.readUTF());
107          final int size = dais.readInt();
108          final List<T> list = new ArrayList<>(size);
109          for (int i = 0; i < size; i++) {
110            list.add(streamingCodec.decodeFromStream(dais));
111          }
112          return new NSMessage<>(srcId, destId, list);
113        }
114      } catch (final IOException e) {
115        throw new RuntimeException("IOException", e);
116      }
117    } else {
118      final NSMessagePBuf pbuf;
119      try {
120        pbuf = NSMessagePBuf.parseFrom(buf);
121      } catch (final InvalidProtocolBufferException e) {
122        e.printStackTrace();
123        throw new NetworkRuntimeException(e);
124      }
125      final List<T> list = new ArrayList<>();
126      for (final NSRecordPBuf rbuf : pbuf.getMsgsList()) {
127        list.add(codec.decode(rbuf.getData().toByteArray()));
128      }
129      return new NSMessage<>(factory.getNewInstance(pbuf.getSrcid()), factory.getNewInstance(pbuf.getDestid()), list);
130    }
131  }
132
133
134}