1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.jclouds.sshj;
20
21 import static com.google.common.base.Preconditions.checkArgument;
22 import static com.google.common.base.Preconditions.checkNotNull;
23 import static com.google.common.base.Preconditions.checkState;
24 import static com.google.common.base.Predicates.instanceOf;
25 import static com.google.common.base.Predicates.or;
26 import static com.google.common.base.Throwables.getCausalChain;
27 import static com.google.common.collect.Iterables.any;
28
29 import java.io.IOException;
30 import java.io.InputStream;
31 import java.util.concurrent.TimeUnit;
32
33 import javax.annotation.PostConstruct;
34 import javax.annotation.PreDestroy;
35 import javax.annotation.Resource;
36 import javax.inject.Named;
37
38 import net.schmizz.sshj.common.IOUtils;
39 import net.schmizz.sshj.connection.ConnectionException;
40 import net.schmizz.sshj.connection.channel.direct.Session;
41 import net.schmizz.sshj.connection.channel.direct.Session.Command;
42 import net.schmizz.sshj.sftp.SFTPClient;
43 import net.schmizz.sshj.transport.TransportException;
44 import net.schmizz.sshj.transport.verification.PromiscuousVerifier;
45 import net.schmizz.sshj.userauth.UserAuthException;
46 import net.schmizz.sshj.userauth.keyprovider.OpenSSHKeyFile;
47 import net.schmizz.sshj.xfer.InMemorySourceFile;
48
49 import org.apache.commons.io.input.ProxyInputStream;
50 import org.jclouds.compute.domain.ExecResponse;
51 import org.jclouds.http.handlers.BackoffLimitedRetryHandler;
52 import org.jclouds.io.Payload;
53 import org.jclouds.io.Payloads;
54 import org.jclouds.logging.Logger;
55 import org.jclouds.net.IPSocket;
56 import org.jclouds.rest.AuthorizationException;
57 import org.jclouds.ssh.SshClient;
58 import org.jclouds.ssh.SshException;
59
60 import com.google.common.annotations.VisibleForTesting;
61 import com.google.common.base.Predicate;
62 import com.google.common.base.Predicates;
63 import com.google.common.base.Splitter;
64 import com.google.common.base.Throwables;
65 import com.google.inject.Inject;
66
67
68
69
70
71
72 public class SshjSshClient implements SshClient {
73
74 private final class CloseFtpChannelOnCloseInputStream extends ProxyInputStream {
75
76 private final SFTPClient sftp;
77
78 private CloseFtpChannelOnCloseInputStream(InputStream proxy, SFTPClient sftp) {
79 super(proxy);
80 this.sftp = sftp;
81 }
82
83 @Override
84 public void close() throws IOException {
85 super.close();
86 if (sftp != null)
87 sftp.close();
88 }
89 }
90
91 private final String host;
92 private final int port;
93 private final String username;
94 private final String password;
95
96 @Inject(optional = true)
97 @Named("jclouds.ssh.max-retries")
98 @VisibleForTesting
99 int sshRetries = 5;
100
101 @Inject(optional = true)
102 @Named("jclouds.ssh.retry-auth")
103 @VisibleForTesting
104 boolean retryAuth;
105
106 @Inject(optional = true)
107 @Named("jclouds.ssh.retryable-messages")
108 @VisibleForTesting
109 String retryableMessages = "";
110
111 @Inject(optional = true)
112 @Named("jclouds.ssh.retry-predicate")
113
114 private Predicate<Throwable> retryPredicate = or(instanceOf(ConnectionException.class),
115 instanceOf(TransportException.class));
116
117 @Resource
118 @Named("jclouds.ssh")
119 protected Logger logger = Logger.NULL;
120
121 private net.schmizz.sshj.SSHClient ssh;
122 private final byte[] privateKey;
123 final byte[] emptyPassPhrase = new byte[0];
124 private final int timeoutMillis;
125 private final BackoffLimitedRetryHandler backoffLimitedRetryHandler;
126
127 public SshjSshClient(BackoffLimitedRetryHandler backoffLimitedRetryHandler, IPSocket socket, int timeout,
128 String username, String password, byte[] privateKey) {
129 this.host = checkNotNull(socket, "socket").getAddress();
130 checkArgument(socket.getPort() > 0, "ssh port must be greater then zero" + socket.getPort());
131 checkArgument(password != null || privateKey != null, "you must specify a password or a key");
132 this.port = socket.getPort();
133 this.username = checkNotNull(username, "username");
134 this.backoffLimitedRetryHandler = checkNotNull(backoffLimitedRetryHandler, "backoffLimitedRetryHandler");
135 this.timeoutMillis = timeout;
136 this.password = password;
137 this.privateKey = privateKey;
138 }
139
140 @Override
141 public void put(String path, String contents) {
142 put(path, Payloads.newStringPayload(checkNotNull(contents, "contents")));
143 }
144
145 private void checkConnected() {
146 checkState(ssh != null && ssh.isConnected(), String.format("(%s) ssh not connected!", toString()));
147 }
148
149 public static interface Connection<T> {
150 void clear() throws Exception;
151
152 T create() throws Exception;
153 }
154
155 Connection<net.schmizz.sshj.SSHClient> sshConnection = new Connection<net.schmizz.sshj.SSHClient>() {
156
157 @Override
158 public void clear() {
159 if (ssh != null && ssh.isConnected()) {
160 try {
161 ssh.disconnect();
162 } catch (IOException e) {
163 Throwables.propagate(e);
164 }
165 ssh = null;
166 }
167 }
168
169 @Override
170 public net.schmizz.sshj.SSHClient create() throws Exception {
171 net.schmizz.sshj.SSHClient ssh = new net.schmizz.sshj.SSHClient();
172 ssh.addHostKeyVerifier(new PromiscuousVerifier());
173 if (timeoutMillis != 0) {
174 ssh.setTimeout(timeoutMillis);
175 ssh.setConnectTimeout(timeoutMillis);
176 }
177 ssh.connect(host, port);
178 if (password != null) {
179 ssh.authPassword(username, password);
180 } else {
181 OpenSSHKeyFile key = new OpenSSHKeyFile();
182 key.init(new String(privateKey), null);
183 ssh.authPublickey(username, key);
184 }
185 return ssh;
186 }
187
188 @Override
189 public String toString() {
190 return String.format("SSHClient(%s)", SshjSshClient.this.toString());
191 }
192 };
193
194 private void backoffForAttempt(int retryAttempt, String message) {
195 backoffLimitedRetryHandler.imposeBackoffExponentialDelay(200L, 2, retryAttempt, sshRetries, message);
196 }
197
198 protected <T, C extends Connection<T>> T acquire(C connection) {
199 String errorMessage = String.format("(%s) error acquiring %s", toString(), connection);
200 for (int i = 0; i < sshRetries; i++) {
201 try {
202 connection.clear();
203 logger.debug(">> (%s) acquiring %s", toString(), connection);
204 T returnVal = connection.create();
205 logger.debug("<< (%s) acquired %s", toString(), returnVal);
206 return returnVal;
207 } catch (Exception from) {
208 try {
209 connection.clear();
210 } catch (Exception e1) {
211 logger.warn(from, "<< (%s) error closing connection", toString());
212 }
213 if (i + 1 == sshRetries) {
214 throw propagate(from, errorMessage);
215 } else if (shouldRetry(from)) {
216 logger.warn(from, "<< " + errorMessage + ": " + from.getMessage());
217 backoffForAttempt(i + 1, errorMessage + ": " + from.getMessage());
218 continue;
219 }
220 }
221 }
222 assert false : "should not reach here";
223 return null;
224 }
225
226 @PostConstruct
227 public void connect() {
228 try {
229 ssh = acquire(sshConnection);
230 } catch (Exception e) {
231 Throwables.propagate(e);
232 }
233 }
234
235 Connection<SFTPClient> sftpConnection = new Connection<SFTPClient>() {
236
237 private SFTPClient sftp;
238
239 @Override
240 public void clear() {
241 if (sftp != null)
242 try {
243 sftp.close();
244 } catch (IOException e) {
245 Throwables.propagate(e);
246 }
247 }
248
249 @Override
250 public SFTPClient create() throws IOException {
251 checkConnected();
252 sftp = ssh.newSFTPClient();
253 return sftp;
254 }
255
256 @Override
257 public String toString() {
258 return "SFTPClient(" + SshjSshClient.this.toString() + ")";
259 }
260 };
261
262 class GetConnection implements Connection<Payload> {
263 private final String path;
264 private SFTPClient sftp;
265
266 GetConnection(String path) {
267 this.path = checkNotNull(path, "path");
268 }
269
270 @Override
271 public void clear() throws IOException {
272 if (sftp != null)
273 sftp.close();
274 }
275
276 @Override
277 public Payload create() throws Exception {
278 sftp = acquire(sftpConnection);
279 return Payloads.newInputStreamPayload(new CloseFtpChannelOnCloseInputStream(sftp.getSFTPEngine().open(path)
280 .getInputStream(), sftp));
281 }
282
283 @Override
284 public String toString() {
285 return "Payload(" + SshjSshClient.this.toString() + ")[" + path + "]";
286 }
287 };
288
289 public Payload get(String path) {
290 return acquire(new GetConnection(path));
291 }
292
293 class PutConnection implements Connection<Void> {
294 private final String path;
295 private final Payload contents;
296 private SFTPClient sftp;
297
298 PutConnection(String path, Payload contents) {
299 this.path = checkNotNull(path, "path");
300 this.contents = checkNotNull(contents, "contents");
301 }
302
303 @Override
304 public void clear() {
305 if (sftp != null)
306 try {
307 sftp.close();
308 } catch (IOException e) {
309 Throwables.propagate(e);
310 }
311 }
312
313 @Override
314 public Void create() throws Exception {
315 sftp = acquire(sftpConnection);
316 try {
317 sftp.put(new InMemorySourceFile() {
318
319 @Override
320 public String getName() {
321 return path;
322 }
323
324 @Override
325 public long getLength() {
326 return contents.getContentMetadata().getContentLength();
327 }
328
329 @Override
330 public InputStream getInputStream() throws IOException {
331 return checkNotNull(contents.getInput(), "inputstream for path %s", path);
332 }
333
334 }, path);
335 } finally {
336 contents.release();
337 }
338 return null;
339 }
340
341 @Override
342 public String toString() {
343 return "Put(" + SshjSshClient.this.toString() + ")[" + path + "]";
344 }
345 };
346
347 @Override
348 public void put(String path, Payload contents) {
349 acquire(new PutConnection(path, contents));
350 }
351
352 @VisibleForTesting
353 boolean shouldRetry(Exception from) {
354 Predicate<Throwable> predicate = retryAuth ? Predicates.<Throwable>or(retryPredicate, instanceOf(AuthorizationException.class))
355 : retryPredicate;
356 if (any(getCausalChain(from), predicate))
357 return true;
358 if (!retryableMessages.equals(""))
359 return any(Splitter.on(",").split(retryableMessages), causalChainHasMessageContaining(from));
360 return false;
361 }
362
363 @VisibleForTesting
364 Predicate<String> causalChainHasMessageContaining(final Exception from) {
365 return new Predicate<String>() {
366
367 @Override
368 public boolean apply(final String input) {
369 return any(getCausalChain(from), new Predicate<Throwable>() {
370
371 @Override
372 public boolean apply(Throwable arg0) {
373 return arg0.getMessage() != null && arg0.getMessage().indexOf(input) != -1;
374 }
375
376 });
377 }
378
379 };
380 }
381
382 @VisibleForTesting
383 SshException propagate(Exception e, String message) {
384 message += ": " + e.getMessage();
385 logger.error(e, "<< " + message);
386 if (e instanceof UserAuthException)
387 throw new AuthorizationException("(" + toString() + ") " + message, e);
388 throw e instanceof SshException ? SshException.class.cast(e) : new SshException(
389 "(" + toString() + ") " + message, e);
390 }
391
392 @Override
393 public String toString() {
394 return String.format("%s@%s:%d", username, host, port);
395 }
396
397 @PreDestroy
398 public void disconnect() {
399 try {
400 sshConnection.clear();
401 } catch (Exception e) {
402 Throwables.propagate(e);
403 }
404 }
405
406 protected Connection<Session> execConnection() {
407
408 return new Connection<Session>() {
409
410 private Session session = null;
411
412 @Override
413 public void clear() throws TransportException, ConnectionException {
414 if (session != null)
415 session.close();
416 }
417
418 @Override
419 public Session create() throws Exception {
420 checkConnected();
421 session = ssh.startSession();
422 session.allocateDefaultPTY();
423 return session;
424 }
425
426 @Override
427 public String toString() {
428 return "Session(" + SshjSshClient.this.toString() + ")";
429 }
430 };
431
432 }
433
434 class ExecConnection implements Connection<ExecResponse> {
435 private final String command;
436 private Session session;
437
438 ExecConnection(String command) {
439 this.command = checkNotNull(command, "command");
440 }
441
442 @Override
443 public void clear() throws TransportException, ConnectionException {
444 if (session != null)
445 session.close();
446 }
447
448 @Override
449 public ExecResponse create() throws Exception {
450 try {
451 session = acquire(execConnection());
452 Command output = session.exec(checkNotNull(command, "command"));
453 String outputString = IOUtils.readFully(output.getInputStream()).toString();
454 output.join(timeoutMillis, TimeUnit.SECONDS);
455 int errorStatus = output.getExitStatus();
456 String errorString = IOUtils.readFully(output.getErrorStream()).toString();
457 return new ExecResponse(outputString, errorString, errorStatus);
458 } finally {
459 clear();
460 }
461 }
462
463 @Override
464 public String toString() {
465 return "ExecResponse(" + SshjSshClient.this.toString() + ")[" + command + "]";
466 }
467 }
468
469 public ExecResponse exec(String command) {
470 return acquire(new ExecConnection(command));
471 }
472
473 @Override
474 public String getHostAddress() {
475 return this.host;
476 }
477
478 @Override
479 public String getUsername() {
480 return this.username;
481 }
482
483 }