connection.cpp

00001 /*
00002   Copyright (c) 2004-2006 by Jakob Schroeter <js@camaya.net>
00003   This file is part of the gloox library. http://camaya.net/gloox
00004 
00005   This software is distributed under a license. The full license
00006   agreement can be found in the file LICENSE in this distribution.
00007   This software may not be copied, modified, sold or distributed
00008   other than expressed in the named license agreement.
00009 
00010   This software is distributed without any warranty.
00011 */
00012 
00013 
00014 
00015 #include "gloox.h"
00016 
00017 #include "compression.h"
00018 #include "connection.h"
00019 #include "dns.h"
00020 #include "logsink.h"
00021 #include "prep.h"
00022 #include "parser.h"
00023 
00024 #ifdef __MINGW32__
00025 #include <winsock.h>
00026 #endif
00027 
00028 #ifndef WIN32
00029 #include <sys/types.h>
00030 #include <sys/socket.h>
00031 #include <sys/select.h>
00032 #include <unistd.h>
00033 #else
00034 #include <winsock.h>
00035 #endif
00036 
00037 #if defined( _MSC_VER ) && ( _MSC_VER >= 1300 )
00038 #define strcasecmp stricmp
00039 #endif
00040 
00041 #include <time.h>
00042 
00043 #include <string>
00044 
00045 namespace gloox
00046 {
00047 
00048   Connection::Connection( Parser *parser, const LogSink& logInstance, const std::string& server,
00049                           int port )
00050     : m_parser( parser ), m_state ( StateDisconnected ), m_disconnect ( ConnNoError ),
00051       m_logInstance( logInstance ), m_compression( 0 ), m_buf( 0 ),
00052       m_server( Prep::idna( server ) ), m_port( port ), m_socket( -1 ), m_bufsize( 17000 ),
00053       m_cancel( true ), m_secure( false ), m_fdRequested( false ), m_enableCompression( false )
00054   {
00055     m_buf = (char*)calloc( m_bufsize + 1, sizeof( char ) );
00056 #ifdef USE_OPENSSL
00057     m_ssl = 0;
00058 #endif
00059   }
00060 
00061   Connection::~Connection()
00062   {
00063     cleanup();
00064     free( m_buf );
00065     m_buf = 0;
00066   }
00067 
00068 #ifdef HAVE_TLS
00069   void Connection::setClientCert( const std::string& clientKey, const std::string& clientCerts )
00070   {
00071     m_clientKey = clientKey;
00072     m_clientCerts = clientCerts;
00073   }
00074 #endif
00075 
00076 #if defined( USE_OPENSSL )
00077   bool Connection::tlsHandshake()
00078   {
00079     SSL_library_init();
00080     SSL_CTX *sslCTX = SSL_CTX_new( TLSv1_client_method() );
00081     if( !sslCTX )
00082       return false;
00083 
00084     if( !SSL_CTX_set_cipher_list( sslCTX, "HIGH:MEDIUM:AES:@STRENGTH" ) )
00085       return false;
00086 
00087     StringList::const_iterator it = m_cacerts.begin();
00088     for( ; it != m_cacerts.end(); ++it )
00089       SSL_CTX_load_verify_locations( sslCTX, (*it).c_str(), NULL );
00090 
00091     if( !m_clientKey.empty() && !m_clientCerts.empty() )
00092     {
00093       SSL_CTX_use_certificate_chain_file( sslCTX, m_clientCerts.c_str() );
00094       SSL_CTX_use_PrivateKey_file( sslCTX, m_clientKey.c_str(), SSL_FILETYPE_PEM );
00095     }
00096 
00097     m_ssl = SSL_new( sslCTX );
00098     SSL_set_connect_state( m_ssl );
00099 
00100     BIO *socketBio = BIO_new_socket( m_socket, BIO_NOCLOSE );
00101     if( !socketBio )
00102       return false;
00103 
00104     SSL_set_bio( m_ssl, socketBio, socketBio );
00105     SSL_set_mode( m_ssl, SSL_MODE_AUTO_RETRY );
00106 
00107     if( !SSL_connect( m_ssl ) )
00108       return false;
00109 
00110     m_secure = true;
00111 
00112     int res = SSL_get_verify_result( m_ssl );
00113     if( res != X509_V_OK )
00114       m_certInfo.status = CertInvalid;
00115     else
00116       m_certInfo.status = CertOk;
00117 
00118     X509 *peer;
00119     peer = SSL_get_peer_certificate( m_ssl );
00120     if( peer )
00121     {
00122       char peer_CN[256];
00123       X509_NAME_get_text_by_NID( X509_get_issuer_name( peer ), NID_commonName, peer_CN, sizeof( peer_CN ) );
00124       m_certInfo.issuer = peer_CN;
00125       X509_NAME_get_text_by_NID( X509_get_subject_name( peer ), NID_commonName, peer_CN, sizeof( peer_CN ) );
00126       m_certInfo.server = peer_CN;
00127       if( strcasecmp( peer_CN, m_server.c_str() ) )
00128         m_certInfo.status |= CertWrongPeer;
00129     }
00130     else
00131     {
00132       m_certInfo.status = CertInvalid;
00133     }
00134 
00135     const char *tmp;
00136     tmp = SSL_get_cipher_name( m_ssl );
00137     if( tmp )
00138       m_certInfo.cipher = tmp;
00139 
00140     tmp = SSL_get_cipher_version( m_ssl );
00141     if( tmp )
00142       m_certInfo.protocol = tmp;
00143 
00144     return true;
00145   }
00146 
00147 #elif defined( USE_GNUTLS )
00148   bool Connection::tlsHandshake()
00149   {
00150     const int protocolPriority[] = { GNUTLS_TLS1, GNUTLS_SSL3, 0 };
00151     const int kxPriority[]       = { GNUTLS_KX_RSA, 0 };
00152     const int cipherPriority[]   = { GNUTLS_CIPHER_AES_256_CBC, GNUTLS_CIPHER_AES_128_CBC,
00153                                              GNUTLS_CIPHER_3DES_CBC, GNUTLS_CIPHER_ARCFOUR, 0 };
00154     const int compPriority[]     = { GNUTLS_COMP_ZLIB, GNUTLS_COMP_NULL, 0 };
00155     const int macPriority[]      = { GNUTLS_MAC_SHA, GNUTLS_MAC_MD5, 0 };
00156 
00157     if( gnutls_global_init() != 0 )
00158       return false;
00159 
00160     if( gnutls_certificate_allocate_credentials( &m_credentials ) < 0 )
00161       return false;
00162 
00163     StringList::const_iterator it = m_cacerts.begin();
00164     for( ; it != m_cacerts.end(); ++it )
00165       gnutls_certificate_set_x509_trust_file( m_credentials, (*it).c_str(), GNUTLS_X509_FMT_PEM );
00166 
00167     if( !m_clientKey.empty() && !m_clientCerts.empty() )
00168     {
00169       gnutls_certificate_set_x509_key_file( m_credentials, m_clientKey.c_str(),
00170                                             m_clientCerts.c_str(), GNUTLS_X509_FMT_PEM );
00171     }
00172 
00173     if( gnutls_init( &m_session, GNUTLS_CLIENT ) != 0 )
00174     {
00175       gnutls_certificate_free_credentials( m_credentials );
00176       return false;
00177     }
00178 
00179     gnutls_protocol_set_priority( m_session, protocolPriority );
00180     gnutls_cipher_set_priority( m_session, cipherPriority );
00181     gnutls_compression_set_priority( m_session, compPriority );
00182     gnutls_kx_set_priority( m_session, kxPriority );
00183     gnutls_mac_set_priority( m_session, macPriority );
00184     gnutls_credentials_set( m_session, GNUTLS_CRD_CERTIFICATE, m_credentials );
00185 
00186     gnutls_transport_set_ptr( m_session, (gnutls_transport_ptr_t)m_socket );
00187     if( gnutls_handshake( m_session ) != 0 )
00188     {
00189       gnutls_deinit( m_session );
00190       gnutls_certificate_free_credentials( m_credentials );
00191       return false;
00192     }
00193     gnutls_certificate_free_ca_names( m_credentials );
00194 
00195     m_secure = true;
00196 
00197     unsigned int status;
00198     bool error = false;
00199 
00200     if( gnutls_certificate_verify_peers2( m_session, &status ) < 0 )
00201       error = true;
00202 
00203     m_certInfo.status = 0;
00204     if( status & GNUTLS_CERT_INVALID )
00205       m_certInfo.status |= CertInvalid;
00206     if( status & GNUTLS_CERT_SIGNER_NOT_FOUND )
00207       m_certInfo.status |= CertSignerUnknown;
00208     if( status & GNUTLS_CERT_REVOKED )
00209       m_certInfo.status |= CertRevoked;
00210     if( status & GNUTLS_CERT_SIGNER_NOT_CA )
00211       m_certInfo.status |= CertSignerNotCa;
00212     const gnutls_datum_t* certList = 0;
00213     unsigned int certListSize;
00214     if( !error && ( ( certList = gnutls_certificate_get_peers( m_session, &certListSize ) ) == 0 ) )
00215       error = true;
00216 
00217     gnutls_x509_crt_t *cert = new gnutls_x509_crt_t[certListSize+1];
00218     for( unsigned int i=0; !error && ( i<certListSize ); ++i )
00219     {
00220       if( !error && ( gnutls_x509_crt_init( &cert[i] ) < 0 ) )
00221         error = true;
00222       if( !error && ( gnutls_x509_crt_import( cert[i], &certList[i], GNUTLS_X509_FMT_DER ) < 0 ) )
00223         error = true;
00224     }
00225 
00226     if( ( gnutls_x509_crt_check_issuer( cert[certListSize-1], cert[certListSize-1] ) > 0 )
00227          && certListSize > 0 )
00228       certListSize--;
00229 
00230     bool chain = true;
00231     for( unsigned int i=1; !error && ( i<certListSize ); ++i )
00232     {
00233       chain = error = !verifyAgainst( cert[i-1], cert[i] );
00234     }
00235     if( !chain )
00236       m_certInfo.status |= CertInvalid;
00237     m_certInfo.chain = chain;
00238 
00239     m_certInfo.chain = verifyAgainstCAs( cert[certListSize], 0 /*CAList*/, 0 /*CAListSize*/ );
00240 
00241     int t = (int)gnutls_x509_crt_get_expiration_time( cert[0] );
00242     if( t == -1 )
00243       error = true;
00244     else if( t < time( 0 ) )
00245       m_certInfo.status |= CertExpired;
00246     m_certInfo.date_from = t;
00247 
00248     t = (int)gnutls_x509_crt_get_activation_time( cert[0] );
00249     if( t == -1 )
00250       error = true;
00251     else if( t > time( 0 ) )
00252       m_certInfo.status |= CertNotActive;
00253     m_certInfo.date_to = t;
00254 
00255     char name[64];
00256     size_t nameSize = sizeof( name );
00257     gnutls_x509_crt_get_issuer_dn( cert[0], name, &nameSize );
00258     m_certInfo.issuer = name;
00259 
00260     nameSize = sizeof( name );
00261     gnutls_x509_crt_get_dn( cert[0], name, &nameSize );
00262     m_certInfo.server = name;
00263 
00264     const char* info;
00265     info = gnutls_compression_get_name( gnutls_compression_get( m_session ) );
00266     if( info )
00267       m_certInfo.compression = info;
00268 
00269     info = gnutls_mac_get_name( gnutls_mac_get( m_session ) );
00270     if( info )
00271       m_certInfo.mac = info;
00272 
00273     info = gnutls_cipher_get_name( gnutls_cipher_get( m_session ) );
00274     if( info )
00275       m_certInfo.cipher = info;
00276 
00277     info = gnutls_protocol_get_name( gnutls_protocol_get_version( m_session ) );
00278     if( info )
00279       m_certInfo.protocol = info;
00280 
00281     if( !gnutls_x509_crt_check_hostname( cert[0], m_server.c_str() ) )
00282       m_certInfo.status |= CertWrongPeer;
00283 
00284     for( unsigned int i=0; i<certListSize; ++i )
00285       gnutls_x509_crt_deinit( cert[i] );
00286 
00287     delete[] cert;
00288 
00289     return true;
00290   }
00291 
00292   bool Connection::verifyAgainst( gnutls_x509_crt_t cert, gnutls_x509_crt_t issuer )
00293   {
00294     unsigned int result;
00295     gnutls_x509_crt_verify( cert, &issuer, 1, 0, &result );
00296     if( result & GNUTLS_CERT_INVALID )
00297       return false;
00298 
00299     if( gnutls_x509_crt_get_expiration_time( cert ) < time( 0 ) )
00300       return false;
00301 
00302     if( gnutls_x509_crt_get_activation_time( cert ) > time( 0 ) )
00303       return false;
00304 
00305     return true;
00306   }
00307 
00308   bool Connection::verifyAgainstCAs( gnutls_x509_crt_t cert, gnutls_x509_crt_t *CAList, int CAListSize )
00309   {
00310     unsigned int result;
00311     gnutls_x509_crt_verify( cert, CAList, CAListSize, GNUTLS_VERIFY_ALLOW_X509_V1_CA_CRT, &result );
00312     if( result & GNUTLS_CERT_INVALID )
00313       return false;
00314 
00315     if( gnutls_x509_crt_get_expiration_time( cert ) < time( 0 ) )
00316       return false;
00317 
00318     if( gnutls_x509_crt_get_activation_time( cert ) > time( 0 ) )
00319       return false;
00320 
00321     return true;
00322   }
00323 #endif
00324 
00325 #ifdef HAVE_ZLIB
00326   bool Connection::initCompression( StreamFeature method )
00327   {
00328     delete m_compression;
00329     m_compression = 0;
00330     m_compression = new Compression( method );
00331     return true;
00332   }
00333 
00334   void Connection::enableCompression()
00335   {
00336     if( !m_compression )
00337       return;
00338 
00339     m_enableCompression = true;
00340   }
00341 #endif
00342 
00343   ConnectionState Connection::connect()
00344   {
00345     if( m_socket != -1 && m_state >= StateConnecting )
00346     {
00347       return m_state;
00348     }
00349 
00350     m_state = StateConnecting;
00351 
00352     if( m_port == -1 )
00353       m_socket = DNS::connect( m_server, m_logInstance );
00354     else
00355       m_socket = DNS::connect( m_server, m_port, m_logInstance );
00356 
00357     if( m_socket < 0 )
00358     {
00359       switch( m_socket )
00360       {
00361         case -DNS::DNS_COULD_NOT_CONNECT:
00362           m_logInstance.log( LogLevelError, LogAreaClassConnection, "connection error: could not connect" );
00363           break;
00364         case -DNS::DNS_NO_HOSTS_FOUND:
00365           m_logInstance.log( LogLevelError, LogAreaClassConnection, "connection error: no hosts found" );
00366           break;
00367         case -DNS::DNS_COULD_NOT_RESOLVE:
00368           m_logInstance.log( LogLevelError, LogAreaClassConnection, "connection error: could not resolve" );
00369           break;
00370       }
00371       cleanup();
00372     }
00373     else
00374       m_state = StateConnected;
00375 
00376     m_cancel = false;
00377     return m_state;
00378   }
00379 
00380   void Connection::disconnect( ConnectionError e )
00381   {
00382     m_disconnect = e;
00383     m_cancel = true;
00384 
00385     if( m_fdRequested )
00386       cleanup();
00387   }
00388 
00389   int Connection::fileDescriptor()
00390   {
00391     m_fdRequested = true;
00392     return m_socket;
00393   }
00394 
00395   ConnectionError Connection::recv( int timeout )
00396   {
00397     if( m_cancel )
00398     {
00399       ConnectionError e = m_disconnect;
00400       cleanup();
00401       return e;
00402     }
00403 
00404     if( m_socket == -1 )
00405       return ConnNotConnected;
00406 
00407     if( !m_fdRequested )
00408     {
00409       fd_set fds;
00410       struct timeval tv;
00411 
00412       FD_ZERO( &fds );
00413       FD_SET( m_socket, &fds );
00414 
00415       tv.tv_sec = timeout;
00416       tv.tv_usec = 0;
00417 
00418       if( select( m_socket + 1, &fds, 0, 0, timeout == -1 ? 0 : &tv ) < 0 )
00419         return ConnIoError;
00420 
00421       if( !FD_ISSET( m_socket, &fds ) )
00422         return ConnNoError;
00423     }
00424 
00425     // optimize(?): recv returns the size. set size+1 = \0
00426     memset( m_buf, '\0', m_bufsize + 1 );
00427     int size = 0;
00428 #if defined( USE_GNUTLS )
00429     if( m_secure )
00430     {
00431       size = gnutls_record_recv( m_session, m_buf, m_bufsize );
00432     }
00433     else
00434 #elif defined( USE_OPENSSL )
00435     if( m_secure )
00436     {
00437       size = SSL_read( m_ssl, m_buf, m_bufsize );
00438     }
00439     else
00440 #endif
00441     {
00442 #ifdef SKYOS
00443       size = ::recv( m_socket, (unsigned char*)m_buf, m_bufsize, 0 );
00444 #else
00445       size = ::recv( m_socket, m_buf, m_bufsize, 0 );
00446 #endif
00447     }
00448 
00449     if( size < 0 )
00450     {
00451       // error
00452       return ConnIoError;
00453     }
00454     else if( size == 0 )
00455     {
00456       // connection closed
00457       return ConnUserDisconnected;
00458     }
00459     else
00460     {
00461       std::string buf;
00462       if( m_compression && m_enableCompression )
00463       {
00464         buf.assign( m_buf, size );
00465         buf = m_compression->decompress( buf );
00466       }
00467       else
00468         buf.assign( m_buf, strlen( m_buf ) );
00469 
00470       Parser::ParserState ret = m_parser->feed( buf );
00471       if( ret != Parser::PARSER_OK )
00472       {
00473         cleanup();
00474         switch( ret )
00475         {
00476           case Parser::PARSER_BADXML:
00477             m_logInstance.log( LogLevelError, LogAreaClassConnection, "XML parse error" );
00478             break;
00479           case Parser::PARSER_NOMEM:
00480             m_logInstance.log( LogLevelError, LogAreaClassConnection, "memory allocation error" );
00481             break;
00482           default:
00483             break;
00484         }
00485         return ConnIoError;
00486       }
00487     }
00488 
00489     return ConnNoError;
00490   }
00491 
00492   ConnectionError Connection::receive()
00493   {
00494     if( m_socket == -1 || !m_parser )
00495       return ConnNotConnected;
00496 
00497     while( !m_cancel )
00498     {
00499       ConnectionError r = recv( 1 );
00500       if( r != ConnNoError )
00501         return r;
00502     }
00503     cleanup();
00504 
00505     return m_disconnect;
00506   }
00507 
00508   void Connection::send( const std::string& data )
00509   {
00510     if( data.empty() || ( m_socket == -1 ) )
00511       return;
00512 
00513     std::string xml;
00514     if( m_compression && m_enableCompression )
00515       xml = m_compression->compress( data );
00516     else
00517       xml = data;
00518 
00519 #if defined( USE_GNUTLS )
00520     if( m_secure )
00521     {
00522       int ret;
00523       size_t len = xml.length();
00524       do
00525       {
00526         ret = gnutls_record_send( m_session, xml.c_str(), len );
00527       }
00528       while( ( ret == GNUTLS_E_AGAIN ) || ( ret == GNUTLS_E_INTERRUPTED ) );
00529     }
00530     else
00531 #elif defined( USE_OPENSSL )
00532     if( m_secure )
00533     {
00534       int ret;
00535       size_t len = xml.length();
00536       ret = SSL_write( m_ssl, xml.c_str(), len );
00537     }
00538     else
00539 #endif
00540     {
00541       size_t num = 0;
00542       size_t len = xml.length();
00543       while( num < len )
00544       {
00545 #ifdef SKYOS
00546         num += ::send( m_socket, (unsigned char*)(xml.c_str()+num), len - num, 0 );
00547 #else
00548         num += ::send( m_socket, (xml.c_str()+num), len - num, 0 );
00549 #endif
00550       }
00551     }
00552   }
00553 
00554   void Connection::cleanup()
00555   {
00556 #if defined( USE_GNUTLS )
00557     if( m_secure )
00558     {
00559       gnutls_bye( m_session, GNUTLS_SHUT_RDWR );
00560       gnutls_deinit( m_session );
00561       gnutls_certificate_free_credentials( m_credentials );
00562       gnutls_global_deinit();
00563     }
00564 #elif defined( USE_OPENSSL )
00565     if( m_secure )
00566     {
00567       SSL_shutdown( m_ssl );
00568       SSL_free( m_ssl );
00569     }
00570 #endif
00571 
00572     if( m_socket != -1 )
00573     {
00574 #ifdef WIN32
00575       closesocket( m_socket );
00576 #else
00577       close( m_socket );
00578 #endif
00579       m_socket = -1;
00580     }
00581     m_state = StateDisconnected;
00582     m_disconnect = ConnNoError;
00583     m_enableCompression = false;
00584     m_secure = false;
00585     m_cancel = true;
00586     m_fdRequested = false;
00587   }
00588 
00589 }

Generated on Wed Sep 13 21:33:46 2006 for gloox by  doxygen 1.4.7