This project has retired. For details please refer to its Attic page.
Source code
001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.vortex.common;
020
021import org.apache.avro.io.*;
022import org.apache.avro.specific.SpecificDatumReader;
023import org.apache.avro.specific.SpecificDatumWriter;
024import org.apache.commons.lang.SerializationUtils;
025import org.apache.reef.annotations.Unstable;
026import org.apache.reef.annotations.audience.DriverSide;
027import org.apache.reef.annotations.audience.Private;
028import org.apache.reef.vortex.api.VortexAggregateFunction;
029import org.apache.reef.vortex.api.VortexAggregatePolicy;
030import org.apache.reef.vortex.api.VortexFunction;
031import org.apache.reef.vortex.common.avro.*;
032
033import javax.inject.Inject;
034import java.io.ByteArrayOutputStream;
035import java.io.IOException;
036import java.nio.ByteBuffer;
037import java.util.ArrayList;
038import java.util.List;
039
040/**
041 * Serialize and deserialize Vortex message to/from byte array.
042 */
043@Private
044@DriverSide
045@Unstable
046public final class VortexAvroUtils {
047  private final AggregateFunctionRepository aggregateFunctionRepository;
048
049  @Inject
050  private VortexAvroUtils(final AggregateFunctionRepository aggregateFunctionRepository) {
051    this.aggregateFunctionRepository = aggregateFunctionRepository;
052  }
053
054  /**
055   * Serialize VortexRequest to byte array.
056   * @param vortexRequest Vortex request message to serialize.
057   * @return Serialized byte array.
058   */
059  public byte[] toBytes(final VortexRequest vortexRequest) {
060    // Convert VortexRequest message to Avro message.
061    final AvroVortexRequest avroVortexRequest;
062    switch (vortexRequest.getType()) {
063    case ExecuteAggregateTasklet:
064      final TaskletAggregateExecutionRequest taskletAggregateExecutionRequest =
065          (TaskletAggregateExecutionRequest) vortexRequest;
066      // TODO[REEF-1113]: Handle serialization failure separately in Vortex
067      final byte[] serializedInputForAggregate =
068        aggregateFunctionRepository.getFunction(taskletAggregateExecutionRequest.getAggregateFunctionId())
069            .getInputCodec().encode(taskletAggregateExecutionRequest.getInput());
070      avroVortexRequest = AvroVortexRequest.newBuilder()
071          .setRequestType(AvroRequestType.AggregateExecute)
072          .setTaskletRequest(
073              AvroTaskletAggregateExecutionRequest.newBuilder()
074                  .setAggregateFunctionId(taskletAggregateExecutionRequest.getAggregateFunctionId())
075                  .setSerializedInput(ByteBuffer.wrap(serializedInputForAggregate))
076                  .setTaskletId(taskletAggregateExecutionRequest.getTaskletId())
077                  .build())
078          .build();
079      break;
080    case AggregateTasklets:
081      final TaskletAggregationRequest taskletAggregationRequest = (TaskletAggregationRequest) vortexRequest;
082
083      // TODO[REEF-1003]: Use reflection instead of serialization when launching VortexFunction
084      final byte[] serializedAggregateFunction = SerializationUtils.serialize(
085          taskletAggregationRequest.getAggregateFunction());
086      final byte[] serializedFunctionForAggregation = SerializationUtils.serialize(
087          taskletAggregationRequest.getFunction());
088      final byte[] serializedPolicy = SerializationUtils.serialize(
089          taskletAggregationRequest.getPolicy());
090      avroVortexRequest = AvroVortexRequest.newBuilder()
091          .setRequestType(AvroRequestType.Aggregate)
092          .setTaskletRequest(AvroTaskletAggregationRequest.newBuilder()
093              .setAggregateFunctionId(taskletAggregationRequest.getAggregateFunctionId())
094              .setSerializedAggregateFunction(ByteBuffer.wrap(serializedAggregateFunction))
095              .setSerializedUserFunction(ByteBuffer.wrap(serializedFunctionForAggregation))
096              .setSerializedPolicy(ByteBuffer.wrap(serializedPolicy))
097              .build())
098          .build();
099      break;
100    case ExecuteTasklet:
101      final TaskletExecutionRequest taskletExecutionRequest = (TaskletExecutionRequest) vortexRequest;
102      // The following TODOs are sub-issues of cleaning up Serializable in Vortex (REEF-504).
103      // The purpose is to reduce serialization cost, which leads to bottleneck in Master.
104      // Temporarily those are left as TODOs, but will be addressed in separate PRs.
105      final VortexFunction vortexFunction = taskletExecutionRequest.getFunction();
106      // TODO[REEF-1113]: Handle serialization failure separately in Vortex
107      final byte[] serializedInput = vortexFunction.getInputCodec().encode(taskletExecutionRequest.getInput());
108      // TODO[REEF-1003]: Use reflection instead of serialization when launching VortexFunction
109      final byte[] serializedFunction = SerializationUtils.serialize(vortexFunction);
110      avroVortexRequest = AvroVortexRequest.newBuilder()
111          .setRequestType(AvroRequestType.ExecuteTasklet)
112          .setTaskletRequest(
113              AvroTaskletExecutionRequest.newBuilder()
114                  .setTaskletId(taskletExecutionRequest.getTaskletId())
115                  .setSerializedInput(ByteBuffer.wrap(serializedInput))
116                  .setSerializedUserFunction(ByteBuffer.wrap(serializedFunction))
117                  .build())
118          .build();
119      break;
120    case CancelTasklet:
121      final TaskletCancellationRequest taskletCancellationRequest = (TaskletCancellationRequest) vortexRequest;
122      avroVortexRequest = AvroVortexRequest.newBuilder()
123          .setRequestType(AvroRequestType.CancelTasklet)
124          .setTaskletRequest(
125              AvroTaskletCancellationRequest.newBuilder()
126                  .setTaskletId(taskletCancellationRequest.getTaskletId())
127                  .build())
128          .build();
129      break;
130    default:
131      throw new RuntimeException("Undefined message type");
132    }
133
134    // Serialize the Avro message to byte array.
135    return toBytes(avroVortexRequest, AvroVortexRequest.class);
136  }
137
138  /**
139   * Serialize WorkerReport to byte array.
140   * @param workerReport Worker report message to serialize.
141   * @return Serialized byte array.
142   */
143  public byte[] toBytes(final WorkerReport workerReport) {
144    final List<AvroTaskletReport> workerTaskletReports = new ArrayList<>();
145
146    for (final TaskletReport taskletReport : workerReport.getTaskletReports()) {
147      final AvroTaskletReport avroTaskletReport;
148      switch (taskletReport.getType()) {
149      case TaskletResult:
150        final TaskletResultReport taskletResultReport = (TaskletResultReport) taskletReport;
151        avroTaskletReport = AvroTaskletReport.newBuilder()
152            .setReportType(AvroReportType.TaskletResult)
153            .setTaskletReport(
154                AvroTaskletResultReport.newBuilder()
155                    .setTaskletId(taskletResultReport.getTaskletId())
156                    .setSerializedOutput(ByteBuffer.wrap(taskletResultReport.getSerializedResult()))
157                    .build())
158            .build();
159        break;
160      case TaskletAggregationResult:
161        final TaskletAggregationResultReport taskletAggregationResultReport =
162            (TaskletAggregationResultReport) taskletReport;
163        avroTaskletReport = AvroTaskletReport.newBuilder()
164            .setReportType(AvroReportType.TaskletAggregationResult)
165            .setTaskletReport(
166                AvroTaskletAggregationResultReport.newBuilder()
167                    .setTaskletIds(taskletAggregationResultReport.getTaskletIds())
168                    .setSerializedOutput(ByteBuffer.wrap(taskletAggregationResultReport.getSerializedResult()))
169                    .build())
170            .build();
171        break;
172      case TaskletCancelled:
173        final TaskletCancelledReport taskletCancelledReport = (TaskletCancelledReport) taskletReport;
174        avroTaskletReport = AvroTaskletReport.newBuilder()
175            .setReportType(AvroReportType.TaskletCancelled)
176            .setTaskletReport(
177                AvroTaskletCancelledReport.newBuilder()
178                    .setTaskletId(taskletCancelledReport.getTaskletId())
179                    .build())
180            .build();
181        break;
182      case TaskletFailure:
183        final TaskletFailureReport taskletFailureReport = (TaskletFailureReport) taskletReport;
184        final byte[] serializedException = SerializationUtils.serialize(taskletFailureReport.getException());
185        avroTaskletReport = AvroTaskletReport.newBuilder()
186            .setReportType(AvroReportType.TaskletFailure)
187            .setTaskletReport(
188                AvroTaskletFailureReport.newBuilder()
189                    .setTaskletId(taskletFailureReport.getTaskletId())
190                    .setSerializedException(ByteBuffer.wrap(serializedException))
191                    .build())
192            .build();
193        break;
194      case TaskletAggregationFailure:
195        final TaskletAggregationFailureReport taskletAggregationFailureReport =
196            (TaskletAggregationFailureReport) taskletReport;
197        final byte[] serializedAggregationException =
198            SerializationUtils.serialize(taskletAggregationFailureReport.getException());
199        avroTaskletReport = AvroTaskletReport.newBuilder()
200            .setReportType(AvroReportType.TaskletAggregationFailure)
201            .setTaskletReport(
202                AvroTaskletAggregationFailureReport.newBuilder()
203                    .setTaskletIds(taskletAggregationFailureReport.getTaskletIds())
204                    .setSerializedException(ByteBuffer.wrap(serializedAggregationException))
205                    .build())
206            .build();
207        break;
208      default:
209        throw new RuntimeException("Undefined message type");
210      }
211
212      workerTaskletReports.add(avroTaskletReport);
213    }
214
215    // Convert WorkerReport message to Avro message.
216    final AvroWorkerReport avroWorkerReport = AvroWorkerReport.newBuilder()
217        .setTaskletReports(workerTaskletReports)
218        .build();
219
220    // Serialize the Avro message to byte array.
221    return toBytes(avroWorkerReport, AvroWorkerReport.class);
222  }
223
224  /**
225   * Deserialize byte array to VortexRequest.
226   * @param bytes Byte array to deserialize.
227   * @return De-serialized VortexRequest.
228   */
229  public VortexRequest toVortexRequest(final byte[] bytes) {
230    final AvroVortexRequest avroVortexRequest = toAvroObject(bytes, AvroVortexRequest.class);
231
232    final VortexRequest vortexRequest;
233    switch (avroVortexRequest.getRequestType()) {
234    case AggregateExecute:
235      final AvroTaskletAggregateExecutionRequest taskletAggregateExecutionRequest =
236          (AvroTaskletAggregateExecutionRequest)avroVortexRequest.getTaskletRequest();
237      vortexRequest = new TaskletAggregateExecutionRequest<>(taskletAggregateExecutionRequest.getTaskletId(),
238          taskletAggregateExecutionRequest.getAggregateFunctionId(),
239          aggregateFunctionRepository.getFunction(taskletAggregateExecutionRequest.getAggregateFunctionId())
240              .getInputCodec().decode(taskletAggregateExecutionRequest.getSerializedInput().array()));
241      break;
242    case Aggregate:
243      final AvroTaskletAggregationRequest taskletAggregationRequest =
244          (AvroTaskletAggregationRequest)avroVortexRequest.getTaskletRequest();
245      final VortexAggregateFunction aggregateFunction =
246          (VortexAggregateFunction) SerializationUtils.deserialize(
247              taskletAggregationRequest.getSerializedAggregateFunction().array());
248      final VortexFunction functionForAggregation =
249          (VortexFunction) SerializationUtils.deserialize(
250              taskletAggregationRequest.getSerializedUserFunction().array());
251      final VortexAggregatePolicy policy =
252          (VortexAggregatePolicy) SerializationUtils.deserialize(
253              taskletAggregationRequest.getSerializedPolicy().array());
254      vortexRequest = new TaskletAggregationRequest<>(taskletAggregationRequest.getAggregateFunctionId(),
255          aggregateFunction, functionForAggregation, policy);
256      break;
257    case ExecuteTasklet:
258      final AvroTaskletExecutionRequest taskletExecutionRequest =
259          (AvroTaskletExecutionRequest)avroVortexRequest.getTaskletRequest();
260      // TODO[REEF-1003]: Use reflection instead of serialization when launching VortexFunction
261      final VortexFunction function =
262          (VortexFunction) SerializationUtils.deserialize(
263              taskletExecutionRequest.getSerializedUserFunction().array());
264      // TODO[REEF-1113]: Handle serialization failure separately in Vortex
265      vortexRequest = new TaskletExecutionRequest(taskletExecutionRequest.getTaskletId(), function,
266         function.getInputCodec().decode(taskletExecutionRequest.getSerializedInput().array()));
267      break;
268    case CancelTasklet:
269      final AvroTaskletCancellationRequest taskletCancellationRequest =
270          (AvroTaskletCancellationRequest)avroVortexRequest.getTaskletRequest();
271      vortexRequest = new TaskletCancellationRequest(taskletCancellationRequest.getTaskletId());
272      break;
273    default:
274      throw new RuntimeException("Undefined VortexRequest type");
275    }
276    return vortexRequest;
277  }
278
279  /**
280   * Deserialize byte array to WorkerReport.
281   * @param bytes Byte array to deserialize.
282   * @return De-serialized WorkerReport.
283   */
284  public WorkerReport toWorkerReport(final byte[] bytes) {
285    final AvroWorkerReport avroWorkerReport = toAvroObject(bytes, AvroWorkerReport.class);
286    final List<TaskletReport> workerTaskletReports = new ArrayList<>();
287
288    for (final AvroTaskletReport avroTaskletReport : avroWorkerReport.getTaskletReports()) {
289      final TaskletReport taskletReport;
290
291      switch (avroTaskletReport.getReportType()) {
292      case TaskletResult:
293        final AvroTaskletResultReport taskletResultReport =
294            (AvroTaskletResultReport) avroTaskletReport.getTaskletReport();
295        taskletReport = new TaskletResultReport(taskletResultReport.getTaskletId(),
296            taskletResultReport.getSerializedOutput().array());
297        break;
298      case TaskletAggregationResult:
299        final AvroTaskletAggregationResultReport taskletAggregationResultReport =
300            (AvroTaskletAggregationResultReport) avroTaskletReport.getTaskletReport();
301        taskletReport =
302            new TaskletAggregationResultReport(taskletAggregationResultReport.getTaskletIds(),
303                taskletAggregationResultReport.getSerializedOutput().array());
304        break;
305      case TaskletCancelled:
306        final AvroTaskletCancelledReport taskletCancelledReport =
307            (AvroTaskletCancelledReport) avroTaskletReport.getTaskletReport();
308        taskletReport = new TaskletCancelledReport(taskletCancelledReport.getTaskletId());
309        break;
310      case TaskletFailure:
311        final AvroTaskletFailureReport taskletFailureReport =
312            (AvroTaskletFailureReport) avroTaskletReport.getTaskletReport();
313        final Exception exception =
314            (Exception) SerializationUtils.deserialize(taskletFailureReport.getSerializedException().array());
315        taskletReport = new TaskletFailureReport(taskletFailureReport.getTaskletId(), exception);
316        break;
317      case TaskletAggregationFailure:
318        final AvroTaskletAggregationFailureReport taskletAggregationFailureReport =
319            (AvroTaskletAggregationFailureReport) avroTaskletReport.getTaskletReport();
320        final Exception aggregationException =
321            (Exception) SerializationUtils.deserialize(
322                taskletAggregationFailureReport.getSerializedException().array());
323        taskletReport =
324            new TaskletAggregationFailureReport(taskletAggregationFailureReport.getTaskletIds(), aggregationException);
325        break;
326      default:
327        throw new RuntimeException("Undefined TaskletReport type");
328      }
329
330      workerTaskletReports.add(taskletReport);
331    }
332
333    return new WorkerReport(workerTaskletReports);
334  }
335
336  /**
337   * Serialize Avro object to byte array.
338   * @param avroObject Avro object to serialize.
339   * @param theClass Class of the Avro object.
340   * @param <T> Type of the Avro object.
341   * @return Serialized byte array.
342   */
343  private <T> byte[] toBytes(final T avroObject, final Class<T> theClass) {
344    final DatumWriter<T> reportWriter = new SpecificDatumWriter<>(theClass);
345    final byte[] theBytes;
346    try (final ByteArrayOutputStream out = new ByteArrayOutputStream()) {
347      final BinaryEncoder encoder = EncoderFactory.get().binaryEncoder(out, null);
348      reportWriter.write(avroObject, encoder);
349      encoder.flush();
350      out.flush();
351      theBytes = out.toByteArray();
352      return theBytes;
353    } catch (IOException e) {
354      throw new RuntimeException(e);
355    }
356  }
357
358  /**
359   * Deserialize byte array to Avro object.
360   * @param bytes Byte array to deserialize.
361   * @param theClass Class of the Avro object.
362   * @param <T> Type of the Avro object.
363   * @return Avro object de-serialized from byte array.
364   */
365  private <T> T toAvroObject(final byte[] bytes, final Class<T> theClass) {
366    final BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(bytes, null);
367    final SpecificDatumReader<T> reader = new SpecificDatumReader<>(theClass);
368    try {
369      return reader.read(null, decoder);
370    } catch (IOException e) {
371      throw new RuntimeException(e);
372    }
373  }
374}