View Javadoc

1   /**
2    * Licensed to jclouds, Inc. (jclouds) under one or more
3    * contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  jclouds licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *   http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing,
13   * software distributed under the License is distributed on an
14   * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   * KIND, either express or implied.  See the License for the
16   * specific language governing permissions and limitations
17   * under the License.
18   */
19  package org.jclouds.sshj;
20  
21  import static com.google.common.base.Preconditions.checkArgument;
22  import static com.google.common.base.Preconditions.checkNotNull;
23  import static com.google.common.base.Preconditions.checkState;
24  import static com.google.common.base.Predicates.instanceOf;
25  import static com.google.common.base.Predicates.or;
26  import static com.google.common.base.Throwables.getCausalChain;
27  import static com.google.common.collect.Iterables.any;
28  import static org.jclouds.crypto.SshKeys.fingerprintPrivateKey;
29  import static org.jclouds.crypto.SshKeys.sha1PrivateKey;
30  
31  import java.io.IOException;
32  import java.io.InputStream;
33  import java.net.ConnectException;
34  import java.net.SocketTimeoutException;
35  import java.util.concurrent.TimeUnit;
36  
37  import javax.annotation.PostConstruct;
38  import javax.annotation.PreDestroy;
39  import javax.annotation.Resource;
40  import javax.inject.Named;
41  
42  import net.schmizz.sshj.common.IOUtils;
43  import net.schmizz.sshj.connection.ConnectionException;
44  import net.schmizz.sshj.connection.channel.direct.Session;
45  import net.schmizz.sshj.connection.channel.direct.Session.Command;
46  import net.schmizz.sshj.sftp.SFTPClient;
47  import net.schmizz.sshj.transport.TransportException;
48  import net.schmizz.sshj.transport.verification.PromiscuousVerifier;
49  import net.schmizz.sshj.userauth.UserAuthException;
50  import net.schmizz.sshj.userauth.keyprovider.OpenSSHKeyFile;
51  import net.schmizz.sshj.xfer.InMemorySourceFile;
52  
53  import org.apache.commons.io.input.ProxyInputStream;
54  import org.jclouds.compute.domain.ExecResponse;
55  import org.jclouds.http.handlers.BackoffLimitedRetryHandler;
56  import org.jclouds.io.Payload;
57  import org.jclouds.io.Payloads;
58  import org.jclouds.logging.Logger;
59  import org.jclouds.net.IPSocket;
60  import org.jclouds.rest.AuthorizationException;
61  import org.jclouds.ssh.SshClient;
62  import org.jclouds.ssh.SshException;
63  import org.jclouds.util.Throwables2;
64  
65  import com.google.common.annotations.VisibleForTesting;
66  import com.google.common.base.Predicate;
67  import com.google.common.base.Predicates;
68  import com.google.common.base.Splitter;
69  import com.google.common.base.Throwables;
70  import com.google.inject.Inject;
71  
72  /**
73   * This class needs refactoring. It is not thread safe.
74   * 
75   * @author Adrian Cole
76   */
77  @SuppressWarnings("unchecked")
78  public class SshjSshClient implements SshClient {
79  
80     private final class CloseFtpChannelOnCloseInputStream extends ProxyInputStream {
81  
82        private final SFTPClient sftp;
83  
84        private CloseFtpChannelOnCloseInputStream(InputStream proxy, SFTPClient sftp) {
85           super(proxy);
86           this.sftp = sftp;
87        }
88  
89        @Override
90        public void close() throws IOException {
91           super.close();
92           if (sftp != null)
93              sftp.close();
94        }
95     }
96  
97     private final String host;
98     private final int port;
99     private final String username;
100    private final String password;
101    private final String toString;
102 
103    @Inject(optional = true)
104    @Named("jclouds.ssh.max-retries")
105    @VisibleForTesting
106    int sshRetries = 5;
107 
108    @Inject(optional = true)
109    @Named("jclouds.ssh.retry-auth")
110    @VisibleForTesting
111    boolean retryAuth;
112 
113    @Inject(optional = true)
114    @Named("jclouds.ssh.retryable-messages")
115    @VisibleForTesting
116    String retryableMessages = "";
117 
118    @Inject(optional = true)
119    @Named("jclouds.ssh.retry-predicate")
120    // NOTE cannot retry io exceptions, as SSHException is a part of the chain
121    private Predicate<Throwable> retryPredicate = or(instanceOf(ConnectionException.class),
122             instanceOf(ConnectException.class), instanceOf(SocketTimeoutException.class),
123             instanceOf(TransportException.class));
124 
125    @Resource
126    @Named("jclouds.ssh")
127    protected Logger logger = Logger.NULL;
128 
129    private net.schmizz.sshj.SSHClient ssh;
130    private final byte[] privateKey;
131    final byte[] emptyPassPhrase = new byte[0];
132    private final int timeoutMillis;
133    private final BackoffLimitedRetryHandler backoffLimitedRetryHandler;
134 
135    public SshjSshClient(BackoffLimitedRetryHandler backoffLimitedRetryHandler, IPSocket socket, int timeout,
136             String username, String password, byte[] privateKey) {
137       this.host = checkNotNull(socket, "socket").getAddress();
138       checkArgument(socket.getPort() > 0, "ssh port must be greater then zero" + socket.getPort());
139       checkArgument(password != null || privateKey != null, "you must specify a password or a key");
140       this.port = socket.getPort();
141       this.username = checkNotNull(username, "username");
142       this.backoffLimitedRetryHandler = checkNotNull(backoffLimitedRetryHandler, "backoffLimitedRetryHandler");
143       this.timeoutMillis = timeout;
144       this.password = password;
145       this.privateKey = privateKey;
146       if (privateKey == null) {
147          this.toString = String.format("%s:password@%s:%d", username, host, port);
148       } else {
149          String fingerPrint = fingerprintPrivateKey(new String(privateKey));
150          String sha1 = sha1PrivateKey(new String(privateKey));
151          this.toString = String.format("%s:rsa[fingerprint(%s),sha1(%s)]@%s:%d", username, fingerPrint, sha1, host,
152                   port);
153       }
154    }
155 
156    @Override
157    public void put(String path, String contents) {
158       put(path, Payloads.newStringPayload(checkNotNull(contents, "contents")));
159    }
160 
161    private void checkConnected() {
162       checkState(ssh != null && ssh.isConnected(), String.format("(%s) ssh not connected!", toString()));
163    }
164 
165    public static interface Connection<T> {
166       void clear() throws Exception;
167 
168       T create() throws Exception;
169    }
170 
171    Connection<net.schmizz.sshj.SSHClient> sshConnection = new Connection<net.schmizz.sshj.SSHClient>() {
172 
173       @Override
174       public void clear() {
175          if (ssh != null && ssh.isConnected()) {
176             try {
177                ssh.disconnect();
178             } catch (IOException e) {
179                Throwables.propagate(e);
180             }
181             ssh = null;
182          }
183       }
184 
185       @Override
186       public net.schmizz.sshj.SSHClient create() throws Exception {
187          net.schmizz.sshj.SSHClient ssh = new net.schmizz.sshj.SSHClient();
188          ssh.addHostKeyVerifier(new PromiscuousVerifier());
189          if (timeoutMillis != 0) {
190             ssh.setTimeout(timeoutMillis);
191             ssh.setConnectTimeout(timeoutMillis);
192          }
193          ssh.connect(host, port);
194          if (password != null) {
195             ssh.authPassword(username, password);
196          } else {
197             OpenSSHKeyFile key = new OpenSSHKeyFile();
198             key.init(new String(privateKey), null);
199             ssh.authPublickey(username, key);
200          }
201          return ssh;
202       }
203 
204       @Override
205       public String toString() {
206          return String.format("SSHClient(timeout=%d)", timeoutMillis);
207       }
208    };
209 
210    private void backoffForAttempt(int retryAttempt, String message) {
211       backoffLimitedRetryHandler.imposeBackoffExponentialDelay(200L, 2, retryAttempt, sshRetries, message);
212    }
213 
214    protected <T, C extends Connection<T>> T acquire(C connection) {
215       String errorMessage = String.format("(%s) error acquiring %s", toString(), connection);
216       for (int i = 0; i < sshRetries; i++) {
217          try {
218             connection.clear();
219             logger.debug(">> (%s) acquiring %s", toString(), connection);
220             T returnVal = connection.create();
221             logger.debug("<< (%s) acquired %s", toString(), returnVal);
222             return returnVal;
223          } catch (Exception from) {
224             try {
225                connection.clear();
226             } catch (Exception e1) {
227                logger.warn(from, "<< (%s) error closing connection", toString());
228             }
229             if (i + 1 == sshRetries) {
230                logger.error(from, "<< " + errorMessage + ": out of retries %d", sshRetries);
231                throw propagate(from, errorMessage);
232             } else if (Throwables2.getFirstThrowableOfType(from, IllegalStateException.class) != null) {
233                logger.warn(from, "<< " + errorMessage + ": " + from.getMessage());
234                disconnect();
235                backoffForAttempt(i + 1, errorMessage + ": " + from.getMessage());
236                connect();
237                continue;
238             } else if (shouldRetry(from)) {
239                logger.warn(from, "<< " + errorMessage + ": " + from.getMessage());
240                backoffForAttempt(i + 1, errorMessage + ": " + from.getMessage());
241                continue;
242             } else {
243                logger.error(from, "<< " + errorMessage + ": exception not retryable");
244                throw propagate(from, errorMessage);
245             }
246          }
247       }
248       assert false : "should not reach here";
249       return null;
250    }
251 
252    @PostConstruct
253    public void connect() {
254       try {
255          ssh = acquire(sshConnection);
256       } catch (Exception e) {
257          Throwables.propagate(e);
258       }
259    }
260 
261    Connection<SFTPClient> sftpConnection = new Connection<SFTPClient>() {
262 
263       private SFTPClient sftp;
264 
265       @Override
266       public void clear() {
267          if (sftp != null)
268             try {
269                sftp.close();
270             } catch (IOException e) {
271                Throwables.propagate(e);
272             }
273       }
274 
275       @Override
276       public SFTPClient create() throws IOException {
277          checkConnected();
278          sftp = ssh.newSFTPClient();
279          return sftp;
280       }
281 
282       @Override
283       public String toString() {
284          return "SFTPClient()";
285       }
286    };
287 
288    class GetConnection implements Connection<Payload> {
289       private final String path;
290       private SFTPClient sftp;
291 
292       GetConnection(String path) {
293          this.path = checkNotNull(path, "path");
294       }
295 
296       @Override
297       public void clear() throws IOException {
298          if (sftp != null)
299             sftp.close();
300       }
301 
302       @Override
303       public Payload create() throws Exception {
304          sftp = acquire(sftpConnection);
305          return Payloads.newInputStreamPayload(new CloseFtpChannelOnCloseInputStream(sftp.getSFTPEngine().open(path)
306                   .getInputStream(), sftp));
307       }
308 
309       @Override
310       public String toString() {
311          return "Payload(path=[" + path + "])";
312       }
313    };
314 
315    public Payload get(String path) {
316       return acquire(new GetConnection(path));
317    }
318 
319    class PutConnection implements Connection<Void> {
320       private final String path;
321       private final Payload contents;
322       private SFTPClient sftp;
323 
324       PutConnection(String path, Payload contents) {
325          this.path = checkNotNull(path, "path");
326          this.contents = checkNotNull(contents, "contents");
327       }
328 
329       @Override
330       public void clear() {
331          if (sftp != null)
332             try {
333                sftp.close();
334             } catch (IOException e) {
335                Throwables.propagate(e);
336             }
337       }
338 
339       @Override
340       public Void create() throws Exception {
341          sftp = acquire(sftpConnection);
342          try {
343             sftp.put(new InMemorySourceFile() {
344 
345                @Override
346                public String getName() {
347                   return path;
348                }
349 
350                @Override
351                public long getLength() {
352                   return contents.getContentMetadata().getContentLength();
353                }
354 
355                @Override
356                public InputStream getInputStream() throws IOException {
357                   return checkNotNull(contents.getInput(), "inputstream for path %s", path);
358                }
359 
360             }, path);
361          } finally {
362             contents.release();
363          }
364          return null;
365       }
366 
367       @Override
368       public String toString() {
369          return "Put(path=[" + path + "])";
370       }
371    };
372 
373    @Override
374    public void put(String path, Payload contents) {
375       acquire(new PutConnection(path, contents));
376    }
377 
378    @VisibleForTesting
379    boolean shouldRetry(Exception from) {
380       Predicate<Throwable> predicate = retryAuth ? Predicates.<Throwable> or(retryPredicate,
381                instanceOf(AuthorizationException.class)) : retryPredicate;
382       if (any(getCausalChain(from), predicate))
383          return true;
384       if (!retryableMessages.equals(""))
385          return any(Splitter.on(",").split(retryableMessages), causalChainHasMessageContaining(from));
386       return false;
387    }
388 
389    @VisibleForTesting
390    Predicate<String> causalChainHasMessageContaining(final Exception from) {
391       return new Predicate<String>() {
392 
393          @Override
394          public boolean apply(final String input) {
395             return any(getCausalChain(from), new Predicate<Throwable>() {
396 
397                @Override
398                public boolean apply(Throwable arg0) {
399                   return arg0.getMessage() != null && arg0.getMessage().indexOf(input) != -1;
400                }
401 
402             });
403          }
404 
405       };
406    }
407 
408    @VisibleForTesting
409    SshException propagate(Exception e, String message) {
410       message += ": " + e.getMessage();
411       logger.error(e, "<< " + message);
412       if (e instanceof UserAuthException)
413          throw new AuthorizationException("(" + toString() + ") " + message, e);
414       throw e instanceof SshException ? SshException.class.cast(e) : new SshException(
415                "(" + toString() + ") " + message, e);
416    }
417 
418    @Override
419    public String toString() {
420       return toString;
421    }
422 
423    @PreDestroy
424    public void disconnect() {
425       try {
426          sshConnection.clear();
427       } catch (Exception e) {
428          Throwables.propagate(e);
429       }
430    }
431 
432    protected Connection<Session> execConnection() {
433 
434       return new Connection<Session>() {
435 
436          private Session session = null;
437 
438          @Override
439          public void clear() throws TransportException, ConnectionException {
440             if (session != null)
441                session.close();
442          }
443 
444          @Override
445          public Session create() throws Exception {
446             checkConnected();
447             session = ssh.startSession();
448             session.allocateDefaultPTY();
449             return session;
450          }
451 
452          @Override
453          public String toString() {
454             return "Session()";
455          }
456       };
457 
458    }
459 
460    class ExecConnection implements Connection<ExecResponse> {
461       private final String command;
462       private Session session;
463 
464       ExecConnection(String command) {
465          this.command = checkNotNull(command, "command");
466       }
467 
468       @Override
469       public void clear() throws TransportException, ConnectionException {
470          if (session != null)
471             session.close();
472       }
473 
474       @Override
475       public ExecResponse create() throws Exception {
476          try {
477             session = acquire(execConnection());
478             Command output = session.exec(checkNotNull(command, "command"));
479             String outputString = IOUtils.readFully(output.getInputStream()).toString();
480             output.join(timeoutMillis, TimeUnit.SECONDS);
481             int errorStatus = output.getExitStatus();
482             String errorString = IOUtils.readFully(output.getErrorStream()).toString();
483             return new ExecResponse(outputString, errorString, errorStatus);
484          } finally {
485             clear();
486          }
487       }
488 
489       @Override
490       public String toString() {
491          return "ExecResponse(command=[" + command + "])";
492       }
493    }
494 
495    public ExecResponse exec(String command) {
496       return acquire(new ExecConnection(command));
497    }
498 
499    @Override
500    public String getHostAddress() {
501       return this.host;
502    }
503 
504    @Override
505    public String getUsername() {
506       return this.username;
507    }
508 
509 }