001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.examples.suspend; 020 021import org.apache.reef.io.checkpoint.CheckpointID; 022import org.apache.reef.io.checkpoint.CheckpointService; 023import org.apache.reef.io.checkpoint.CheckpointService.CheckpointReadChannel; 024import org.apache.reef.io.checkpoint.CheckpointService.CheckpointWriteChannel; 025import org.apache.reef.io.checkpoint.fs.FSCheckpointID; 026import org.apache.reef.tang.annotations.Parameter; 027import org.apache.reef.tang.annotations.Unit; 028import org.apache.reef.task.Task; 029import org.apache.reef.task.TaskMessage; 030import org.apache.reef.task.TaskMessageSource; 031import org.apache.reef.task.events.SuspendEvent; 032import org.apache.reef.util.Optional; 033import org.apache.reef.wake.EventHandler; 034import org.apache.reef.wake.remote.impl.ObjectSerializableCodec; 035 036import javax.inject.Inject; 037import java.io.IOException; 038import java.nio.ByteBuffer; 039import java.util.logging.Level; 040import java.util.logging.Logger; 041 042/** 043 * Simple do-nothing task that can send messages to the Driver and can be suspended/resumed. 044 */ 045@Unit 046public class SuspendTestTask implements Task, TaskMessageSource { 047 048 /** 049 * Standard java logger. 050 */ 051 private static final Logger LOG = Logger.getLogger(SuspendTestTask.class.getName()); 052 private final CheckpointService checkpointService; 053 /** 054 * number of cycles to run in the task. 055 */ 056 private final int numCycles; 057 /** 058 * delay in milliseconds between cycles in the task. 059 */ 060 private final int delay; 061 /** 062 * Codec to serialize/deserialize counter values for the updates. 063 */ 064 private final ObjectSerializableCodec<Integer> codecInt = new ObjectSerializableCodec<>(); 065 /** 066 * Codec to serialize/deserialize checkpoint IDs for suspend/resume. 067 */ 068 @SuppressWarnings("checkstyle:diamondoperatorforvariabledefinition") 069 private final ObjectWritableCodec<CheckpointID> codecCheckpoint = 070 new ObjectWritableCodec<CheckpointID>(FSCheckpointID.class); 071 /** 072 * Current value of the counter. 073 */ 074 private int counter = 0; 075 /** 076 * True if the suspend message has been received, false otherwise. 077 */ 078 private boolean suspended = false; 079 080 /** 081 * Task constructor: invoked by TANG. 082 * 083 * @param numCycles number of cycles to run in the task. 084 * @param delay delay in seconds between cycles in the task. 085 */ 086 @Inject 087 public SuspendTestTask( 088 final CheckpointService checkpointService, 089 @Parameter(Launch.NumCycles.class) final int numCycles, 090 @Parameter(Launch.Delay.class) final int delay) { 091 this.checkpointService = checkpointService; 092 this.numCycles = numCycles; 093 this.delay = delay * 1000; 094 } 095 096 /** 097 * Main method of the task: run cycle from 0 to numCycles, 098 * and sleep for delay seconds on each cycle. 099 * 100 * @param memento serialized version of the counter. 101 * Empty array for initial run, but can contain value for resumed job. 102 * @return serialized version of the counter. 103 */ 104 @Override 105 public synchronized byte[] call(final byte[] memento) throws IOException, InterruptedException { 106 107 LOG.log(Level.INFO, "Start: {0} counter: {1}/{2}", 108 new Object[]{this, this.counter, this.numCycles}); 109 110 if (memento != null && memento.length > 0) { 111 this.restore(memento); 112 } 113 114 this.suspended = false; 115 for (; this.counter < this.numCycles && !this.suspended; ++this.counter) { 116 try { 117 LOG.log(Level.INFO, "Run: {0} counter: {1}/{2} sleep: {3}", 118 new Object[]{this, this.counter, this.numCycles, this.delay}); 119 this.wait(this.delay); 120 } catch (final InterruptedException ex) { 121 LOG.log(Level.INFO, "{0} interrupted. counter: {1}: {2}", 122 new Object[]{this, this.counter, ex}); 123 } 124 } 125 126 return this.suspended ? this.save() : this.codecInt.encode(this.counter); 127 } 128 129 /** 130 * Update driver on current state of the task. 131 * 132 * @return serialized version of the counter. 133 */ 134 @Override 135 public synchronized Optional<TaskMessage> getMessage() { 136 LOG.log(Level.INFO, "Message from Task {0} to the Driver: counter: {1}", 137 new Object[]{this, this.counter}); 138 return Optional.of(TaskMessage.from(SuspendTestTask.class.getName(), this.codecInt.encode(this.counter))); 139 } 140 141 /** 142 * Save current state of the task in the checkpoint. 143 * 144 * @return checkpoint ID (serialized) 145 */ 146 private synchronized byte[] save() throws IOException, InterruptedException { 147 try (final CheckpointWriteChannel channel = this.checkpointService.create()) { 148 channel.write(ByteBuffer.wrap(this.codecInt.encode(this.counter))); 149 return this.codecCheckpoint.encode(this.checkpointService.commit(channel)); 150 } 151 } 152 153 /** 154 * Restore the task state from the given checkpoint. 155 * 156 * @param memento serialized checkpoint ID 157 */ 158 private synchronized void restore(final byte[] memento) throws IOException, InterruptedException { 159 final CheckpointID checkpointId = this.codecCheckpoint.decode(memento); 160 try (final CheckpointReadChannel channel = this.checkpointService.open(checkpointId)) { 161 final ByteBuffer buffer = ByteBuffer.wrap(this.codecInt.encode(this.counter)); 162 channel.read(buffer); 163 this.counter = this.codecInt.decode(buffer.array()); 164 } 165 this.checkpointService.delete(checkpointId); 166 } 167 168 /** 169 * Handler for suspend event. 170 */ 171 public class SuspendHandler implements EventHandler<SuspendEvent> { 172 173 @Override 174 public void onNext(final SuspendEvent suspendEvent) { 175 final byte[] message = suspendEvent.get().get(); 176 synchronized (SuspendTestTask.this) { 177 LOG.log(Level.INFO, "Suspend: {0} with: {1} bytes; counter: {2}", 178 new Object[]{this, message.length, SuspendTestTask.this.counter}); 179 SuspendTestTask.this.suspended = true; 180 SuspendTestTask.this.notify(); 181 } 182 } 183 } 184 185}