001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.vortex.common; 020 021import org.apache.avro.io.*; 022import org.apache.avro.specific.SpecificDatumReader; 023import org.apache.avro.specific.SpecificDatumWriter; 024import org.apache.commons.lang.SerializationUtils; 025import org.apache.reef.annotations.Unstable; 026import org.apache.reef.annotations.audience.DriverSide; 027import org.apache.reef.annotations.audience.Private; 028import org.apache.reef.vortex.api.VortexAggregateFunction; 029import org.apache.reef.vortex.api.VortexAggregatePolicy; 030import org.apache.reef.vortex.api.VortexFunction; 031import org.apache.reef.vortex.common.avro.*; 032 033import javax.inject.Inject; 034import java.io.ByteArrayOutputStream; 035import java.io.IOException; 036import java.nio.ByteBuffer; 037import java.util.ArrayList; 038import java.util.List; 039 040/** 041 * Serialize and deserialize Vortex message to/from byte array. 042 */ 043@Private 044@DriverSide 045@Unstable 046public final class VortexAvroUtils { 047 private final AggregateFunctionRepository aggregateFunctionRepository; 048 049 @Inject 050 private VortexAvroUtils(final AggregateFunctionRepository aggregateFunctionRepository) { 051 this.aggregateFunctionRepository = aggregateFunctionRepository; 052 } 053 054 /** 055 * Serialize VortexRequest to byte array. 056 * @param vortexRequest Vortex request message to serialize. 057 * @return Serialized byte array. 058 */ 059 public byte[] toBytes(final VortexRequest vortexRequest) { 060 // Convert VortexRequest message to Avro message. 061 final AvroVortexRequest avroVortexRequest; 062 switch (vortexRequest.getType()) { 063 case ExecuteAggregateTasklet: 064 final TaskletAggregateExecutionRequest taskletAggregateExecutionRequest = 065 (TaskletAggregateExecutionRequest) vortexRequest; 066 // TODO[REEF-1113]: Handle serialization failure separately in Vortex 067 final byte[] serializedInputForAggregate = 068 aggregateFunctionRepository.getFunction(taskletAggregateExecutionRequest.getAggregateFunctionId()) 069 .getInputCodec().encode(taskletAggregateExecutionRequest.getInput()); 070 avroVortexRequest = AvroVortexRequest.newBuilder() 071 .setRequestType(AvroRequestType.AggregateExecute) 072 .setTaskletRequest( 073 AvroTaskletAggregateExecutionRequest.newBuilder() 074 .setAggregateFunctionId(taskletAggregateExecutionRequest.getAggregateFunctionId()) 075 .setSerializedInput(ByteBuffer.wrap(serializedInputForAggregate)) 076 .setTaskletId(taskletAggregateExecutionRequest.getTaskletId()) 077 .build()) 078 .build(); 079 break; 080 case AggregateTasklets: 081 final TaskletAggregationRequest taskletAggregationRequest = (TaskletAggregationRequest) vortexRequest; 082 083 // TODO[REEF-1003]: Use reflection instead of serialization when launching VortexFunction 084 final byte[] serializedAggregateFunction = SerializationUtils.serialize( 085 taskletAggregationRequest.getAggregateFunction()); 086 final byte[] serializedFunctionForAggregation = SerializationUtils.serialize( 087 taskletAggregationRequest.getFunction()); 088 final byte[] serializedPolicy = SerializationUtils.serialize( 089 taskletAggregationRequest.getPolicy()); 090 avroVortexRequest = AvroVortexRequest.newBuilder() 091 .setRequestType(AvroRequestType.Aggregate) 092 .setTaskletRequest(AvroTaskletAggregationRequest.newBuilder() 093 .setAggregateFunctionId(taskletAggregationRequest.getAggregateFunctionId()) 094 .setSerializedAggregateFunction(ByteBuffer.wrap(serializedAggregateFunction)) 095 .setSerializedUserFunction(ByteBuffer.wrap(serializedFunctionForAggregation)) 096 .setSerializedPolicy(ByteBuffer.wrap(serializedPolicy)) 097 .build()) 098 .build(); 099 break; 100 case ExecuteTasklet: 101 final TaskletExecutionRequest taskletExecutionRequest = (TaskletExecutionRequest) vortexRequest; 102 // The following TODOs are sub-issues of cleaning up Serializable in Vortex (REEF-504). 103 // The purpose is to reduce serialization cost, which leads to bottleneck in Master. 104 // Temporarily those are left as TODOs, but will be addressed in separate PRs. 105 final VortexFunction vortexFunction = taskletExecutionRequest.getFunction(); 106 // TODO[REEF-1113]: Handle serialization failure separately in Vortex 107 final byte[] serializedInput = vortexFunction.getInputCodec().encode(taskletExecutionRequest.getInput()); 108 // TODO[REEF-1003]: Use reflection instead of serialization when launching VortexFunction 109 final byte[] serializedFunction = SerializationUtils.serialize(vortexFunction); 110 avroVortexRequest = AvroVortexRequest.newBuilder() 111 .setRequestType(AvroRequestType.ExecuteTasklet) 112 .setTaskletRequest( 113 AvroTaskletExecutionRequest.newBuilder() 114 .setTaskletId(taskletExecutionRequest.getTaskletId()) 115 .setSerializedInput(ByteBuffer.wrap(serializedInput)) 116 .setSerializedUserFunction(ByteBuffer.wrap(serializedFunction)) 117 .build()) 118 .build(); 119 break; 120 case CancelTasklet: 121 final TaskletCancellationRequest taskletCancellationRequest = (TaskletCancellationRequest) vortexRequest; 122 avroVortexRequest = AvroVortexRequest.newBuilder() 123 .setRequestType(AvroRequestType.CancelTasklet) 124 .setTaskletRequest( 125 AvroTaskletCancellationRequest.newBuilder() 126 .setTaskletId(taskletCancellationRequest.getTaskletId()) 127 .build()) 128 .build(); 129 break; 130 default: 131 throw new RuntimeException("Undefined message type"); 132 } 133 134 // Serialize the Avro message to byte array. 135 return toBytes(avroVortexRequest, AvroVortexRequest.class); 136 } 137 138 /** 139 * Serialize WorkerReport to byte array. 140 * @param workerReport Worker report message to serialize. 141 * @return Serialized byte array. 142 */ 143 public byte[] toBytes(final WorkerReport workerReport) { 144 final List<AvroTaskletReport> workerTaskletReports = new ArrayList<>(); 145 146 for (final TaskletReport taskletReport : workerReport.getTaskletReports()) { 147 final AvroTaskletReport avroTaskletReport; 148 switch (taskletReport.getType()) { 149 case TaskletResult: 150 final TaskletResultReport taskletResultReport = (TaskletResultReport) taskletReport; 151 avroTaskletReport = AvroTaskletReport.newBuilder() 152 .setReportType(AvroReportType.TaskletResult) 153 .setTaskletReport( 154 AvroTaskletResultReport.newBuilder() 155 .setTaskletId(taskletResultReport.getTaskletId()) 156 .setSerializedOutput(ByteBuffer.wrap(taskletResultReport.getSerializedResult())) 157 .build()) 158 .build(); 159 break; 160 case TaskletAggregationResult: 161 final TaskletAggregationResultReport taskletAggregationResultReport = 162 (TaskletAggregationResultReport) taskletReport; 163 avroTaskletReport = AvroTaskletReport.newBuilder() 164 .setReportType(AvroReportType.TaskletAggregationResult) 165 .setTaskletReport( 166 AvroTaskletAggregationResultReport.newBuilder() 167 .setTaskletIds(taskletAggregationResultReport.getTaskletIds()) 168 .setSerializedOutput(ByteBuffer.wrap(taskletAggregationResultReport.getSerializedResult())) 169 .build()) 170 .build(); 171 break; 172 case TaskletCancelled: 173 final TaskletCancelledReport taskletCancelledReport = (TaskletCancelledReport) taskletReport; 174 avroTaskletReport = AvroTaskletReport.newBuilder() 175 .setReportType(AvroReportType.TaskletCancelled) 176 .setTaskletReport( 177 AvroTaskletCancelledReport.newBuilder() 178 .setTaskletId(taskletCancelledReport.getTaskletId()) 179 .build()) 180 .build(); 181 break; 182 case TaskletFailure: 183 final TaskletFailureReport taskletFailureReport = (TaskletFailureReport) taskletReport; 184 final byte[] serializedException = SerializationUtils.serialize(taskletFailureReport.getException()); 185 avroTaskletReport = AvroTaskletReport.newBuilder() 186 .setReportType(AvroReportType.TaskletFailure) 187 .setTaskletReport( 188 AvroTaskletFailureReport.newBuilder() 189 .setTaskletId(taskletFailureReport.getTaskletId()) 190 .setSerializedException(ByteBuffer.wrap(serializedException)) 191 .build()) 192 .build(); 193 break; 194 case TaskletAggregationFailure: 195 final TaskletAggregationFailureReport taskletAggregationFailureReport = 196 (TaskletAggregationFailureReport) taskletReport; 197 final byte[] serializedAggregationException = 198 SerializationUtils.serialize(taskletAggregationFailureReport.getException()); 199 avroTaskletReport = AvroTaskletReport.newBuilder() 200 .setReportType(AvroReportType.TaskletAggregationFailure) 201 .setTaskletReport( 202 AvroTaskletAggregationFailureReport.newBuilder() 203 .setTaskletIds(taskletAggregationFailureReport.getTaskletIds()) 204 .setSerializedException(ByteBuffer.wrap(serializedAggregationException)) 205 .build()) 206 .build(); 207 break; 208 default: 209 throw new RuntimeException("Undefined message type"); 210 } 211 212 workerTaskletReports.add(avroTaskletReport); 213 } 214 215 // Convert WorkerReport message to Avro message. 216 final AvroWorkerReport avroWorkerReport = AvroWorkerReport.newBuilder() 217 .setTaskletReports(workerTaskletReports) 218 .build(); 219 220 // Serialize the Avro message to byte array. 221 return toBytes(avroWorkerReport, AvroWorkerReport.class); 222 } 223 224 /** 225 * Deserialize byte array to VortexRequest. 226 * @param bytes Byte array to deserialize. 227 * @return De-serialized VortexRequest. 228 */ 229 public VortexRequest toVortexRequest(final byte[] bytes) { 230 final AvroVortexRequest avroVortexRequest = toAvroObject(bytes, AvroVortexRequest.class); 231 232 final VortexRequest vortexRequest; 233 switch (avroVortexRequest.getRequestType()) { 234 case AggregateExecute: 235 final AvroTaskletAggregateExecutionRequest taskletAggregateExecutionRequest = 236 (AvroTaskletAggregateExecutionRequest)avroVortexRequest.getTaskletRequest(); 237 vortexRequest = new TaskletAggregateExecutionRequest<>(taskletAggregateExecutionRequest.getTaskletId(), 238 taskletAggregateExecutionRequest.getAggregateFunctionId(), 239 aggregateFunctionRepository.getFunction(taskletAggregateExecutionRequest.getAggregateFunctionId()) 240 .getInputCodec().decode(taskletAggregateExecutionRequest.getSerializedInput().array())); 241 break; 242 case Aggregate: 243 final AvroTaskletAggregationRequest taskletAggregationRequest = 244 (AvroTaskletAggregationRequest)avroVortexRequest.getTaskletRequest(); 245 final VortexAggregateFunction aggregateFunction = 246 (VortexAggregateFunction) SerializationUtils.deserialize( 247 taskletAggregationRequest.getSerializedAggregateFunction().array()); 248 final VortexFunction functionForAggregation = 249 (VortexFunction) SerializationUtils.deserialize( 250 taskletAggregationRequest.getSerializedUserFunction().array()); 251 final VortexAggregatePolicy policy = 252 (VortexAggregatePolicy) SerializationUtils.deserialize( 253 taskletAggregationRequest.getSerializedPolicy().array()); 254 vortexRequest = new TaskletAggregationRequest<>(taskletAggregationRequest.getAggregateFunctionId(), 255 aggregateFunction, functionForAggregation, policy); 256 break; 257 case ExecuteTasklet: 258 final AvroTaskletExecutionRequest taskletExecutionRequest = 259 (AvroTaskletExecutionRequest)avroVortexRequest.getTaskletRequest(); 260 // TODO[REEF-1003]: Use reflection instead of serialization when launching VortexFunction 261 final VortexFunction function = 262 (VortexFunction) SerializationUtils.deserialize( 263 taskletExecutionRequest.getSerializedUserFunction().array()); 264 // TODO[REEF-1113]: Handle serialization failure separately in Vortex 265 vortexRequest = new TaskletExecutionRequest(taskletExecutionRequest.getTaskletId(), function, 266 function.getInputCodec().decode(taskletExecutionRequest.getSerializedInput().array())); 267 break; 268 case CancelTasklet: 269 final AvroTaskletCancellationRequest taskletCancellationRequest = 270 (AvroTaskletCancellationRequest)avroVortexRequest.getTaskletRequest(); 271 vortexRequest = new TaskletCancellationRequest(taskletCancellationRequest.getTaskletId()); 272 break; 273 default: 274 throw new RuntimeException("Undefined VortexRequest type"); 275 } 276 return vortexRequest; 277 } 278 279 /** 280 * Deserialize byte array to WorkerReport. 281 * @param bytes Byte array to deserialize. 282 * @return De-serialized WorkerReport. 283 */ 284 public WorkerReport toWorkerReport(final byte[] bytes) { 285 final AvroWorkerReport avroWorkerReport = toAvroObject(bytes, AvroWorkerReport.class); 286 final List<TaskletReport> workerTaskletReports = new ArrayList<>(); 287 288 for (final AvroTaskletReport avroTaskletReport : avroWorkerReport.getTaskletReports()) { 289 final TaskletReport taskletReport; 290 291 switch (avroTaskletReport.getReportType()) { 292 case TaskletResult: 293 final AvroTaskletResultReport taskletResultReport = 294 (AvroTaskletResultReport) avroTaskletReport.getTaskletReport(); 295 taskletReport = new TaskletResultReport(taskletResultReport.getTaskletId(), 296 taskletResultReport.getSerializedOutput().array()); 297 break; 298 case TaskletAggregationResult: 299 final AvroTaskletAggregationResultReport taskletAggregationResultReport = 300 (AvroTaskletAggregationResultReport) avroTaskletReport.getTaskletReport(); 301 taskletReport = 302 new TaskletAggregationResultReport(taskletAggregationResultReport.getTaskletIds(), 303 taskletAggregationResultReport.getSerializedOutput().array()); 304 break; 305 case TaskletCancelled: 306 final AvroTaskletCancelledReport taskletCancelledReport = 307 (AvroTaskletCancelledReport) avroTaskletReport.getTaskletReport(); 308 taskletReport = new TaskletCancelledReport(taskletCancelledReport.getTaskletId()); 309 break; 310 case TaskletFailure: 311 final AvroTaskletFailureReport taskletFailureReport = 312 (AvroTaskletFailureReport) avroTaskletReport.getTaskletReport(); 313 final Exception exception = 314 (Exception) SerializationUtils.deserialize(taskletFailureReport.getSerializedException().array()); 315 taskletReport = new TaskletFailureReport(taskletFailureReport.getTaskletId(), exception); 316 break; 317 case TaskletAggregationFailure: 318 final AvroTaskletAggregationFailureReport taskletAggregationFailureReport = 319 (AvroTaskletAggregationFailureReport) avroTaskletReport.getTaskletReport(); 320 final Exception aggregationException = 321 (Exception) SerializationUtils.deserialize( 322 taskletAggregationFailureReport.getSerializedException().array()); 323 taskletReport = 324 new TaskletAggregationFailureReport(taskletAggregationFailureReport.getTaskletIds(), aggregationException); 325 break; 326 default: 327 throw new RuntimeException("Undefined TaskletReport type"); 328 } 329 330 workerTaskletReports.add(taskletReport); 331 } 332 333 return new WorkerReport(workerTaskletReports); 334 } 335 336 /** 337 * Serialize Avro object to byte array. 338 * @param avroObject Avro object to serialize. 339 * @param theClass Class of the Avro object. 340 * @param <T> Type of the Avro object. 341 * @return Serialized byte array. 342 */ 343 private <T> byte[] toBytes(final T avroObject, final Class<T> theClass) { 344 final DatumWriter<T> reportWriter = new SpecificDatumWriter<>(theClass); 345 final byte[] theBytes; 346 try (final ByteArrayOutputStream out = new ByteArrayOutputStream()) { 347 final BinaryEncoder encoder = EncoderFactory.get().binaryEncoder(out, null); 348 reportWriter.write(avroObject, encoder); 349 encoder.flush(); 350 out.flush(); 351 theBytes = out.toByteArray(); 352 return theBytes; 353 } catch (IOException e) { 354 throw new RuntimeException(e); 355 } 356 } 357 358 /** 359 * Deserialize byte array to Avro object. 360 * @param bytes Byte array to deserialize. 361 * @param theClass Class of the Avro object. 362 * @param <T> Type of the Avro object. 363 * @return Avro object de-serialized from byte array. 364 */ 365 private <T> T toAvroObject(final byte[] bytes, final Class<T> theClass) { 366 final BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(bytes, null); 367 final SpecificDatumReader<T> reader = new SpecificDatumReader<>(theClass); 368 try { 369 return reader.read(null, decoder); 370 } catch (IOException e) { 371 throw new RuntimeException(e); 372 } 373 } 374}