blob: dcd3eec2fea685af41c7a7c61f13b7dcf9040f4a [file]
// Copyright 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package common
import (
"encoding/base64"
"fmt"
"github.com/apid/apid-core/cipher"
"io/ioutil"
"net/http"
"regexp"
"strings"
"sync"
)
const RegEncrypted = `^\{[0-9A-Za-z]+/[0-9A-Za-z]+/[0-9A-Za-z]+\}.`
const retrieveEncryptKeyPath = ""
var RegexpEncrypted = regexp.MustCompile(RegEncrypted)
const (
EncryptAes = "AES"
)
type KmsCipherManager struct {
// org-level key map {organization: key}
key map[string][]byte
// org-level AesCipher map {organization: AesCipher}
aes map[string]*cipher.AesCipher
mutex *sync.RWMutex
client *http.Client
}
func (c *KmsCipherManager) retrieveKey(org string) (key []byte, err error) {
req, err := http.NewRequest(http.MethodGet, retrieveEncryptKeyPath, nil)
if err != nil {
return
}
res, err := c.client.Do(req)
if err != nil {
return
}
if res.StatusCode != http.StatusOK {
err = fmt.Errorf("retrieve encryption key failed for org [%v] with status: %v", org, res.Status)
return
}
defer res.Body.Close()
key64, err := ioutil.ReadAll(res.Body)
key, err = base64.StdEncoding.DecodeString(string(key64))
return
}
func (c *KmsCipherManager) getAesCipher(org string) (*cipher.AesCipher, error) {
// if exists
c.mutex.RLock()
if a := c.aes[org]; a != nil {
c.mutex.RUnlock()
return a, nil
}
c.mutex.RUnlock()
// if not exists
key, err := c.retrieveKey(org)
if err != nil {
log.Errorf("getAesCipher error for org [%v] when retrieveKey: %v", org, err)
return nil, err
}
c.mutex.Lock()
defer c.mutex.Unlock()
c.key[org] = key
a, err := cipher.CreateAesCipher(key)
if err != nil {
log.Errorf("getAesCipher error for org [%v] when CreateAesCipher: %v", org, err)
return nil, err
}
c.aes[org] = a
return a, nil
}
// If input is encrypted, it decodes the input with base64,
// and then decrypt it. Otherwise, original input is returned.
// An encrypted input should be ciphertext prepended with algorithm. An unencrypted input can have any other format.
// An example of encrypted input is "{AES/ECB/PKCS5Padding}2jX3V3dQ5xB9C9Zl9sqyo8pmkvVP10rkEVPVhmnLHw4=".
func (c *KmsCipherManager) TryDecryptBase64(input string, org string) (output string, err error) {
if !IsEncrypted(input) {
output = input
return
}
text, mode, padding, err := GetCiphertext(input)
if err != nil {
log.Errorf("Get ciphertext of [%v] failed: [%v], considered as unencrypted!", input, err)
return
}
bytes, err := base64.StdEncoding.DecodeString(text)
if err != nil {
log.Errorf("Decode base64 of [%v] failed: [%v], considered as unencrypted!", text, err)
return
}
aes, err := c.getAesCipher(org)
if err != nil {
return
}
plaintext, err := aes.Decrypt(bytes, mode, padding)
if err != nil {
log.Errorf("Decrypt of [%v] failed: [%v], considered as unencrypted!", bytes, err)
return
}
output = string(plaintext)
return
}
// It encrypts the input, and then encodes the ciphertext with base64.
// The returned string is the base64 encoding of the encrypted input, prepended with algorithm.
// An example output is "{AES/ECB/PKCS5Padding}2jX3V3dQ5xB9C9Zl9sqyo8pmkvVP10rkEVPVhmnLHw4="
func (c *KmsCipherManager) EncryptBase64(input string, org string, mode cipher.Mode, padding cipher.Padding) (output string, err error) {
aes, err := c.getAesCipher(org)
if err != nil {
return
}
ciphertext, err := aes.Encrypt([]byte(input), mode, padding)
if err != nil {
return
}
output = fmt.Sprintf("{%s/%s/%s}%s", EncryptAes, mode, padding, base64.StdEncoding.EncodeToString(ciphertext))
return
}
// TODO: make sure this regex has no false positive for all possible inputs
func IsEncrypted(input string) (encrypted bool) {
return RegexpEncrypted.Match([]byte(input))
}
func GetCiphertext(input string) (ciphertext string, mode cipher.Mode, padding cipher.Padding, err error) {
l := strings.SplitN(input, "}", 2)
if len(l) != 2 {
err = fmt.Errorf("invalid input for GetCiphertext: %v", input)
return
}
ciphertext = l[1]
l = strings.Split(strings.TrimLeft(l[0], "{"), "/")
if len(l) != 3 {
err = fmt.Errorf("invalid input for GetCiphertext: %v", input)
return
}
// encryption algorithm
if strings.ToUpper(l[0]) != EncryptAes {
err = fmt.Errorf("unsupported algorithm for GetCiphertext: %v", l[0])
return
}
// mode
mode = cipher.Mode(strings.ToUpper(l[1]))
// padding
padding = cipher.Padding(strings.ToUpper(l[2]))
return
}