2012年10月8日 星期一

Manage the certificates received from ssl website


#ifndef SSLACCESSCONTROLLER_H
#define SSLACCESSCONTROLLER_H

#include "Noncopyable.h"

#include <WinCrypt.h>

namespace WebCore
{  
class ResourceHandle;
class ResourceResponse;

class SSLAccessController :public Noncopyable
{
public:
    SSLAccessController();

    bool didReceiveCertificate(ResourceHandle& handle, HANDLE hInternet);

    void didReceiveSSLError(ResourceHandle& handle, long errorID, HANDLE hInternet);

    void didReceiveClientAuthNeededError(ResourceHandle& handle, long errorID, HANDLE hInternet);

    void continueVisitSSLSite(ResourceHandle& handle, HANDLE hInternet, void* certificate);

    void cancelVisitSSLSite(ResourceHandle& handle);

private:
   
    void resetFlagMembers();

    void sendCurrentRequest(ResourceHandle& handle, bool noSecurityChecks);

    void receiveSSLError(ResourceHandle& handle, long errorID, HANDLE hInternet);

    bool notifyCertificate(ResourceHandle& handle, HANDLE hInternet);

    void sendSSLErrorToUI(const ResourceResponse& response, ResourceHandle& handle, long errorID, HANDLE hInternet);

private:
    bool m_continueVisit;
    bool m_hasSSLError;
    bool m_clientAuthNeeded;
    PCCERT_CONTEXT m_certContext;
};

}

#endif //SSLACCESSCONTROLLER_H



#include "config.h"
#include "SSLAccessController.h"
#include "ResourceHandle.h"
#include "ResourceHandleInternal.h"
#include "ResourceHandleClient.h"
#include "ResourceError.h"
#include "WinInetManager.h"
#include "SSLCertStoreManager.h"

namespace WebCore
{

class CertContextDeleter
{
public:
    CertContextDeleter(PCCERT_CONTEXT cert)
    :m_cert(cert){}

    ~CertContextDeleter(){CertFreeCertificateContext(m_cert);}

private:
    PCCERT_CONTEXT m_cert;
};


SSLAccessController::SSLAccessController()
{
    resetFlagMembers();
}

static bool needToHandleSSLMessage(ResourceHandle& handle)
{
    return handle.request().url().protocol().lower() == "https" &&
        handle.request().targetType() == ResourceRequestBase::TargetIsMainFrame;
}

static PCCERT_CONTEXT getCertFromWininet(HANDLE hInternet)
{
    PCCERT_CONTEXT pInfo = NULL;
    DWORD certInfoLength = sizeof(PCCERT_CONTEXT);
    BOOL succ = WinInetManager::InternetQueryOptionW(
        hInternet,
        INTERNET_OPTION_SERVER_CERT_CONTEXT,
        &pInfo,
        &certInfoLength);
    if (!succ)
        pInfo = NULL;
    return pInfo;
}

static bool findCertInStore(bool& inAcceptingStore, PCCERT_CONTEXT cert)
{
    return SSLCertStoreManager::getInstance().findCertInStore(inAcceptingStore, cert);
}

static bool saveCertInStore(bool inAcceptingStore, PCCERT_CONTEXT cert)
{
    if (NULL == cert)
        return false;
    return SSLCertStoreManager::getInstance().saveCertInStore(inAcceptingStore, cert);
}

static void buildResourceResponse(ResourceResponse& response, INTERNET_CERTIFICATE_INFO* info)
{
    response.setMimeType("text/html");
    response.setTextEncodingName("UTF-16");
    response.setSSLPeerCertificate(info);
}

static bool buildResponseWithCertificate(HINTERNET handle, ResourceResponse& response, INTERNET_CERTIFICATE_INFO* info)
{
    bool result = false;
    DWORD certInfoLength = sizeof(INTERNET_CERTIFICATE_INFO);
    if (WinInetManager::InternetQueryOptionW(handle,
        INTERNET_OPTION_SECURITY_CERTIFICATE_STRUCT,
        info,
        &certInfoLength))
    {
            buildResourceResponse(response, info);
            result = true;
    }
    return result;
}

static void releaseCertificateInfo(INTERNET_CERTIFICATE_INFO* info)
{
    static const int ARRAY_COUNT = 5;
    LPTSTR strs[ARRAY_COUNT] = {
        info->lpszEncryptionAlgName,
        info->lpszIssuerInfo,
        info->lpszProtocolName,
        info->lpszSignatureAlgName,
        info->lpszSubjectInfo};

    for (int i = 0; i < ARRAY_COUNT; ++i)
    {
        if (strs[i])
            LocalFree(strs[i]);
    }
}

bool SSLAccessController::notifyCertificate(ResourceHandle& handle, HANDLE hInternet)
{
    ResourceResponse response;
    INTERNET_CERTIFICATE_INFO certificateInfo;
    if(!buildResponseWithCertificate(hInternet, response, &certificateInfo))
    {
        return false;
    }
    ResourceHandleInternal *handleInternal = handle.d.get();
    handleInternal->m_currentWebChallenge.setAuthenticationClient(&handle);
    handle.client()->didReceiveCertificate(&handle, response, handleInternal->m_currentWebChallenge);
    releaseCertificateInfo(&certificateInfo);
    return true;
}

bool SSLAccessController::didReceiveCertificate(ResourceHandle& handle, HANDLE hInternet)
{
    if (!needToHandleSSLMessage(handle))
        return true;

    m_certContext = getCertFromWininet(hInternet);
    if (!m_certContext)
        return false;

    CertContextDeleter temp(m_certContext);

    bool inAcceptingStore = false;
    if (findCertInStore(inAcceptingStore, m_certContext))
    {
        return inAcceptingStore;
    }
   
    if (!notifyCertificate(handle, hInternet))
        return false;
   
    //m_continueVisit will be changed by notifyCertificate in continueVisitSSLSite()
    //or cancelVisitSSLSite()
    return m_continueVisit;
}

void SSLAccessController::sendSSLErrorToUI(const ResourceResponse& response, ResourceHandle& handle, long errorID, HANDLE hInternet)
{
    ResourceHandleInternal *handleInternal = handle.d.get();

    handleInternal->m_currentWebChallenge.setAuthenticationClient(&handle);

    ResourceError resourceError("Error", WebURLErrorUnknown, handle.request().url().string(), "");
    resourceError.setPlatformCustomError(errorID);

    //Send error info to UI
    handle.client()->didReceiveSSLError(&handle, response, resourceError, handleInternal->m_currentWebChallenge);
}



void SSLAccessController::receiveSSLError(ResourceHandle& handle, long errorID, HANDLE hInternet)
{
    if (!needToHandleSSLMessage(handle))
        return;

    m_certContext = getCertFromWininet(hInternet);
    if (!m_certContext)
        return;

    CertContextDeleter temp(m_certContext);

    bool inAcceptingStore = false;
    if (findCertInStore(inAcceptingStore, m_certContext))
    {
        //the certificate is already in store.
        if (!inAcceptingStore) // do nothing
            return;

        if (inAcceptingStore) // Continue visiting
            return sendCurrentRequest(handle, true);
    }

    ResourceResponse response;
    INTERNET_CERTIFICATE_INFO certificateInfo;
    if(!buildResponseWithCertificate(hInternet, response, &certificateInfo))
        return;

    sendSSLErrorToUI(response, handle, errorID, hInternet);
   
    releaseCertificateInfo(&certificateInfo);
}

void SSLAccessController::didReceiveSSLError(ResourceHandle& handle, long errorID, HANDLE hInternet)
{
    if (handle.request().targetType() != ResourceRequestBase::TargetIsMainFrame)
        return sendCurrentRequest(handle, true);

    m_hasSSLError = true;
    receiveSSLError(handle, errorID, hInternet);
}

void SSLAccessController::didReceiveClientAuthNeededError(ResourceHandle& handle, long errorID, HANDLE hInternet)
{
    m_clientAuthNeeded = true;
   
    ResourceResponse response;
    buildResourceResponse(response, NULL);
   
    sendSSLErrorToUI(response, handle, errorID, hInternet);
}

void SSLAccessController::sendCurrentRequest(ResourceHandle& handle, bool noSecurityChecks)
{
    bool temp = handle.d->m_noSecurityChecks;
    handle.d->m_noSecurityChecks = noSecurityChecks;
    handle.sendCurrentRequest();
    handle.d->m_noSecurityChecks = temp;
}

void SSLAccessController::continueVisitSSLSite(ResourceHandle& handle, HANDLE hInternet, void* certificate)
{
    m_continueVisit = true;
if(!m_clientAuthNeeded)
saveCertInStore(true, m_certContext);

    if (certificate != NULL && m_clientAuthNeeded)
    {
        WinInetManager::InternetSetOptionW(
            hInternet,
            INTERNET_OPTION_CLIENT_CERT_CONTEXT,
            (LPVOID)certificate,
            sizeof(CERT_CONTEXT));
    }

    if (m_hasSSLError)
    {
        sendCurrentRequest(handle, true);
    }

    if (m_clientAuthNeeded)
    {
        sendCurrentRequest(handle, false);
    }

    resetFlagMembers();
}

void SSLAccessController::cancelVisitSSLSite(ResourceHandle& handle)
{
    m_continueVisit = false;

if(!m_clientAuthNeeded)
saveCertInStore(false, m_certContext);
    resetFlagMembers();
}

void SSLAccessController::resetFlagMembers()
{
    m_continueVisit = true;
    m_hasSSLError = false;
    m_clientAuthNeeded = false;
}

}

沒有留言:

張貼留言