blob: 9235937001335981ecdb4a9f533204bd61ec4bb8 [file] [log] [blame] [edit]
// Copyright 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package common
import (
"encoding/base64"
"encoding/json"
"encoding/xml"
"fmt"
"github.com/apid/apid-core/cipher"
"io/ioutil"
"net/http"
"regexp"
"strings"
"sync"
"time"
)
const regEncrypted = `^\{[0-9A-Za-z]+/[0-9A-Za-z]+/[0-9A-Za-z]+\}.`
const retrieveEncryptKeyPath = "/encryptionkey"
const EncryptAes = "AES"
const (
retrieveKeyRetryInterval = time.Duration(5 * time.Second)
retrieveKeyTimeout = time.Duration(5 * time.Minute)
)
const parameterOrganization = "organization"
const configBearerToken = "apigeesync_bearer_token"
const headerContentType = "Content-Type"
const (
typeJson = "application/json"
typeXml = "text/xml"
)
const errorCodeNoKey = "organizations.EncryptionKeyDoesNotExist"
var RegexpEncrypted = regexp.MustCompile(regEncrypted)
func CreateCipherManager(client *http.Client, serverUrlBase string) *KmsCipherManager {
return &KmsCipherManager{
serverUrlBase: serverUrlBase,
aes: make(map[string]*cipher.AesCipher),
mutex: &sync.RWMutex{},
client: client,
interval: retrieveKeyRetryInterval,
timeout: retrieveKeyTimeout,
}
}
type KmsCipherManager struct {
serverUrlBase string
// org-level AesCipher map {organization: AesCipher}
aes map[string]*cipher.AesCipher
mutex *sync.RWMutex
client *http.Client
interval time.Duration
timeout time.Duration
}
func (c *KmsCipherManager) AddOrgs(orgs []string) {
for _, org := range orgs {
go c.startRetrieve(org, c.interval, c.timeout)
}
}
func (c *KmsCipherManager) startRetrieve(org string, interval time.Duration, timeout time.Duration) {
timeoutChan := time.After(timeout)
if err := c.retrieveKey(org); err != nil {
log.Error(err)
} else {
return
}
ticker := time.NewTicker(interval)
for {
select {
case <-timeoutChan:
log.Error("timeout when retrieving key")
return
case <-ticker.C:
if err := c.retrieveKey(org); err != nil {
log.Error(err)
} else {
return
}
}
}
}
func (c *KmsCipherManager) retrieveKey(org string) error {
var key []byte
req, err := http.NewRequest(http.MethodGet, c.serverUrlBase+retrieveEncryptKeyPath, nil)
if err != nil {
return fmt.Errorf("failed to create retrieving key request for org=%s : %v", org, err)
}
pars := req.URL.Query()
pars[parameterOrganization] = []string{org}
req.URL.RawQuery = pars.Encode()
req.Header.Set("Authorization", "Bearer "+services.Config().GetString(configBearerToken))
log.Debugf("Retrieving key: %s", req.URL.String())
res, err := c.client.Do(req)
if err != nil {
return fmt.Errorf("failed to retrieve key for org=%s : %v", org, err)
}
// if 404
if res.StatusCode == http.StatusNotFound {
e, err := parseErrorResponse(res)
if err != nil {
log.Errorf("Failed to parse 404 error response for org %s: %v", org, err)
return err
}
// is this org has no key, stop retrying
if e.Code == errorCodeNoKey {
log.Debugf("No key is associated with org %v", org)
return nil
}
}
if res.StatusCode != http.StatusOK {
return fmt.Errorf("failed to retrieve key for org [%v] with status: %v", org, res.Status)
}
log.Debugf("Downloaded Encryption Key for org %s", org)
key64, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
return fmt.Errorf("error reading encryption key: %v", err)
}
key, err = base64.StdEncoding.DecodeString(string(key64))
if err != nil {
return fmt.Errorf("error decoding encryption key: %v", err)
}
log.Debugf("Encryption Key successfully retrieved for org %s", org)
a, err := cipher.CreateAesCipher(key)
if err != nil {
return fmt.Errorf("CreateAesCipher error for org [%v] when CreateAesCipher: %v", org, err)
}
c.mutex.Lock()
c.aes[org] = a
c.mutex.Unlock()
return nil
}
// return val is nullable
func (c *KmsCipherManager) getAesCipher(org string) *cipher.AesCipher {
// if exists
c.mutex.RLock()
if a := c.aes[org]; a != nil {
c.mutex.RUnlock()
return a
}
// if not exists
c.mutex.RUnlock()
if err := c.retrieveKey(org); err != nil {
log.Errorf("Failed to get encryption key for org=%s : %v", org, err)
return nil
}
c.mutex.RLock()
defer c.mutex.RUnlock()
return c.aes[org]
}
// If input is encrypted, it decodes the input with base64,
// and then decrypt it. Otherwise, original input is returned.
// An encrypted input should be ciphertext prepended with algorithm. An unencrypted input can have any other format.
// An example of encrypted input is "{AES/ECB/PKCS5Padding}2jX3V3dQ5xB9C9Zl9sqyo8pmkvVP10rkEVPVhmnLHw4=".
func (c *KmsCipherManager) TryDecryptBase64(input string, org string) (output string, err error) {
if !IsEncrypted(input) {
output = input
return
}
text, mode, padding, err := GetCiphertext(input)
if err != nil {
log.Errorf("Get ciphertext of [%v] failed: [%v], considered as unencrypted!", input, err)
return
}
bytes, err := base64.StdEncoding.DecodeString(text)
if err != nil {
log.Errorf("Decode base64 of [%v] failed: [%v], considered as unencrypted!", text, err)
return
}
aes := c.getAesCipher(org)
if aes == nil {
err = fmt.Errorf("failed to get decryption key for org: %s", org)
return
}
plaintext, err := aes.Decrypt(bytes, mode, padding)
if err != nil {
log.Errorf("Decrypt of [%v] failed: [%v], considered as unencrypted!", bytes, err)
return
}
output = string(plaintext)
return
}
// It encrypts the input, and then encodes the ciphertext with base64.
// The returned string is the base64 encoding of the encrypted input, prepended with algorithm.
// An example output is "{AES/ECB/PKCS5Padding}2jX3V3dQ5xB9C9Zl9sqyo8pmkvVP10rkEVPVhmnLHw4="
func (c *KmsCipherManager) EncryptBase64(input string, org string, mode cipher.Mode, padding cipher.Padding) (output string, err error) {
aes := c.getAesCipher(org)
// TODO: make sure this logic is expected
// if failed to get key and cipher, considered this org as unencrypted
if aes == nil {
return input, nil
}
ciphertext, err := aes.Encrypt([]byte(input), mode, padding)
if err != nil {
return
}
output = fmt.Sprintf("{%s/%s/%s}%s", EncryptAes, mode, padding, base64.StdEncoding.EncodeToString(ciphertext))
return
}
func IsEncrypted(input string) (encrypted bool) {
return RegexpEncrypted.Match([]byte(input))
}
func GetCiphertext(input string) (ciphertext string, mode cipher.Mode, padding cipher.Padding, err error) {
list := strings.SplitN(input, "}", 2)
if len(list) != 2 {
err = fmt.Errorf("invalid input for GetCiphertext: %v", input)
return
}
ciphertext = list[1]
list = strings.Split(strings.TrimLeft(list[0], "{"), "/")
if len(list) != 3 {
err = fmt.Errorf("invalid input for GetCiphertext: %v", input)
return
}
// encryption algorithm
if list[0] != EncryptAes {
err = fmt.Errorf("unsupported algorithm for GetCiphertext: %v", list[0])
return
}
// mode
mode = cipher.Mode(list[1])
// padding
padding = cipher.Padding(list[2])
return
}
type KeyErrorResponse struct {
Code string `json:"code"`
Message string `json:"message"`
}
func parseErrorResponse(res *http.Response) (*KeyErrorResponse, error) {
contentType := res.Header.Get(headerContentType)
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, err
}
ret := &KeyErrorResponse{}
if contentType == typeJson {
return ret, json.Unmarshal(body, ret)
} else if contentType == typeXml {
return ret, xml.Unmarshal(body, ret)
} else {
return nil, fmt.Errorf("unknown error: %v", string(body))
}
}