View Javadoc

1   /**
2    * Licensed to jclouds, Inc. (jclouds) under one or more
3    * contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  jclouds licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *   http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing,
13   * software distributed under the License is distributed on an
14   * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   * KIND, either express or implied.  See the License for the
16   * specific language governing permissions and limitations
17   * under the License.
18   */
19  package org.jclouds.ssh.jsch;
20  
21  import static com.google.common.base.Preconditions.checkArgument;
22  import static com.google.common.base.Preconditions.checkNotNull;
23  import static com.google.common.base.Preconditions.checkState;
24  import static com.google.common.base.Predicates.instanceOf;
25  import static com.google.common.base.Predicates.or;
26  import static com.google.common.base.Throwables.getCausalChain;
27  import static com.google.common.collect.Iterables.any;
28  import static org.jclouds.crypto.SshKeys.fingerprintPrivateKey;
29  import static org.jclouds.crypto.SshKeys.sha1PrivateKey;
30  
31  import java.io.IOException;
32  import java.io.InputStream;
33  import java.net.ConnectException;
34  import java.util.Arrays;
35  
36  import javax.annotation.PostConstruct;
37  import javax.annotation.PreDestroy;
38  import javax.annotation.Resource;
39  import javax.inject.Named;
40  
41  import org.apache.commons.io.input.ProxyInputStream;
42  import org.apache.commons.io.output.ByteArrayOutputStream;
43  import org.jclouds.compute.domain.ExecResponse;
44  import org.jclouds.http.handlers.BackoffLimitedRetryHandler;
45  import org.jclouds.io.Payload;
46  import org.jclouds.io.Payloads;
47  import org.jclouds.logging.Logger;
48  import org.jclouds.net.IPSocket;
49  import org.jclouds.rest.AuthorizationException;
50  import org.jclouds.ssh.SshClient;
51  import org.jclouds.ssh.SshException;
52  import org.jclouds.util.CredentialUtils;
53  import org.jclouds.util.Strings2;
54  
55  import com.google.common.annotations.VisibleForTesting;
56  import com.google.common.base.Predicate;
57  import com.google.common.base.Predicates;
58  import com.google.common.base.Splitter;
59  import com.google.common.io.Closeables;
60  import com.google.inject.Inject;
61  import com.jcraft.jsch.ChannelExec;
62  import com.jcraft.jsch.ChannelSftp;
63  import com.jcraft.jsch.JSch;
64  import com.jcraft.jsch.JSchException;
65  import com.jcraft.jsch.Session;
66  
67  /**
68   * This class needs refactoring. It is not thread safe.
69   * 
70   * @author Adrian Cole
71   */
72  public class JschSshClient implements SshClient {
73  
74     private final class CloseFtpChannelOnCloseInputStream extends ProxyInputStream {
75  
76        private final ChannelSftp sftp;
77  
78        private CloseFtpChannelOnCloseInputStream(InputStream proxy, ChannelSftp sftp) {
79           super(proxy);
80           this.sftp = sftp;
81        }
82  
83        @Override
84        public void close() throws IOException {
85           super.close();
86           if (sftp != null)
87              sftp.disconnect();
88        }
89     }
90  
91     private final String host;
92     private final int port;
93     private final String username;
94     private final String password;
95     private final String toString;
96  
97     @Inject(optional = true)
98     @Named("jclouds.ssh.max-retries")
99     @VisibleForTesting
100    int sshRetries = 5;
101 
102    @Inject(optional = true)
103    @Named("jclouds.ssh.retry-auth")
104    @VisibleForTesting
105    boolean retryAuth;
106 
107    @Inject(optional = true)
108    @Named("jclouds.ssh.retryable-messages")
109    @VisibleForTesting
110    String retryableMessages = "failed to send channel request,channel is not opened,invalid data,End of IO Stream Read,Connection reset,connection is closed by foreign host,socket is not established";
111 
112    @Inject(optional = true)
113    @Named("jclouds.ssh.retry-predicate")
114    Predicate<Throwable> retryPredicate = or(instanceOf(ConnectException.class), instanceOf(IOException.class));
115 
116    @Resource
117    @Named("jclouds.ssh")
118    protected Logger logger = Logger.NULL;
119 
120    private Session session;
121    private final byte[] privateKey;
122    final byte[] emptyPassPhrase = new byte[0];
123    private final int timeout;
124    private final BackoffLimitedRetryHandler backoffLimitedRetryHandler;
125 
126    public JschSshClient(BackoffLimitedRetryHandler backoffLimitedRetryHandler, IPSocket socket, int timeout,
127          String username, String password, byte[] privateKey) {
128       this.host = checkNotNull(socket, "socket").getAddress();
129       checkArgument(socket.getPort() > 0, "ssh port must be greater then zero" + socket.getPort());
130       checkArgument(password != null || privateKey != null, "you must specify a password or a key");
131       this.port = socket.getPort();
132       this.username = checkNotNull(username, "username");
133       this.backoffLimitedRetryHandler = checkNotNull(backoffLimitedRetryHandler, "backoffLimitedRetryHandler");
134       this.timeout = timeout;
135       this.password = password;
136       this.privateKey = privateKey;
137       if ( privateKey==null ) {
138           this.toString = String.format("%s:password@%s:%d", username, host, port);
139       } else {
140           String fingerPrint = fingerprintPrivateKey(new String(privateKey));
141           String sha1 = sha1PrivateKey(new String(privateKey));
142           this.toString = String.format("%s:rsa[fingerprint(%s),sha1(%s)]@%s:%d", username, fingerPrint, sha1, host,
143                  port);
144       }
145    }
146 
147    @Override
148    public void put(String path, String contents) {
149       put(path, Payloads.newStringPayload(checkNotNull(contents, "contents")));
150    }
151 
152    private void checkConnected() {
153       checkState(session != null && session.isConnected(), String.format("(%s) Session not connected!", toString()));
154    }
155 
156    public static interface Connection<T> {
157       void clear();
158 
159       T create() throws Exception;
160    }
161 
162    Connection<Session> sessionConnection = new Connection<Session>() {
163 
164       @Override
165       public void clear() {
166          if (session != null && session.isConnected()) {
167             session.disconnect();
168             session = null;
169          }
170       }
171 
172       @Override
173       public Session create() throws Exception {
174          JSch jsch = new JSch();
175          session = jsch.getSession(username, host, port);
176          if (timeout != 0)
177             session.setTimeout(timeout);
178          if (password != null) {
179             session.setPassword(password);
180          } else {
181             // jsch wipes out your private key
182             if (CredentialUtils.isPrivateKeyEncrypted(privateKey)) {
183                throw new IllegalArgumentException("JschSshClientModule does not support private keys that require a passphrase");
184             }
185             jsch.addIdentity(username, Arrays.copyOf(privateKey, privateKey.length), null, emptyPassPhrase);
186          }
187          java.util.Properties config = new java.util.Properties();
188          config.put("StrictHostKeyChecking", "no");
189          session.setConfig(config);
190          session.connect(timeout);
191          return session;
192       }
193 
194       @Override
195       public String toString() {
196          return String.format("Session(timeout=%d)", timeout);
197       }
198    };
199 
200    protected <T, C extends Connection<T>> T acquire(C connection) {
201       connection.clear();
202       String errorMessage = String.format("(%s) error acquiring %s", toString(), connection);
203       for (int i = 0; i < sshRetries; i++) {
204          try {
205             logger.debug(">> (%s) acquiring %s", toString(), connection);
206             T returnVal = connection.create();
207             logger.debug("<< (%s) acquired %s", toString(), returnVal);
208             return returnVal;
209          } catch (Exception from) {
210             connection.clear();
211 
212             if (i + 1 == sshRetries) {
213                throw propagate(from, errorMessage);
214             } else if (shouldRetry(from)) {
215                logger.warn(from, "<< " + errorMessage + ": " + from.getMessage());
216                backoffForAttempt(i + 1, errorMessage + ": " + from.getMessage());
217                continue;
218             }
219          }
220       }
221       assert false : "should not reach here";
222       return null;
223    }
224 
225    @PostConstruct
226    public void connect() {
227       acquire(sessionConnection);
228    }
229 
230    Connection<ChannelSftp> sftpConnection = new Connection<ChannelSftp>() {
231 
232       private ChannelSftp sftp;
233 
234       @Override
235       public void clear() {
236          if (sftp != null)
237             sftp.disconnect();
238       }
239 
240       @Override
241       public ChannelSftp create() throws JSchException {
242          checkConnected();
243          String channel = "sftp";
244          sftp = (ChannelSftp) session.openChannel(channel);
245          sftp.connect();
246          return sftp;
247       }
248 
249       @Override
250       public String toString() {
251          return "ChannelSftp()";
252       }
253    };
254 
255    class GetConnection implements Connection<Payload> {
256       private final String path;
257       private ChannelSftp sftp;
258 
259       GetConnection(String path) {
260          this.path = checkNotNull(path, "path");
261       }
262 
263       @Override
264       public void clear() {
265          if (sftp != null)
266             sftp.disconnect();
267       }
268 
269       @Override
270       public Payload create() throws Exception {
271          sftp = acquire(sftpConnection);
272          return Payloads.newInputStreamPayload(new CloseFtpChannelOnCloseInputStream(sftp.get(path), sftp));
273       }
274 
275       @Override
276       public String toString() {
277          return "Payload(path=[" + path + "])";
278       }
279    };
280 
281    public Payload get(String path) {
282       return acquire(new GetConnection(path));
283    }
284 
285    class PutConnection implements Connection<Void> {
286       private final String path;
287       private final Payload contents;
288       private ChannelSftp sftp;
289 
290       PutConnection(String path, Payload contents) {
291          this.path = checkNotNull(path, "path");
292          this.contents = checkNotNull(contents, "contents");
293       }
294 
295       @Override
296       public void clear() {
297          if (sftp != null)
298             sftp.disconnect();
299       }
300 
301       @Override
302       public Void create() throws Exception {
303          sftp = acquire(sftpConnection);
304          InputStream is = checkNotNull(contents.getInput(), "inputstream for path %s", path);
305          try {
306             sftp.put(is, path);
307          } finally {
308             Closeables.closeQuietly(contents);
309          }
310          return null;
311       }
312 
313       @Override
314       public String toString() {
315          return "Put(path=[" + path + "])";
316       }
317    };
318 
319    @Override
320    public void put(String path, Payload contents) {
321       acquire(new PutConnection(path, contents));
322    }
323 
324    @VisibleForTesting
325    boolean shouldRetry(Exception from) {
326       Predicate<Throwable> predicate = retryAuth ?  Predicates.<Throwable>or(retryPredicate, instanceOf(AuthorizationException.class))
327             : retryPredicate;
328       if (any(getCausalChain(from), predicate))
329          return true;
330       if (!retryableMessages.equals(""))
331          return any(Splitter.on(",").split(retryableMessages), causalChainHasMessageContaining(from));
332       return false;
333    }
334 
335    @VisibleForTesting
336    Predicate<String> causalChainHasMessageContaining(final Exception from) {
337       return new Predicate<String>() {
338 
339          @Override
340          public boolean apply(final String input) {
341             return any(getCausalChain(from), new Predicate<Throwable>() {
342 
343                @Override
344                public boolean apply(Throwable arg0) {
345                   return arg0.getMessage() != null && arg0.getMessage().indexOf(input) != -1;
346                }
347 
348             });
349          }
350 
351       };
352    }
353 
354    private void backoffForAttempt(int retryAttempt, String message) {
355       backoffLimitedRetryHandler.imposeBackoffExponentialDelay(200L, 2, retryAttempt, sshRetries, message);
356    }
357 
358    SshException propagate(Exception e, String message) {
359       message += ": " + e.getMessage();
360       if (e.getMessage() != null && e.getMessage().indexOf("Auth fail") != -1)
361          throw new AuthorizationException("(" + toString() + ") " + message, e);
362       throw e instanceof SshException ? SshException.class.cast(e) : new SshException(
363             "(" + toString() + ") " + message, e);
364    }
365 
366    @Override
367    public String toString() {
368       return toString;
369    }
370 
371    @PreDestroy
372    public void disconnect() {
373       sessionConnection.clear();
374    }
375 
376    protected Connection<ChannelExec> execConnection(final String command) {
377       checkNotNull(command, "command");
378       return new Connection<ChannelExec>() {
379 
380          private ChannelExec executor = null;
381 
382          @Override
383          public void clear() {
384             if (executor != null)
385                executor.disconnect();
386          }
387 
388          @Override
389          public ChannelExec create() throws Exception {
390             checkConnected();
391             String channel = "exec";
392             executor = (ChannelExec) session.openChannel(channel);
393             executor.setPty(true);
394             executor.setCommand(command);
395             ByteArrayOutputStream error = new ByteArrayOutputStream();
396             executor.setErrStream(error);
397             executor.connect();
398             return executor;
399          }
400 
401          @Override
402          public String toString() {
403             return "ChannelExec()";
404          }
405       };
406 
407    }
408 
409    class ExecConnection implements Connection<ExecResponse> {
410       private final String command;
411       private ChannelExec executor;
412 
413       ExecConnection(String command) {
414          this.command = checkNotNull(command, "command");
415       }
416 
417       @Override
418       public void clear() {
419          if (executor != null)
420             executor.disconnect();
421       }
422 
423       @Override
424       public ExecResponse create() throws Exception {
425          try {
426             executor = acquire(execConnection(command));
427             String outputString = Strings2.toStringAndClose(executor.getInputStream());
428             int errorStatus = executor.getExitStatus();
429             int i = 0;
430             String message = String.format("bad status -1 %s", toString());
431             while ((errorStatus = executor.getExitStatus()) == -1 && i < JschSshClient.this.sshRetries) {
432                logger.warn("<< " + message);
433                backoffForAttempt(++i, message);
434             }
435             if (errorStatus == -1)
436                throw new SshException(message);
437             // be careful as this can hang reading
438             // com.jcraft.jsch.Channel$MyPipedInputStream when there's a slow
439             // network connection
440             // String errorString =
441             // Strings2.toStringAndClose(executor.getErrStream());
442             String errorString = "";
443             return new ExecResponse(outputString, errorString, errorStatus);
444          } finally {
445             clear();
446          }
447       }
448 
449       @Override
450       public String toString() {
451          return "ExecResponse(command=[" + command + "])";
452       }
453    }
454 
455    public ExecResponse exec(String command) {
456       return acquire(new ExecConnection(command));
457    }
458 
459    @Override
460    public String getHostAddress() {
461       return this.host;
462    }
463 
464    @Override
465    public String getUsername() {
466       return this.username;
467    }
468 
469 }