View Javadoc

1   /**
2    *
3    * Copyright (C) 2011 Cloud Conscious, LLC. <info@cloudconscious.com>
4    *
5    * ====================================================================
6    * Licensed under the Apache License, Version 2.0 (the "License");
7    * you may not use this file except in compliance with the License.
8    * You may obtain a copy of the License at
9    *
10   * http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing, software
13   * distributed under the License is distributed on an "AS IS" BASIS,
14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15   * See the License for the specific language governing permissions and
16   * limitations under the License.
17   * ====================================================================
18   */
19  package org.jclouds.sshj;
20  
21  import static com.google.common.base.Preconditions.checkArgument;
22  import static com.google.common.base.Preconditions.checkNotNull;
23  import static com.google.common.base.Preconditions.checkState;
24  import static com.google.common.base.Predicates.instanceOf;
25  import static com.google.common.base.Predicates.or;
26  import static com.google.common.base.Throwables.getCausalChain;
27  import static com.google.common.collect.Iterables.any;
28  
29  import java.io.IOException;
30  import java.io.InputStream;
31  import java.util.concurrent.TimeUnit;
32  
33  import javax.annotation.PostConstruct;
34  import javax.annotation.PreDestroy;
35  import javax.annotation.Resource;
36  import javax.inject.Named;
37  
38  import net.schmizz.sshj.common.IOUtils;
39  import net.schmizz.sshj.connection.ConnectionException;
40  import net.schmizz.sshj.connection.channel.direct.Session;
41  import net.schmizz.sshj.connection.channel.direct.Session.Command;
42  import net.schmizz.sshj.sftp.SFTPClient;
43  import net.schmizz.sshj.transport.TransportException;
44  import net.schmizz.sshj.transport.verification.PromiscuousVerifier;
45  import net.schmizz.sshj.userauth.UserAuthException;
46  import net.schmizz.sshj.userauth.keyprovider.OpenSSHKeyFile;
47  import net.schmizz.sshj.xfer.InMemorySourceFile;
48  
49  import org.apache.commons.io.input.ProxyInputStream;
50  import org.jclouds.compute.domain.ExecResponse;
51  import org.jclouds.http.handlers.BackoffLimitedRetryHandler;
52  import org.jclouds.io.Payload;
53  import org.jclouds.io.Payloads;
54  import org.jclouds.logging.Logger;
55  import org.jclouds.net.IPSocket;
56  import org.jclouds.rest.AuthorizationException;
57  import org.jclouds.ssh.SshClient;
58  import org.jclouds.ssh.SshException;
59  
60  import com.google.common.annotations.VisibleForTesting;
61  import com.google.common.base.Predicate;
62  import com.google.common.base.Predicates;
63  import com.google.common.base.Splitter;
64  import com.google.common.base.Throwables;
65  import com.google.inject.Inject;
66  
67  /**
68   * This class needs refactoring. It is not thread safe.
69   * 
70   * @author Adrian Cole
71   */
72  public class SshjSshClient implements SshClient {
73  
74     private final class CloseFtpChannelOnCloseInputStream extends ProxyInputStream {
75  
76        private final SFTPClient sftp;
77  
78        private CloseFtpChannelOnCloseInputStream(InputStream proxy, SFTPClient sftp) {
79           super(proxy);
80           this.sftp = sftp;
81        }
82  
83        @Override
84        public void close() throws IOException {
85           super.close();
86           if (sftp != null)
87              sftp.close();
88        }
89     }
90  
91     private final String host;
92     private final int port;
93     private final String username;
94     private final String password;
95  
96     @Inject(optional = true)
97     @Named("jclouds.ssh.max-retries")
98     @VisibleForTesting
99     int sshRetries = 5;
100    
101    @Inject(optional = true)
102    @Named("jclouds.ssh.retry-auth")
103    @VisibleForTesting
104    boolean retryAuth;
105    
106    @Inject(optional = true)
107    @Named("jclouds.ssh.retryable-messages")
108    @VisibleForTesting
109    String retryableMessages = "";
110 
111    @Inject(optional = true)
112    @Named("jclouds.ssh.retry-predicate")
113    // NOTE cannot retry io exceptions, as SSHException is a part of the chain
114    private Predicate<Throwable> retryPredicate = or(instanceOf(ConnectionException.class),
115          instanceOf(TransportException.class));
116 
117    @Resource
118    @Named("jclouds.ssh")
119    protected Logger logger = Logger.NULL;
120 
121    private net.schmizz.sshj.SSHClient ssh;
122    private final byte[] privateKey;
123    final byte[] emptyPassPhrase = new byte[0];
124    private final int timeoutMillis;
125    private final BackoffLimitedRetryHandler backoffLimitedRetryHandler;
126 
127    public SshjSshClient(BackoffLimitedRetryHandler backoffLimitedRetryHandler, IPSocket socket, int timeout,
128          String username, String password, byte[] privateKey) {
129       this.host = checkNotNull(socket, "socket").getAddress();
130       checkArgument(socket.getPort() > 0, "ssh port must be greater then zero" + socket.getPort());
131       checkArgument(password != null || privateKey != null, "you must specify a password or a key");
132       this.port = socket.getPort();
133       this.username = checkNotNull(username, "username");
134       this.backoffLimitedRetryHandler = checkNotNull(backoffLimitedRetryHandler, "backoffLimitedRetryHandler");
135       this.timeoutMillis = timeout;
136       this.password = password;
137       this.privateKey = privateKey;
138    }
139 
140    @Override
141    public void put(String path, String contents) {
142       put(path, Payloads.newStringPayload(checkNotNull(contents, "contents")));
143    }
144 
145    private void checkConnected() {
146       checkState(ssh != null && ssh.isConnected(), String.format("(%s) ssh not connected!", toString()));
147    }
148 
149    public static interface Connection<T> {
150       void clear() throws Exception;
151 
152       T create() throws Exception;
153    }
154 
155    Connection<net.schmizz.sshj.SSHClient> sshConnection = new Connection<net.schmizz.sshj.SSHClient>() {
156 
157       @Override
158       public void clear() {
159          if (ssh != null && ssh.isConnected()) {
160             try {
161                ssh.disconnect();
162             } catch (IOException e) {
163                Throwables.propagate(e);
164             }
165             ssh = null;
166          }
167       }
168 
169       @Override
170       public net.schmizz.sshj.SSHClient create() throws Exception {
171          net.schmizz.sshj.SSHClient ssh = new net.schmizz.sshj.SSHClient();
172          ssh.addHostKeyVerifier(new PromiscuousVerifier());
173          if (timeoutMillis != 0) {
174             ssh.setTimeout(timeoutMillis);
175             ssh.setConnectTimeout(timeoutMillis);
176          }
177          ssh.connect(host, port);
178          if (password != null) {
179             ssh.authPassword(username, password);
180          } else {
181             OpenSSHKeyFile key = new OpenSSHKeyFile();
182             key.init(new String(privateKey), null);
183             ssh.authPublickey(username, key);
184          }
185          return ssh;
186       }
187 
188       @Override
189       public String toString() {
190          return String.format("SSHClient(%s)", SshjSshClient.this.toString());
191       }
192    };
193 
194    private void backoffForAttempt(int retryAttempt, String message) {
195       backoffLimitedRetryHandler.imposeBackoffExponentialDelay(200L, 2, retryAttempt, sshRetries, message);
196    }
197 
198    protected <T, C extends Connection<T>> T acquire(C connection) {
199       String errorMessage = String.format("(%s) error acquiring %s", toString(), connection);
200       for (int i = 0; i < sshRetries; i++) {
201          try {
202             connection.clear();
203             logger.debug(">> (%s) acquiring %s", toString(), connection);
204             T returnVal = connection.create();
205             logger.debug("<< (%s) acquired %s", toString(), returnVal);
206             return returnVal;
207          } catch (Exception from) {
208             try {
209                connection.clear();
210             } catch (Exception e1) {
211                logger.warn(from, "<< (%s) error closing connection", toString());
212             }
213             if (i + 1 == sshRetries) {
214                throw propagate(from, errorMessage);
215             } else if (shouldRetry(from)) {
216                logger.warn(from, "<< " + errorMessage + ": " + from.getMessage());
217                backoffForAttempt(i + 1, errorMessage + ": " + from.getMessage());
218                continue;
219             }
220          }
221       }
222       assert false : "should not reach here";
223       return null;
224    }
225 
226    @PostConstruct
227    public void connect() {
228       try {
229          ssh = acquire(sshConnection);
230       } catch (Exception e) {
231          Throwables.propagate(e);
232       }
233    }
234 
235    Connection<SFTPClient> sftpConnection = new Connection<SFTPClient>() {
236 
237       private SFTPClient sftp;
238 
239       @Override
240       public void clear() {
241          if (sftp != null)
242             try {
243                sftp.close();
244             } catch (IOException e) {
245                Throwables.propagate(e);
246             }
247       }
248 
249       @Override
250       public SFTPClient create() throws IOException {
251          checkConnected();
252          sftp = ssh.newSFTPClient();
253          return sftp;
254       }
255 
256       @Override
257       public String toString() {
258          return "SFTPClient(" + SshjSshClient.this.toString() + ")";
259       }
260    };
261 
262    class GetConnection implements Connection<Payload> {
263       private final String path;
264       private SFTPClient sftp;
265 
266       GetConnection(String path) {
267          this.path = checkNotNull(path, "path");
268       }
269 
270       @Override
271       public void clear() throws IOException {
272          if (sftp != null)
273             sftp.close();
274       }
275 
276       @Override
277       public Payload create() throws Exception {
278          sftp = acquire(sftpConnection);
279          return Payloads.newInputStreamPayload(new CloseFtpChannelOnCloseInputStream(sftp.getSFTPEngine().open(path)
280                .getInputStream(), sftp));
281       }
282 
283       @Override
284       public String toString() {
285          return "Payload(" + SshjSshClient.this.toString() + ")[" + path + "]";
286       }
287    };
288 
289    public Payload get(String path) {
290       return acquire(new GetConnection(path));
291    }
292 
293    class PutConnection implements Connection<Void> {
294       private final String path;
295       private final Payload contents;
296       private SFTPClient sftp;
297 
298       PutConnection(String path, Payload contents) {
299          this.path = checkNotNull(path, "path");
300          this.contents = checkNotNull(contents, "contents");
301       }
302 
303       @Override
304       public void clear() {
305          if (sftp != null)
306             try {
307                sftp.close();
308             } catch (IOException e) {
309                Throwables.propagate(e);
310             }
311       }
312 
313       @Override
314       public Void create() throws Exception {
315          sftp = acquire(sftpConnection);
316          try {
317             sftp.put(new InMemorySourceFile() {
318 
319                @Override
320                public String getName() {
321                   return path;
322                }
323 
324                @Override
325                public long getLength() {
326                   return contents.getContentMetadata().getContentLength();
327                }
328 
329                @Override
330                public InputStream getInputStream() throws IOException {
331                   return checkNotNull(contents.getInput(), "inputstream for path %s", path);
332                }
333 
334             }, path);
335          } finally {
336             contents.release();
337          }
338          return null;
339       }
340 
341       @Override
342       public String toString() {
343          return "Put(" + SshjSshClient.this.toString() + ")[" + path + "]";
344       }
345    };
346 
347    @Override
348    public void put(String path, Payload contents) {
349       acquire(new PutConnection(path, contents));
350    }
351 
352    @VisibleForTesting
353    boolean shouldRetry(Exception from) {
354       Predicate<Throwable> predicate = retryAuth ?  Predicates.<Throwable>or(retryPredicate, instanceOf(AuthorizationException.class))
355             : retryPredicate;
356       if (any(getCausalChain(from), predicate))
357          return true;
358       if (!retryableMessages.equals(""))
359          return any(Splitter.on(",").split(retryableMessages), causalChainHasMessageContaining(from));
360       return false;
361    }
362 
363    @VisibleForTesting
364    Predicate<String> causalChainHasMessageContaining(final Exception from) {
365       return new Predicate<String>() {
366 
367          @Override
368          public boolean apply(final String input) {
369             return any(getCausalChain(from), new Predicate<Throwable>() {
370 
371                @Override
372                public boolean apply(Throwable arg0) {
373                   return arg0.getMessage() != null && arg0.getMessage().indexOf(input) != -1;
374                }
375 
376             });
377          }
378 
379       };
380    }
381 
382    @VisibleForTesting
383    SshException propagate(Exception e, String message) {
384       message += ": " + e.getMessage();
385       logger.error(e, "<< " + message);
386       if (e instanceof UserAuthException)
387          throw new AuthorizationException("(" + toString() + ") " + message, e);
388       throw e instanceof SshException ? SshException.class.cast(e) : new SshException(
389             "(" + toString() + ") " + message, e);
390    }
391 
392    @Override
393    public String toString() {
394       return String.format("%s@%s:%d", username, host, port);
395    }
396 
397    @PreDestroy
398    public void disconnect() {
399       try {
400          sshConnection.clear();
401       } catch (Exception e) {
402          Throwables.propagate(e);
403       }
404    }
405 
406    protected Connection<Session> execConnection() {
407 
408       return new Connection<Session>() {
409 
410          private Session session = null;
411 
412          @Override
413          public void clear() throws TransportException, ConnectionException {
414             if (session != null)
415                session.close();
416          }
417 
418          @Override
419          public Session create() throws Exception {
420             checkConnected();
421             session = ssh.startSession();
422             session.allocateDefaultPTY();
423             return session;
424          }
425 
426          @Override
427          public String toString() {
428             return "Session(" + SshjSshClient.this.toString() + ")";
429          }
430       };
431 
432    }
433 
434    class ExecConnection implements Connection<ExecResponse> {
435       private final String command;
436       private Session session;
437 
438       ExecConnection(String command) {
439          this.command = checkNotNull(command, "command");
440       }
441 
442       @Override
443       public void clear() throws TransportException, ConnectionException {
444          if (session != null)
445             session.close();
446       }
447 
448       @Override
449       public ExecResponse create() throws Exception {
450          try {
451             session = acquire(execConnection());
452             Command output = session.exec(checkNotNull(command, "command"));
453             String outputString = IOUtils.readFully(output.getInputStream()).toString();
454             output.join(timeoutMillis, TimeUnit.SECONDS);
455             int errorStatus = output.getExitStatus();
456             String errorString = IOUtils.readFully(output.getErrorStream()).toString();
457             return new ExecResponse(outputString, errorString, errorStatus);
458          } finally {
459             clear();
460          }
461       }
462 
463       @Override
464       public String toString() {
465          return "ExecResponse(" + SshjSshClient.this.toString() + ")[" + command + "]";
466       }
467    }
468 
469    public ExecResponse exec(String command) {
470       return acquire(new ExecConnection(command));
471    }
472 
473    @Override
474    public String getHostAddress() {
475       return this.host;
476    }
477 
478    @Override
479    public String getUsername() {
480       return this.username;
481    }
482 
483 }