001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.io.network.impl; 020 021import com.google.protobuf.ByteString; 022import com.google.protobuf.InvalidProtocolBufferException; 023import org.apache.reef.io.network.exception.NetworkRuntimeException; 024import org.apache.reef.io.network.proto.ReefNetworkServiceProtos.NSMessagePBuf; 025import org.apache.reef.io.network.proto.ReefNetworkServiceProtos.NSRecordPBuf; 026import org.apache.reef.wake.Identifier; 027import org.apache.reef.wake.IdentifierFactory; 028import org.apache.reef.wake.remote.Codec; 029 030import java.io.*; 031import java.util.ArrayList; 032import java.util.List; 033 034/** 035 * Network service message codec. 036 * 037 * @param <T> type 038 */ 039public class NSMessageCodec<T> implements Codec<NSMessage<T>> { 040 041 private final Codec<T> codec; 042 private final IdentifierFactory factory; 043 private final boolean isStreamingCodec; 044 045 /** 046 * Constructs a network service message codec. 047 * 048 * @param codec a codec 049 * @param factory an identifier factory 050 */ 051 public NSMessageCodec(final Codec<T> codec, final IdentifierFactory factory) { 052 this.codec = codec; 053 this.factory = factory; 054 this.isStreamingCodec = codec instanceof StreamingCodec; 055 } 056 057 /** 058 * Encodes a network service message to bytes. 059 * 060 * @param obj a message 061 * @return bytes 062 */ 063 @Override 064 public byte[] encode(final NSMessage<T> obj) { 065 if (isStreamingCodec) { 066 final StreamingCodec<T> streamingCodec = (StreamingCodec<T>) codec; 067 try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { 068 try (DataOutputStream daos = new DataOutputStream(baos)) { 069 daos.writeUTF(obj.getSrcId().toString()); 070 daos.writeUTF(obj.getDestId().toString()); 071 daos.writeInt(obj.getData().size()); 072 for (final T rec : obj.getData()) { 073 streamingCodec.encodeToStream(rec, daos); 074 } 075 } 076 return baos.toByteArray(); 077 } catch (final IOException e) { 078 throw new RuntimeException("IOException", e); 079 } 080 } else { 081 final NSMessagePBuf.Builder pbuf = NSMessagePBuf.newBuilder(); 082 pbuf.setSrcid(obj.getSrcId().toString()); 083 pbuf.setDestid(obj.getDestId().toString()); 084 for (final T rec : obj.getData()) { 085 final NSRecordPBuf.Builder rbuf = NSRecordPBuf.newBuilder(); 086 rbuf.setData(ByteString.copyFrom(codec.encode(rec))); 087 pbuf.addMsgs(rbuf); 088 } 089 return pbuf.build().toByteArray(); 090 } 091 } 092 093 /** 094 * Decodes a network service message from bytes. 095 * 096 * @param buf bytes 097 * @return a message 098 */ 099 @Override 100 public NSMessage<T> decode(final byte[] buf) { 101 if (isStreamingCodec) { 102 final StreamingCodec<T> streamingCodec = (StreamingCodec<T>) codec; 103 try (ByteArrayInputStream bais = new ByteArrayInputStream(buf)) { 104 try (DataInputStream dais = new DataInputStream(bais)) { 105 final Identifier srcId = factory.getNewInstance(dais.readUTF()); 106 final Identifier destId = factory.getNewInstance(dais.readUTF()); 107 final int size = dais.readInt(); 108 final List<T> list = new ArrayList<>(size); 109 for (int i = 0; i < size; i++) { 110 list.add(streamingCodec.decodeFromStream(dais)); 111 } 112 return new NSMessage<>(srcId, destId, list); 113 } 114 } catch (final IOException e) { 115 throw new RuntimeException("IOException", e); 116 } 117 } else { 118 final NSMessagePBuf pbuf; 119 try { 120 pbuf = NSMessagePBuf.parseFrom(buf); 121 } catch (final InvalidProtocolBufferException e) { 122 e.printStackTrace(); 123 throw new NetworkRuntimeException(e); 124 } 125 final List<T> list = new ArrayList<>(); 126 for (final NSRecordPBuf rbuf : pbuf.getMsgsList()) { 127 list.add(codec.decode(rbuf.getData().toByteArray())); 128 } 129 return new NSMessage<>(factory.getNewInstance(pbuf.getSrcid()), factory.getNewInstance(pbuf.getDestid()), list); 130 } 131 } 132 133 134}