001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.tests.fail.driver; 020 021import org.apache.reef.driver.context.*; 022import org.apache.reef.driver.evaluator.*; 023import org.apache.reef.driver.task.*; 024import org.apache.reef.tang.annotations.Name; 025import org.apache.reef.tang.annotations.NamedParameter; 026import org.apache.reef.tang.annotations.Parameter; 027import org.apache.reef.tang.annotations.Unit; 028import org.apache.reef.tang.exceptions.BindException; 029import org.apache.reef.tests.library.exceptions.DriverSideFailure; 030import org.apache.reef.tests.library.exceptions.SimulatedDriverFailure; 031import org.apache.reef.wake.EventHandler; 032import org.apache.reef.wake.remote.impl.ObjectSerializableCodec; 033import org.apache.reef.wake.time.Clock; 034import org.apache.reef.wake.time.event.Alarm; 035import org.apache.reef.wake.time.event.StartTime; 036import org.apache.reef.wake.time.event.StopTime; 037 038import javax.inject.Inject; 039import javax.xml.bind.DatatypeConverter; 040import java.util.Arrays; 041import java.util.logging.Level; 042import java.util.logging.Logger; 043 044import static org.apache.reef.tests.fail.driver.FailDriver.ExpectedMessage.RequiredFlag.OPTIONAL; 045import static org.apache.reef.tests.fail.driver.FailDriver.ExpectedMessage.RequiredFlag.REQUIRED; 046 047/** 048 * Driver which fails on receiving certain message class. 049 */ 050@Unit 051public final class FailDriver { 052 053 private static final Logger LOG = Logger.getLogger(FailDriver.class.getName()); 054 private static final ObjectSerializableCodec<String> CODEC = new ObjectSerializableCodec<>(); 055 private static final byte[] HELLO_STR = CODEC.encode("MESSAGE::HELLO"); 056 /** 057 * Send message to the Task MSG_DELAY milliseconds after start. 058 */ 059 private static final int MSG_DELAY = 1000; 060 private static final ExpectedMessage[] EVENT_SEQUENCE = { 061 new ExpectedMessage(FailDriver.class, REQUIRED), 062 new ExpectedMessage(StartTime.class, REQUIRED), 063 new ExpectedMessage(AllocatedEvaluator.class, REQUIRED), 064 new ExpectedMessage(FailedEvaluator.class, OPTIONAL), 065 new ExpectedMessage(ActiveContext.class, REQUIRED), 066 new ExpectedMessage(ContextMessage.class, OPTIONAL), 067 new ExpectedMessage(FailedContext.class, OPTIONAL), 068 new ExpectedMessage(RunningTask.class, REQUIRED), 069 new ExpectedMessage(Alarm.class, REQUIRED), 070 new ExpectedMessage(TaskMessage.class, REQUIRED), 071 new ExpectedMessage(Alarm.class, REQUIRED), 072 new ExpectedMessage(SuspendedTask.class, REQUIRED), 073 new ExpectedMessage(RunningTask.class, REQUIRED), 074 new ExpectedMessage(Alarm.class, REQUIRED), 075 new ExpectedMessage(FailedTask.class, OPTIONAL), 076 new ExpectedMessage(CompletedTask.class, REQUIRED), 077 new ExpectedMessage(ClosedContext.class, OPTIONAL), 078 new ExpectedMessage(CompletedEvaluator.class, REQUIRED), 079 new ExpectedMessage(StopTime.class, REQUIRED) 080 }; 081 private final transient Class<?> failMsgClass; 082 private final transient EvaluatorRequestor requestor; 083 private final transient Clock clock; 084 private transient RunningTask task = null; 085 private transient int expectIdx = 0; 086 private transient DriverState state = DriverState.INIT; 087 088 @Inject 089 public FailDriver(@Parameter(FailMsgClassName.class) final String failMsgClassName, 090 final EvaluatorRequestor requestor, final Clock clock) 091 throws ClassNotFoundException { 092 this.failMsgClass = ClassLoader.getSystemClassLoader().loadClass(failMsgClassName); 093 this.requestor = requestor; 094 this.clock = clock; 095 this.checkMsgOrder(this); 096 } 097 098 /** 099 * Check if observer methods are called in the right order 100 * and generate an exception at the given point in the message sequence. 101 * 102 * @param msg a message from one of the observers. 103 * @throws SimulatedDriverFailure if failMsgClass matches the message class. 104 * @throws DriverSideFailure if messages are out of order. 105 */ 106 private void checkMsgOrder(final Object msg) throws SimulatedDriverFailure, DriverSideFailure { 107 108 final String msgClassName = msg.getClass().getName(); 109 LOG.log(Level.FINE, "At {0} {1}:{2}", new Object[] {this.state, this.expectIdx, msgClassName}); 110 111 if (this.state == DriverState.FAILED) { 112 // If already failed, do not do anything 113 return; 114 } 115 116 // Simulate failure at this step? 117 if (this.failMsgClass.isInstance(msg)) { 118 this.state = DriverState.FAILED; 119 } 120 121 // Make sure events arrive in the right order (specified in EVENT_SEQUENCE): 122 boolean notFound = true; 123 for (; this.expectIdx < EVENT_SEQUENCE.length; ++this.expectIdx) { 124 if (EVENT_SEQUENCE[expectIdx].msgClass.isInstance(msg)) { 125 notFound = false; 126 break; 127 } else if (EVENT_SEQUENCE[expectIdx].requiredFlag == REQUIRED) { 128 break; 129 } 130 } 131 132 if (notFound) { 133 LOG.log(Level.SEVERE, "Event out of sequence: {0} {1}:{2}", 134 new Object[] {this.state, this.expectIdx, msgClassName}); 135 throw new DriverSideFailure("Event out of sequence: " + msgClassName); 136 } 137 138 LOG.log(Level.INFO, "{0}: send: {1} got: {2}", new Object[] { 139 this.state, EVENT_SEQUENCE[this.expectIdx], msgClassName}); 140 141 ++this.expectIdx; 142 143 if (this.state == DriverState.FAILED) { 144 final SimulatedDriverFailure ex = new SimulatedDriverFailure( 145 "Simulated Failure at FailDriver :: " + msgClassName); 146 LOG.log(Level.INFO, "Simulated Failure:", ex); 147 throw ex; 148 } 149 } 150 151 private enum DriverState {INIT, SEND_MSG, SUSPEND, RESUME, CLOSE, FAILED} 152 153 /** 154 * Name of the message class to specify the failing message handler. 155 */ 156 @NamedParameter(doc = "Full name of the message class to fail on", short_name = "fail") 157 public static final class FailMsgClassName implements Name<String> { 158 } 159 160 /** 161 * Expected message class. 162 */ 163 static final class ExpectedMessage { 164 165 private final transient Class<?> msgClass; 166 private final transient RequiredFlag requiredFlag; 167 private final transient String repr; 168 169 private ExpectedMessage(final Class<?> clazz, final RequiredFlag requiredFlag) { 170 this.msgClass = clazz; 171 this.requiredFlag = requiredFlag; 172 this.repr = this.msgClass.getSimpleName() + ":" + this.requiredFlag; 173 } 174 175 @Override 176 public String toString() { 177 return this.repr; 178 } 179 180 /** 181 * "Required" flag for message class. 182 */ 183 enum RequiredFlag {OPTIONAL, REQUIRED} 184 } 185 186 final class AllocatedEvaluatorHandler implements EventHandler<AllocatedEvaluator> { 187 @Override 188 public void onNext(final AllocatedEvaluator eval) { 189 checkMsgOrder(eval); 190 try { 191 eval.submitContext(ContextConfiguration.CONF 192 .set(ContextConfiguration.IDENTIFIER, "FailContext_" + eval.getId()) 193 .build()); 194 } catch (final BindException ex) { 195 LOG.log(Level.WARNING, "Context configuration error", ex); 196 throw new RuntimeException(ex); 197 } 198 } 199 } 200 201 final class CompletedEvaluatorHandler implements EventHandler<CompletedEvaluator> { 202 @Override 203 public void onNext(final CompletedEvaluator eval) { 204 checkMsgOrder(eval); 205 // noop 206 } 207 } 208 209 final class FailedEvaluatorHandler implements EventHandler<FailedEvaluator> { 210 @Override 211 public void onNext(final FailedEvaluator eval) { 212 LOG.log(Level.WARNING, "Evaluator failed: " + eval.getId(), eval.getEvaluatorException()); 213 checkMsgOrder(eval); 214 throw new RuntimeException(eval.getEvaluatorException()); 215 } 216 } 217 218 final class ActiveContextHandler implements EventHandler<ActiveContext> { 219 @Override 220 public void onNext(final ActiveContext context) { 221 checkMsgOrder(context); 222 try { 223 context.submitTask(TaskConfiguration.CONF 224 .set(TaskConfiguration.IDENTIFIER, "FailTask_" + context.getId()) 225 .set(TaskConfiguration.TASK, NoopTask.class) 226 .set(TaskConfiguration.ON_MESSAGE, NoopTask.DriverMessageHandler.class) 227 .set(TaskConfiguration.ON_SUSPEND, NoopTask.TaskSuspendHandler.class) 228 .set(TaskConfiguration.ON_CLOSE, NoopTask.TaskCloseHandler.class) 229 .set(TaskConfiguration.ON_TASK_STOP, NoopTask.TaskStopHandler.class) 230 .set(TaskConfiguration.ON_SEND_MESSAGE, NoopTask.class) 231 .build()); 232 } catch (final BindException ex) { 233 LOG.log(Level.WARNING, "Task configuration error", ex); 234 throw new RuntimeException(ex); 235 } 236 } 237 } 238 239 final class ContextMessageHandler implements EventHandler<ContextMessage> { 240 @Override 241 public void onNext(final ContextMessage message) { 242 checkMsgOrder(message); 243 // noop 244 } 245 } 246 247 final class ClosedContextHandler implements EventHandler<ClosedContext> { 248 @Override 249 public void onNext(final ClosedContext context) { 250 checkMsgOrder(context); 251 // noop 252 } 253 } 254 255 final class FailedContextHandler implements EventHandler<FailedContext> { 256 @Override 257 public void onNext(final FailedContext context) { 258 LOG.log(Level.WARNING, "Context failed: " + context.getId(), context.getReason().orElse(null)); 259 checkMsgOrder(context); 260 261 // if (context.getParentContext().isPresent()) { 262 // context.getParentContext().get().close(); 263 // } 264 } 265 } 266 267 final class RunningTaskHandler implements EventHandler<RunningTask> { 268 @Override 269 public void onNext(final RunningTask runningTask) { 270 checkMsgOrder(runningTask); 271 FailDriver.this.task = runningTask; 272 switch (state) { 273 case INIT: 274 state = DriverState.SEND_MSG; 275 break; 276 case RESUME: 277 state = DriverState.CLOSE; 278 break; 279 default: 280 LOG.log(Level.WARNING, "Unexpected state at TaskRuntime: {0}", state); 281 throw new DriverSideFailure("Unexpected state: " + state); 282 } 283 // After a delay, send message or suspend the task: 284 clock.scheduleAlarm(MSG_DELAY, new AlarmHandler()); 285 } 286 } 287 288 final class SuspendedTaskHandler implements EventHandler<SuspendedTask> { 289 @Override 290 public void onNext(final SuspendedTask suspendedTask) { 291 checkMsgOrder(suspendedTask); 292 state = DriverState.RESUME; 293 try { 294 suspendedTask.getActiveContext().submitTask(TaskConfiguration.CONF 295 .set(TaskConfiguration.IDENTIFIER, suspendedTask.getId() + "_RESUMED") 296 .set(TaskConfiguration.TASK, NoopTask.class) 297 .set(TaskConfiguration.ON_MESSAGE, NoopTask.DriverMessageHandler.class) 298 .set(TaskConfiguration.ON_SUSPEND, NoopTask.TaskSuspendHandler.class) 299 .set(TaskConfiguration.ON_CLOSE, NoopTask.TaskCloseHandler.class) 300 .set(TaskConfiguration.ON_TASK_STOP, NoopTask.TaskStopHandler.class) 301 .set(TaskConfiguration.ON_SEND_MESSAGE, NoopTask.class) 302 .set(TaskConfiguration.MEMENTO, DatatypeConverter.printBase64Binary(HELLO_STR)) 303 .build()); 304 } catch (final BindException ex) { 305 LOG.log(Level.SEVERE, "Task configuration error", ex); 306 throw new DriverSideFailure("Task configuration error", ex); 307 } 308 } 309 } 310 311 final class TaskMessageHandler implements EventHandler<TaskMessage> { 312 @Override 313 public void onNext(final TaskMessage msg) { 314 checkMsgOrder(msg); 315 assert Arrays.equals(HELLO_STR, msg.get()); 316 assert state == DriverState.SEND_MSG; 317 state = DriverState.SUSPEND; 318 clock.scheduleAlarm(MSG_DELAY, new AlarmHandler()); 319 } 320 } 321 322 final class FailedTaskHandler implements EventHandler<FailedTask> { 323 @Override 324 public void onNext(final FailedTask failedTask) { 325 LOG.log(Level.WARNING, "Task failed: " + failedTask.getId(), failedTask.getReason().orElse(null)); 326 checkMsgOrder(failedTask); 327 if (failedTask.getActiveContext().isPresent()) { 328 failedTask.getActiveContext().get().close(); 329 } 330 } 331 } 332 333 final class CompletedTaskHandler implements EventHandler<CompletedTask> { 334 @Override 335 public void onNext(final CompletedTask completedTask) { 336 checkMsgOrder(completedTask); 337 completedTask.getActiveContext().close(); 338 } 339 } 340 341 final class StartHandler implements EventHandler<StartTime> { 342 @Override 343 public void onNext(final StartTime time) { 344 FailDriver.this.checkMsgOrder(time); 345 FailDriver.this.requestor.submit(EvaluatorRequest.newBuilder() 346 .setNumber(1).setMemory(128).setNumberOfCores(1).build()); 347 } 348 } 349 350 final class AlarmHandler implements EventHandler<Alarm> { 351 @Override 352 public void onNext(final Alarm time) { 353 FailDriver.this.checkMsgOrder(time); 354 switch (FailDriver.this.state) { 355 case SEND_MSG: 356 FailDriver.this.task.send(HELLO_STR); 357 break; 358 case SUSPEND: 359 FailDriver.this.task.suspend(); 360 break; 361 case CLOSE: 362 FailDriver.this.task.close(); 363 break; 364 default: 365 LOG.log(Level.WARNING, "Unexpected state at AlarmHandler: {0}", FailDriver.this.state); 366 throw new DriverSideFailure("Unexpected state: " + FailDriver.this.state); 367 } 368 } 369 } 370 371 final class StopHandler implements EventHandler<StopTime> { 372 @Override 373 public void onNext(final StopTime time) { 374 FailDriver.this.checkMsgOrder(time); 375 // noop 376 } 377 } 378}