001/** 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.examples.suspend; 020 021import org.apache.reef.io.checkpoint.CheckpointID; 022import org.apache.reef.io.checkpoint.CheckpointService; 023import org.apache.reef.io.checkpoint.CheckpointService.CheckpointReadChannel; 024import org.apache.reef.io.checkpoint.CheckpointService.CheckpointWriteChannel; 025import org.apache.reef.io.checkpoint.fs.FSCheckpointID; 026import org.apache.reef.tang.annotations.Parameter; 027import org.apache.reef.tang.annotations.Unit; 028import org.apache.reef.task.Task; 029import org.apache.reef.task.TaskMessage; 030import org.apache.reef.task.TaskMessageSource; 031import org.apache.reef.task.events.SuspendEvent; 032import org.apache.reef.util.Optional; 033import org.apache.reef.wake.EventHandler; 034import org.apache.reef.wake.remote.impl.ObjectSerializableCodec; 035 036import javax.inject.Inject; 037import java.io.IOException; 038import java.nio.ByteBuffer; 039import java.util.logging.Level; 040import java.util.logging.Logger; 041 042/** 043 * Simple do-nothing task that can send messages to the Driver and can be suspended/resumed. 044 */ 045@Unit 046public class SuspendTestTask implements Task, TaskMessageSource { 047 048 /** 049 * Standard java logger. 050 */ 051 private static final Logger LOG = Logger.getLogger(SuspendTestTask.class.getName()); 052 private final CheckpointService checkpointService; 053 /** 054 * number of cycles to run in the task. 055 */ 056 private final int numCycles; 057 /** 058 * delay in milliseconds between cycles in the task. 059 */ 060 private final int delay; 061 /** 062 * Codec to serialize/deserialize counter values for the updates. 063 */ 064 private final ObjectSerializableCodec<Integer> codecInt = new ObjectSerializableCodec<>(); 065 /** 066 * Codec to serialize/deserialize checkpoint IDs for suspend/resume. 067 */ 068 private final ObjectWritableCodec<CheckpointID> codecCheckpoint = 069 new ObjectWritableCodec<CheckpointID>(FSCheckpointID.class); 070 /** 071 * Current value of the counter. 072 */ 073 private int counter = 0; 074 /** 075 * True if the suspend message has been received, false otherwise. 076 */ 077 private boolean suspended = false; 078 079 /** 080 * Task constructor: invoked by TANG. 081 * 082 * @param numCycles number of cycles to run in the task. 083 * @param delay delay in seconds between cycles in the task. 084 */ 085 @Inject 086 public SuspendTestTask( 087 final CheckpointService checkpointService, 088 @Parameter(Launch.NumCycles.class) final int numCycles, 089 @Parameter(Launch.Delay.class) final int delay) { 090 this.checkpointService = checkpointService; 091 this.numCycles = numCycles; 092 this.delay = delay * 1000; 093 } 094 095 /** 096 * Main method of the task: run cycle from 0 to numCycles, 097 * and sleep for delay seconds on each cycle. 098 * 099 * @param memento serialized version of the counter. 100 * Empty array for initial run, but can contain value for resumed job. 101 * @return serialized version of the counter. 102 */ 103 @Override 104 public synchronized byte[] call(final byte[] memento) throws IOException, InterruptedException { 105 106 LOG.log(Level.INFO, "Start: {0} counter: {1}/{2}", 107 new Object[]{this, this.counter, this.numCycles}); 108 109 if (memento != null && memento.length > 0) { 110 this.restore(memento); 111 } 112 113 this.suspended = false; 114 for (; this.counter < this.numCycles && !this.suspended; ++this.counter) { 115 try { 116 LOG.log(Level.INFO, "Run: {0} counter: {1}/{2} sleep: {3}", 117 new Object[]{this, this.counter, this.numCycles, this.delay}); 118 this.wait(this.delay); 119 } catch (final InterruptedException ex) { 120 LOG.log(Level.INFO, "{0} interrupted. counter: {1}: {2}", 121 new Object[]{this, this.counter, ex}); 122 } 123 } 124 125 return this.suspended ? this.save() : this.codecInt.encode(this.counter); 126 } 127 128 /** 129 * Update driver on current state of the task. 130 * 131 * @return serialized version of the counter. 132 */ 133 @Override 134 public synchronized Optional<TaskMessage> getMessage() { 135 LOG.log(Level.INFO, "Message from Task {0} to the Driver: counter: {1}", 136 new Object[]{this, this.counter}); 137 return Optional.of(TaskMessage.from(SuspendTestTask.class.getName(), this.codecInt.encode(this.counter))); 138 } 139 140 /** 141 * Save current state of the task in the checkpoint. 142 * 143 * @return checkpoint ID (serialized) 144 */ 145 private synchronized byte[] save() throws IOException, InterruptedException { 146 try (final CheckpointWriteChannel channel = this.checkpointService.create()) { 147 channel.write(ByteBuffer.wrap(this.codecInt.encode(this.counter))); 148 return this.codecCheckpoint.encode(this.checkpointService.commit(channel)); 149 } 150 } 151 152 /** 153 * Restore the task state from the given checkpoint. 154 * 155 * @param memento serialized checkpoint ID 156 */ 157 private synchronized void restore(final byte[] memento) throws IOException, InterruptedException { 158 final CheckpointID checkpointId = this.codecCheckpoint.decode(memento); 159 try (final CheckpointReadChannel channel = this.checkpointService.open(checkpointId)) { 160 final ByteBuffer buffer = ByteBuffer.wrap(this.codecInt.encode(this.counter)); 161 channel.read(buffer); 162 this.counter = this.codecInt.decode(buffer.array()); 163 } 164 this.checkpointService.delete(checkpointId); 165 } 166 167 public class SuspendHandler implements EventHandler<SuspendEvent> { 168 169 @Override 170 public void onNext(SuspendEvent suspendEvent) { 171 final byte[] message = suspendEvent.get().get(); 172 LOG.log(Level.INFO, "Suspend: {0} with: {1} bytes; counter: {2}", 173 new Object[]{this, message.length, SuspendTestTask.this.counter}); 174 SuspendTestTask.this.suspended = true; 175 SuspendTestTask.this.notify(); 176 } 177 } 178 179}