1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.jclouds.sshj;
20
21 import static com.google.common.base.Preconditions.checkArgument;
22 import static com.google.common.base.Preconditions.checkNotNull;
23 import static com.google.common.base.Preconditions.checkState;
24 import static com.google.common.base.Predicates.instanceOf;
25 import static com.google.common.base.Predicates.or;
26 import static com.google.common.base.Throwables.getCausalChain;
27 import static com.google.common.collect.Iterables.any;
28 import static org.jclouds.crypto.SshKeys.fingerprintPrivateKey;
29 import static org.jclouds.crypto.SshKeys.sha1PrivateKey;
30
31 import java.io.IOException;
32 import java.io.InputStream;
33 import java.net.ConnectException;
34 import java.net.SocketTimeoutException;
35 import java.util.concurrent.TimeUnit;
36
37 import javax.annotation.PostConstruct;
38 import javax.annotation.PreDestroy;
39 import javax.annotation.Resource;
40 import javax.inject.Named;
41
42 import net.schmizz.sshj.common.IOUtils;
43 import net.schmizz.sshj.connection.ConnectionException;
44 import net.schmizz.sshj.connection.channel.direct.Session;
45 import net.schmizz.sshj.connection.channel.direct.Session.Command;
46 import net.schmizz.sshj.sftp.SFTPClient;
47 import net.schmizz.sshj.transport.TransportException;
48 import net.schmizz.sshj.transport.verification.PromiscuousVerifier;
49 import net.schmizz.sshj.userauth.UserAuthException;
50 import net.schmizz.sshj.userauth.keyprovider.OpenSSHKeyFile;
51 import net.schmizz.sshj.xfer.InMemorySourceFile;
52
53 import org.apache.commons.io.input.ProxyInputStream;
54 import org.jclouds.compute.domain.ExecResponse;
55 import org.jclouds.http.handlers.BackoffLimitedRetryHandler;
56 import org.jclouds.io.Payload;
57 import org.jclouds.io.Payloads;
58 import org.jclouds.logging.Logger;
59 import org.jclouds.net.IPSocket;
60 import org.jclouds.rest.AuthorizationException;
61 import org.jclouds.ssh.SshClient;
62 import org.jclouds.ssh.SshException;
63 import org.jclouds.util.Throwables2;
64
65 import com.google.common.annotations.VisibleForTesting;
66 import com.google.common.base.Predicate;
67 import com.google.common.base.Predicates;
68 import com.google.common.base.Splitter;
69 import com.google.common.base.Throwables;
70 import com.google.inject.Inject;
71
72
73
74
75
76
77 @SuppressWarnings("unchecked")
78 public class SshjSshClient implements SshClient {
79
80 private final class CloseFtpChannelOnCloseInputStream extends ProxyInputStream {
81
82 private final SFTPClient sftp;
83
84 private CloseFtpChannelOnCloseInputStream(InputStream proxy, SFTPClient sftp) {
85 super(proxy);
86 this.sftp = sftp;
87 }
88
89 @Override
90 public void close() throws IOException {
91 super.close();
92 if (sftp != null)
93 sftp.close();
94 }
95 }
96
97 private final String host;
98 private final int port;
99 private final String username;
100 private final String password;
101 private final String toString;
102
103 @Inject(optional = true)
104 @Named("jclouds.ssh.max-retries")
105 @VisibleForTesting
106 int sshRetries = 5;
107
108 @Inject(optional = true)
109 @Named("jclouds.ssh.retry-auth")
110 @VisibleForTesting
111 boolean retryAuth;
112
113 @Inject(optional = true)
114 @Named("jclouds.ssh.retryable-messages")
115 @VisibleForTesting
116 String retryableMessages = "";
117
118 @Inject(optional = true)
119 @Named("jclouds.ssh.retry-predicate")
120
121 private Predicate<Throwable> retryPredicate = or(instanceOf(ConnectionException.class),
122 instanceOf(ConnectException.class), instanceOf(SocketTimeoutException.class),
123 instanceOf(TransportException.class));
124
125 @Resource
126 @Named("jclouds.ssh")
127 protected Logger logger = Logger.NULL;
128
129 private net.schmizz.sshj.SSHClient ssh;
130 private final byte[] privateKey;
131 final byte[] emptyPassPhrase = new byte[0];
132 private final int timeoutMillis;
133 private final BackoffLimitedRetryHandler backoffLimitedRetryHandler;
134
135 public SshjSshClient(BackoffLimitedRetryHandler backoffLimitedRetryHandler, IPSocket socket, int timeout,
136 String username, String password, byte[] privateKey) {
137 this.host = checkNotNull(socket, "socket").getAddress();
138 checkArgument(socket.getPort() > 0, "ssh port must be greater then zero" + socket.getPort());
139 checkArgument(password != null || privateKey != null, "you must specify a password or a key");
140 this.port = socket.getPort();
141 this.username = checkNotNull(username, "username");
142 this.backoffLimitedRetryHandler = checkNotNull(backoffLimitedRetryHandler, "backoffLimitedRetryHandler");
143 this.timeoutMillis = timeout;
144 this.password = password;
145 this.privateKey = privateKey;
146 if (privateKey == null) {
147 this.toString = String.format("%s:password@%s:%d", username, host, port);
148 } else {
149 String fingerPrint = fingerprintPrivateKey(new String(privateKey));
150 String sha1 = sha1PrivateKey(new String(privateKey));
151 this.toString = String.format("%s:rsa[fingerprint(%s),sha1(%s)]@%s:%d", username, fingerPrint, sha1, host,
152 port);
153 }
154 }
155
156 @Override
157 public void put(String path, String contents) {
158 put(path, Payloads.newStringPayload(checkNotNull(contents, "contents")));
159 }
160
161 private void checkConnected() {
162 checkState(ssh != null && ssh.isConnected(), String.format("(%s) ssh not connected!", toString()));
163 }
164
165 public static interface Connection<T> {
166 void clear() throws Exception;
167
168 T create() throws Exception;
169 }
170
171 Connection<net.schmizz.sshj.SSHClient> sshConnection = new Connection<net.schmizz.sshj.SSHClient>() {
172
173 @Override
174 public void clear() {
175 if (ssh != null && ssh.isConnected()) {
176 try {
177 ssh.disconnect();
178 } catch (IOException e) {
179 Throwables.propagate(e);
180 }
181 ssh = null;
182 }
183 }
184
185 @Override
186 public net.schmizz.sshj.SSHClient create() throws Exception {
187 net.schmizz.sshj.SSHClient ssh = new net.schmizz.sshj.SSHClient();
188 ssh.addHostKeyVerifier(new PromiscuousVerifier());
189 if (timeoutMillis != 0) {
190 ssh.setTimeout(timeoutMillis);
191 ssh.setConnectTimeout(timeoutMillis);
192 }
193 ssh.connect(host, port);
194 if (password != null) {
195 ssh.authPassword(username, password);
196 } else {
197 OpenSSHKeyFile key = new OpenSSHKeyFile();
198 key.init(new String(privateKey), null);
199 ssh.authPublickey(username, key);
200 }
201 return ssh;
202 }
203
204 @Override
205 public String toString() {
206 return String.format("SSHClient(timeout=%d)", timeoutMillis);
207 }
208 };
209
210 private void backoffForAttempt(int retryAttempt, String message) {
211 backoffLimitedRetryHandler.imposeBackoffExponentialDelay(200L, 2, retryAttempt, sshRetries, message);
212 }
213
214 protected <T, C extends Connection<T>> T acquire(C connection) {
215 String errorMessage = String.format("(%s) error acquiring %s", toString(), connection);
216 for (int i = 0; i < sshRetries; i++) {
217 try {
218 connection.clear();
219 logger.debug(">> (%s) acquiring %s", toString(), connection);
220 T returnVal = connection.create();
221 logger.debug("<< (%s) acquired %s", toString(), returnVal);
222 return returnVal;
223 } catch (Exception from) {
224 try {
225 connection.clear();
226 } catch (Exception e1) {
227 logger.warn(from, "<< (%s) error closing connection", toString());
228 }
229 if (i + 1 == sshRetries) {
230 logger.error(from, "<< " + errorMessage + ": out of retries %d", sshRetries);
231 throw propagate(from, errorMessage);
232 } else if (Throwables2.getFirstThrowableOfType(from, IllegalStateException.class) != null) {
233 logger.warn(from, "<< " + errorMessage + ": " + from.getMessage());
234 disconnect();
235 backoffForAttempt(i + 1, errorMessage + ": " + from.getMessage());
236 connect();
237 continue;
238 } else if (shouldRetry(from)) {
239 logger.warn(from, "<< " + errorMessage + ": " + from.getMessage());
240 backoffForAttempt(i + 1, errorMessage + ": " + from.getMessage());
241 continue;
242 } else {
243 logger.error(from, "<< " + errorMessage + ": exception not retryable");
244 throw propagate(from, errorMessage);
245 }
246 }
247 }
248 assert false : "should not reach here";
249 return null;
250 }
251
252 @PostConstruct
253 public void connect() {
254 try {
255 ssh = acquire(sshConnection);
256 } catch (Exception e) {
257 Throwables.propagate(e);
258 }
259 }
260
261 Connection<SFTPClient> sftpConnection = new Connection<SFTPClient>() {
262
263 private SFTPClient sftp;
264
265 @Override
266 public void clear() {
267 if (sftp != null)
268 try {
269 sftp.close();
270 } catch (IOException e) {
271 Throwables.propagate(e);
272 }
273 }
274
275 @Override
276 public SFTPClient create() throws IOException {
277 checkConnected();
278 sftp = ssh.newSFTPClient();
279 return sftp;
280 }
281
282 @Override
283 public String toString() {
284 return "SFTPClient()";
285 }
286 };
287
288 class GetConnection implements Connection<Payload> {
289 private final String path;
290 private SFTPClient sftp;
291
292 GetConnection(String path) {
293 this.path = checkNotNull(path, "path");
294 }
295
296 @Override
297 public void clear() throws IOException {
298 if (sftp != null)
299 sftp.close();
300 }
301
302 @Override
303 public Payload create() throws Exception {
304 sftp = acquire(sftpConnection);
305 return Payloads.newInputStreamPayload(new CloseFtpChannelOnCloseInputStream(sftp.getSFTPEngine().open(path)
306 .getInputStream(), sftp));
307 }
308
309 @Override
310 public String toString() {
311 return "Payload(path=[" + path + "])";
312 }
313 };
314
315 public Payload get(String path) {
316 return acquire(new GetConnection(path));
317 }
318
319 class PutConnection implements Connection<Void> {
320 private final String path;
321 private final Payload contents;
322 private SFTPClient sftp;
323
324 PutConnection(String path, Payload contents) {
325 this.path = checkNotNull(path, "path");
326 this.contents = checkNotNull(contents, "contents");
327 }
328
329 @Override
330 public void clear() {
331 if (sftp != null)
332 try {
333 sftp.close();
334 } catch (IOException e) {
335 Throwables.propagate(e);
336 }
337 }
338
339 @Override
340 public Void create() throws Exception {
341 sftp = acquire(sftpConnection);
342 try {
343 sftp.put(new InMemorySourceFile() {
344
345 @Override
346 public String getName() {
347 return path;
348 }
349
350 @Override
351 public long getLength() {
352 return contents.getContentMetadata().getContentLength();
353 }
354
355 @Override
356 public InputStream getInputStream() throws IOException {
357 return checkNotNull(contents.getInput(), "inputstream for path %s", path);
358 }
359
360 }, path);
361 } finally {
362 contents.release();
363 }
364 return null;
365 }
366
367 @Override
368 public String toString() {
369 return "Put(path=[" + path + "])";
370 }
371 };
372
373 @Override
374 public void put(String path, Payload contents) {
375 acquire(new PutConnection(path, contents));
376 }
377
378 @VisibleForTesting
379 boolean shouldRetry(Exception from) {
380 Predicate<Throwable> predicate = retryAuth ? Predicates.<Throwable> or(retryPredicate,
381 instanceOf(AuthorizationException.class)) : retryPredicate;
382 if (any(getCausalChain(from), predicate))
383 return true;
384 if (!retryableMessages.equals(""))
385 return any(Splitter.on(",").split(retryableMessages), causalChainHasMessageContaining(from));
386 return false;
387 }
388
389 @VisibleForTesting
390 Predicate<String> causalChainHasMessageContaining(final Exception from) {
391 return new Predicate<String>() {
392
393 @Override
394 public boolean apply(final String input) {
395 return any(getCausalChain(from), new Predicate<Throwable>() {
396
397 @Override
398 public boolean apply(Throwable arg0) {
399 return arg0.getMessage() != null && arg0.getMessage().indexOf(input) != -1;
400 }
401
402 });
403 }
404
405 };
406 }
407
408 @VisibleForTesting
409 SshException propagate(Exception e, String message) {
410 message += ": " + e.getMessage();
411 logger.error(e, "<< " + message);
412 if (e instanceof UserAuthException)
413 throw new AuthorizationException("(" + toString() + ") " + message, e);
414 throw e instanceof SshException ? SshException.class.cast(e) : new SshException(
415 "(" + toString() + ") " + message, e);
416 }
417
418 @Override
419 public String toString() {
420 return toString;
421 }
422
423 @PreDestroy
424 public void disconnect() {
425 try {
426 sshConnection.clear();
427 } catch (Exception e) {
428 Throwables.propagate(e);
429 }
430 }
431
432 protected Connection<Session> execConnection() {
433
434 return new Connection<Session>() {
435
436 private Session session = null;
437
438 @Override
439 public void clear() throws TransportException, ConnectionException {
440 if (session != null)
441 session.close();
442 }
443
444 @Override
445 public Session create() throws Exception {
446 checkConnected();
447 session = ssh.startSession();
448 session.allocateDefaultPTY();
449 return session;
450 }
451
452 @Override
453 public String toString() {
454 return "Session()";
455 }
456 };
457
458 }
459
460 class ExecConnection implements Connection<ExecResponse> {
461 private final String command;
462 private Session session;
463
464 ExecConnection(String command) {
465 this.command = checkNotNull(command, "command");
466 }
467
468 @Override
469 public void clear() throws TransportException, ConnectionException {
470 if (session != null)
471 session.close();
472 }
473
474 @Override
475 public ExecResponse create() throws Exception {
476 try {
477 session = acquire(execConnection());
478 Command output = session.exec(checkNotNull(command, "command"));
479 String outputString = IOUtils.readFully(output.getInputStream()).toString();
480 output.join(timeoutMillis, TimeUnit.SECONDS);
481 int errorStatus = output.getExitStatus();
482 String errorString = IOUtils.readFully(output.getErrorStream()).toString();
483 return new ExecResponse(outputString, errorString, errorStatus);
484 } finally {
485 clear();
486 }
487 }
488
489 @Override
490 public String toString() {
491 return "ExecResponse(command=[" + command + "])";
492 }
493 }
494
495 public ExecResponse exec(String command) {
496 return acquire(new ExecConnection(command));
497 }
498
499 @Override
500 public String getHostAddress() {
501 return this.host;
502 }
503
504 @Override
505 public String getUsername() {
506 return this.username;
507 }
508
509 }