txAWS-0.2.3/0000775000175000017500000000000011741312025014206 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/README0000664000175000017500000000074611741311335015100 0ustar oubiwannoubiwann00000000000000Dependencies ------------ * Python * The twisted python package (python-twisted on Debian or similar systems) * The dateutil python package (python-dateutil on Debian or similar systems) * lxml (only when using txaws.wsdl) Things present here ------------------- * The txaws python package. (No installer at the moment) * bin/aws-status, a GUI status program for aws resources. License ------- txAWS is open source software, MIT License. See the LICENSE file for more details. txAWS-0.2.3/bin/0000775000175000017500000000000011741312025014756 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/bin/txaws-list-buckets0000775000175000017500000000176311741311335020473 0ustar oubiwannoubiwann00000000000000#!/usr/bin/env python """ %prog [options] """ import sys from txaws.credentials import AWSCredentials from txaws.script import parse_options from txaws.service import AWSServiceRegion from txaws.reactor import reactor def printResults(results): print "\nBuckets:" for bucket in results: print "\t%s (created on %s)" % (bucket.name, bucket.creation_date) print "Total buckets: %s\n" % len(list(results)) return 0 def printError(error): print error.value return 1 def finish(return_code): reactor.stop(exitStatus=return_code) options, args = parse_options(__doc__.strip()) creds = AWSCredentials(options.access_key, options.secret_key) region = AWSServiceRegion( creds=creds, region=options.region, s3_uri=options.url) client = region.get_s3_client() d = client.list_buckets() d.addCallback(printResults) d.addErrback(printError) d.addCallback(finish) # We use a custom reactor so that we can return the exit status from # reactor.run(). sys.exit(reactor.run()) txAWS-0.2.3/bin/txaws-head-object0000775000175000017500000000213111741311335020215 0ustar oubiwannoubiwann00000000000000#!/usr/bin/env python """ %prog [options] """ import sys from pprint import pprint from txaws.credentials import AWSCredentials from txaws.script import parse_options from txaws.service import AWSServiceRegion from txaws.reactor import reactor def printResults(results): pprint(results) return 0 def printError(error): print error.value return 1 def finish(return_code): reactor.stop(exitStatus=return_code) options, args = parse_options(__doc__.strip()) if options.bucket is None: print "Error Message: A bucket name is required." sys.exit(1) if options.object_name is None: print "Error Message: An object name is required." sys.exit(1) creds = AWSCredentials(options.access_key, options.secret_key) region = AWSServiceRegion( creds=creds, region=options.region, s3_uri=options.url) client = region.get_s3_client() d = client.head_object(options.bucket, options.object_name) d.addCallback(printResults) d.addErrback(printError) d.addCallback(finish) # We use a custom reactor so that we can return the exit status from # reactor.run(). sys.exit(reactor.run()) txAWS-0.2.3/bin/txaws-delete-bucket0000775000175000017500000000166111741311335020574 0ustar oubiwannoubiwann00000000000000#!/usr/bin/env python """ %prog [options] """ import sys from txaws.credentials import AWSCredentials from txaws.script import parse_options from txaws.service import AWSServiceRegion from txaws.reactor import reactor def printResults(results): return 0 def printError(error): print error.value return 1 def finish(return_code): reactor.stop(exitStatus=return_code) options, args = parse_options(__doc__.strip()) if options.bucket is None: print "Error Message: A bucket name is required." sys.exit(1) creds = AWSCredentials(options.access_key, options.secret_key) region = AWSServiceRegion( creds=creds, region=options.region, s3_uri=options.url) client = region.get_s3_client() d = client.delete_bucket(options.bucket) d.addCallback(printResults) d.addErrback(printError) d.addCallback(finish) # We use a custom reactor so that we can return the exit status from # reactor.run(). sys.exit(reactor.run()) txAWS-0.2.3/bin/txaws-create-bucket0000775000175000017500000000166111741311335020575 0ustar oubiwannoubiwann00000000000000#!/usr/bin/env python """ %prog [options] """ import sys from txaws.credentials import AWSCredentials from txaws.script import parse_options from txaws.service import AWSServiceRegion from txaws.reactor import reactor def printResults(results): return 0 def printError(error): print error.value return 1 def finish(return_code): reactor.stop(exitStatus=return_code) options, args = parse_options(__doc__.strip()) if options.bucket is None: print "Error Message: A bucket name is required." sys.exit(1) creds = AWSCredentials(options.access_key, options.secret_key) region = AWSServiceRegion( creds=creds, region=options.region, s3_uri=options.url) client = region.get_s3_client() d = client.create_bucket(options.bucket) d.addCallback(printResults) d.addErrback(printError) d.addCallback(finish) # We use a custom reactor so that we can return the exit status from # reactor.run(). sys.exit(reactor.run()) txAWS-0.2.3/bin/txaws-delete-object0000775000175000017500000000207711741311335020567 0ustar oubiwannoubiwann00000000000000#!/usr/bin/env python """ %prog [options] """ import sys from txaws.credentials import AWSCredentials from txaws.script import parse_options from txaws.service import AWSServiceRegion from txaws.reactor import reactor def printResults(results): print results return 0 def printError(error): print error.value return 1 def finish(return_code): reactor.stop(exitStatus=return_code) options, args = parse_options(__doc__.strip()) if options.bucket is None: print "Error Message: A bucket name is required." sys.exit(1) if options.object_name is None: print "Error Message: An object name is required." sys.exit(1) creds = AWSCredentials(options.access_key, options.secret_key) region = AWSServiceRegion( creds=creds, region=options.region, s3_uri=options.url) client = region.get_s3_client() d = client.delete_object(options.bucket, options.object_name) d.addCallback(printResults) d.addErrback(printError) d.addCallback(finish) # We use a custom reactor so that we can return the exit status from # reactor.run(). sys.exit(reactor.run()) txAWS-0.2.3/bin/aws-status0000775000175000017500000000036111741311335017022 0ustar oubiwannoubiwann00000000000000#!/usr/bin/env python # Copyright (C) 2009 Robert Collins # Licenced under the txaws licence available at /LICENSE in the txaws source. import sys from txaws.client.gui.gtk import main sys.exit(main(sys.argv)) txAWS-0.2.3/bin/txaws-put-object0000775000175000017500000000252111741311335020127 0ustar oubiwannoubiwann00000000000000#!/usr/bin/env python """ %prog [options] """ import os import sys from txaws.credentials import AWSCredentials from txaws.script import parse_options from txaws.service import AWSServiceRegion from txaws.reactor import reactor def printResults(results): return 0 def printError(error): print error.value return 1 def finish(return_code): reactor.stop(exitStatus=return_code) options, args = parse_options(__doc__.strip()) if options.bucket is None: print "Error Message: A bucket name is required." sys.exit(1) filename = options.object_filename if filename: options.object_name = os.path.basename(filename) try: options.object_data = open(filename).read() except Exception, error: print error sys.exit(1) elif options.object_name is None: print "Error Message: An object name is required." sys.exit(1) creds = AWSCredentials(options.access_key, options.secret_key) region = AWSServiceRegion( creds=creds, region=options.region, s3_uri=options.url) client = region.get_s3_client() d = client.put_object( options.bucket, options.object_name, options.object_data, options.content_type) d.addCallback(printResults) d.addErrback(printError) d.addCallback(finish) # We use a custom reactor so that we can return the exit status from # reactor.run(). sys.exit(reactor.run()) txAWS-0.2.3/bin/txaws-get-object0000775000175000017500000000207411741311335020101 0ustar oubiwannoubiwann00000000000000#!/usr/bin/env python """ %prog [options] """ import sys from txaws.credentials import AWSCredentials from txaws.script import parse_options from txaws.service import AWSServiceRegion from txaws.reactor import reactor def printResults(results): print results return 0 def printError(error): print error.value return 1 def finish(return_code): reactor.stop(exitStatus=return_code) options, args = parse_options(__doc__.strip()) if options.bucket is None: print "Error Message: A bucket name is required." sys.exit(1) if options.object_name is None: print "Error Message: An object name is required." sys.exit(1) creds = AWSCredentials(options.access_key, options.secret_key) region = AWSServiceRegion( creds=creds, region=options.region, s3_uri=options.url) client = region.get_s3_client() d = client.get_object(options.bucket, options.object_name) d.addCallback(printResults) d.addErrback(printError) d.addCallback(finish) # We use a custom reactor so that we can return the exit status from # reactor.run(). sys.exit(reactor.run()) txAWS-0.2.3/bin/txaws-get-bucket0000775000175000017500000000223411741311335020106 0ustar oubiwannoubiwann00000000000000#!/usr/bin/env python """ %prog [options] """ import sys from txaws.credentials import AWSCredentials from txaws.script import parse_options from txaws.service import AWSServiceRegion from txaws.reactor import reactor def printResults(listing, bucket): print "Contents of '%s' bucket:" % bucket for item in listing.contents: print "\t%s (last modified on %s)" % (item.key, item.modification_date) print "Total items: %s\n" % len(listing.contents) return 0 def printError(error): print error.value return 1 def finish(return_code): reactor.stop(exitStatus=return_code) options, args = parse_options(__doc__.strip()) if options.bucket is None: print "Error Message: A bucket name is required." sys.exit(1) creds = AWSCredentials(options.access_key, options.secret_key) region = AWSServiceRegion( creds=creds, region=options.region, s3_uri=options.url) client = region.get_s3_client() d = client.get_bucket(options.bucket) d.addCallback(printResults, options.bucket) d.addErrback(printError) d.addCallback(finish) # We use a custom reactor so that we can return the exit status from # reactor.run(). sys.exit(reactor.run()) txAWS-0.2.3/bin/txaws-discover0000775000175000017500000000046411741311335017675 0ustar oubiwannoubiwann00000000000000#!/usr/bin/env python # Copyright (C) 2010 Jamu Kakar # Licenced under the txaws licence available at /LICENSE in the txaws source. import os import sys if os.path.isdir("txaws"): sys.path.insert(0, ".") from txaws.client.discover.entry_point import main sys.exit(main(sys.argv)) txAWS-0.2.3/wsdl/0000775000175000017500000000000011741312025015157 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/wsdl/2009-11-30.ec2.wsdl0000664000175000017500000052547211741311335017575 0ustar oubiwannoubiwann00000000000000 txAWS-0.2.3/setup.py0000664000175000017500000000333011741311335015722 0ustar oubiwannoubiwann00000000000000from distutils.core import setup from glob import glob import os from txaws import version # If setuptools is present, use it to find_packages(), and also # declare our dependency on python-dateutil. extra_setup_args = {} try: import setuptools from setuptools import find_packages extra_setup_args['install_requires'] = ['python-dateutil<2.0', 'twisted'] except ImportError: def find_packages(): """ Compatibility wrapper. Taken from storm setup.py. """ packages = [] for directory, subdirectories, files in os.walk("txaws"): if '__init__.py' in files: packages.append(directory.replace(os.sep, '.')) return packages long_description = """ Twisted-based Asynchronous Libraries for Amazon Web Services and Eucalyptus private clouds This project's goal is to have a complete Twisted API representing the spectrum of Amazon's web services as well as support for Eucalyptus clouds. """ setup( name="txAWS", version=version.txaws, description="Async library for EC2 and Eucalyptus", author="txAWS Developers", author_email="txaws-discuss@lists.launchpad.net", url="https://launchpad.net/txaws", license="MIT", packages=find_packages(), scripts=glob("./bin/*"), long_description=long_description, classifiers=[ "Development Status :: 4 - Beta", "Intended Audience :: Developers", "Intended Audience :: System Administrators", "Intended Audience :: Information Technology", "Programming Language :: Python", "Topic :: Database", "Topic :: Internet :: WWW/HTTP", "License :: OSI Approved :: MIT License", ], **extra_setup_args ) txAWS-0.2.3/PKG-INFO0000664000175000017500000000157711741312025015315 0ustar oubiwannoubiwann00000000000000Metadata-Version: 1.1 Name: txAWS Version: 0.2.3 Summary: Async library for EC2 and Eucalyptus Home-page: https://launchpad.net/txaws Author: txAWS Developers Author-email: txaws-discuss@lists.launchpad.net License: MIT Description: Twisted-based Asynchronous Libraries for Amazon Web Services and Eucalyptus private clouds This project's goal is to have a complete Twisted API representing the spectrum of Amazon's web services as well as support for Eucalyptus clouds. Platform: UNKNOWN Classifier: Development Status :: 4 - Beta Classifier: Intended Audience :: Developers Classifier: Intended Audience :: System Administrators Classifier: Intended Audience :: Information Technology Classifier: Programming Language :: Python Classifier: Topic :: Database Classifier: Topic :: Internet :: WWW/HTTP Classifier: License :: OSI Approved :: MIT License txAWS-0.2.3/txaws/0000775000175000017500000000000011741312025015354 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/testing/0000775000175000017500000000000011741312025017031 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/testing/__init__.py0000664000175000017500000000000011741311335021133 0ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/testing/service.py0000664000175000017500000000275311741311335021055 0ustar oubiwannoubiwann00000000000000from txaws.credentials import AWSCredentials from txaws.service import AWSServiceEndpoint from txaws.testing.ec2 import FakeEC2Client class FakeAWSServiceRegion(object): key_material = "" def __init__(self, access_key="", secret_key="", uri="", ec2_client_factory=None, keypairs=None, security_groups=None, instances=None, volumes=None, snapshots=None, availability_zones=None): self.access_key = access_key self.secret_key = secret_key self.uri = uri self.ec2_client = None if not ec2_client_factory: ec2_client_factory = FakeEC2Client self.ec2_client_factory = ec2_client_factory self.keypairs = keypairs self.security_groups = security_groups self.instances = instances self.volumes = volumes self.snapshots = snapshots self.availability_zones = availability_zones def get_ec2_client(self, *args, **kwds): creds = AWSCredentials(access_key=self.access_key, secret_key=self.secret_key) endpoint = AWSServiceEndpoint(uri=self.uri) self.ec2_client = self.ec2_client_factory( creds, endpoint, instances=self.instances, keypairs=self.keypairs, volumes=self.volumes, key_material=self.key_material, security_groups=self.security_groups, snapshots=self.snapshots, availability_zones=self.availability_zones) return self.ec2_client txAWS-0.2.3/txaws/testing/ec2.py0000664000175000017500000001146211741311335020063 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2009 Canonical Ltd # Copyright (C) 2009 Duncan McGreggor # Licenced under the txaws licence available at /LICENSE in the txaws source. from datetime import datetime from twisted.internet.defer import succeed, fail from twisted.python.failure import Failure from twisted.web.error import Error from txaws.ec2.model import Keypair, SecurityGroup class FakeEC2Client(object): def __init__(self, creds, endpoint, instances=None, keypairs=None, volumes=None, key_material="", security_groups=None, snapshots=None, addresses=None, availability_zones=None, query_factory=None, parser=None): self.creds = creds self.endpoint = endpoint self.query_factory = query_factory self.parser = parser self.instances = instances or [] self.keypairs = keypairs or [] self.keypairs_deleted = [] self.volumes = volumes or [] self.volumes_deleted = [] self.key_material = key_material self.security_groups = security_groups or [] self.security_groups_deleted = [] self.snapshots = snapshots or [] self.snapshots_deleted = [] self.addresses = addresses or [] self.availability_zones = availability_zones or [] def describe_instances(self, *instances): return succeed(self.instances) def run_instances(self, image_id, min_count, max_count, security_groups=None, key_name=None, instance_type=None, user_data=None, availability_zone=None, kernel_id=None, ramdisk_id=None): return succeed(self.instances) def terminate_instances(self, *instance_ids): result = [(instance.instance_id, instance.instance_state, u"shutting-down") for instance in self.instances] return succeed(result) def describe_keypairs(self): return succeed(self.keypairs) def create_keypair(self, name): keypair = Keypair(name, "fingerprint", self.key_material) return succeed(keypair) def delete_keypair(self, name): self.keypairs_deleted.append(name) return succeed(True) def describe_security_groups(self, names=None): return succeed(self.security_groups) def create_security_group(self, name, description): self.security_groups.append(SecurityGroup(name, description)) return succeed(True) def delete_security_group(self, name): self.security_groups_deleted.append(name) return succeed(True) def describe_volumes(self, *volume_ids): return succeed(self.volumes) def create_volume(self, availability_zone, size=None, snapshot_id=None): return succeed(self.volumes[0]) def attach_volume(self, volume_id, instance_id, device): return succeed({"status": u"attaching", "attach_time": datetime(2007, 6, 6, 11, 10, 00)}) def delete_volume(self, volume_id): self.volumes_deleted.append(volume_id) return succeed(True) def describe_snapshots(self, *snapshot_ids): return succeed(self.snapshots) def create_snapshot(self, volume_id): return succeed(self.snapshots[0]) def delete_snapshot(self, volume_id): self.snapshots_deleted.append(volume_id) return succeed(True) def authorize_group_permission(self, group_name, source_group_name, source_group_owner_id): return succeed(True) def revoke_group_permission(self, group_name, source_group_name, source_group_owner_id): return succeed(True) def authorize_ip_permission(self, group_name, protocol, from_port, to_port, cidr_ip): return succeed(True) def revoke_ip_permission(self, group_name, protocol, from_port, to_port, cidr_ip): return succeed(True) def describe_addresses(self, *addresses): return succeed(self.addresses) def allocate_address(self): return succeed(self.addresses[0][0]) def release_address(self, address): return succeed(True) def associate_address(self, instance_id, address): return succeed(True) def disassociate_address(self, address): return succeed(True) def describe_availability_zones(self, *names): return succeed(self.availability_zones) class FakePageGetter(object): def __init__(self, status, payload): self.status = status self.payload = payload def get_page(self, url, *args, **kwds): return succeed(self.payload) def get_page_with_exception(self, url, *args, **kwds): try: raise Error(self.status, "There's been an error", self.payload) except: failure = Failure() return fail(failure) txAWS-0.2.3/txaws/testing/base.py0000664000175000017500000000142711741311335020324 0ustar oubiwannoubiwann00000000000000import os from twisted.trial.unittest import TestCase class TXAWSTestCase(TestCase): """Support for isolation of txaws tests.""" def setUp(self): TestCase.setUp(self) self._stash_environ() def _stash_environ(self): self.orig_environ = dict(os.environ) self.addCleanup(self._restore_environ) if "AWS_ACCESS_KEY_ID" in os.environ: del os.environ["AWS_ACCESS_KEY_ID"] if "AWS_SECRET_ACCESS_KEY" in os.environ: del os.environ["AWS_SECRET_ACCESS_KEY"] if "AWS_ENDPOINT" in os.environ: del os.environ["AWS_ENDPOINT"] def _restore_environ(self): for key in set(os.environ) - set(self.orig_environ): del os.environ[key] os.environ.update(self.orig_environ) txAWS-0.2.3/txaws/testing/payload.py0000664000175000017500000007607611741311335021057 0ustar oubiwannoubiwann00000000000000from txaws import version sample_required_describe_instances_result = """\ 52b4c730-f29f-498d-94c1-91efb75994cc r-cf24b1a6 123456789012 default i-abcdef01 ami-12345678 16 running domU-12-31-39-03-15-11.compute-1.internal\ ec2-75-101-245-65.compute-1.amazonaws.com 10.0.0.1 75.101.245.65 c1.xlarge 2009-04-27T02:23:18.000Z us-east-1c """ % (version.ec2_api,) sample_describe_instances_result = """\ 52b4c730-f29f-498d-94c1-91efb75994cc r-cf24b1a6 123456789012 default i-abcdef01 ami-12345678 16 running domU-12-31-39-03-15-11.compute-1.internal\ ec2-75-101-245-65.compute-1.amazonaws.com 10.0.0.1 75.101.245.65 keyname 0 774F4FF8 c1.xlarge 2009-04-27T02:23:18.000Z us-east-1c aki-b51cf9dc ari-b31cf9da """ % (version.ec2_api,) sample_run_instances_result = """\ r-47a5402e 495219933132 default i-2ba64342 ami-60a54009 0 pending example-key-name 0 m1.small 2007-08-07T11:51:50.000Z us-east-1b i-2bc64242 ami-60a54009 0 pending example-key-name 1 m1.small 2007-08-07T11:51:50.000Z us-east-1b i-2be64332 ami-60a54009 0 pending example-key-name 2 m1.small 2007-08-07T11:51:50.000Z us-east-1b """ % (version.ec2_api,) sample_terminate_instances_result = """\ i-1234 32 shutting-down 16 running i-5678 32 shutting-down 32 shutting-down """ % (version.ec2_api,) sample_describe_security_groups_with_openstack = """\ 7d4e4dbd-0a33-4d3a-864a-b5ce0f1c9cbf 22 tcp 0.0.0.0/0 22 WebServers UYY3TLBUXIEON5NQVUUX6OMPWBZIQNFM WebServers Web servers UYY3TLBUXIEON5NQVUUX6OMPWBZIQNFM """ % (version.ec2_api,) sample_describe_security_groups_result = """\ 52b4c730-f29f-498d-94c1-91efb75994cc UYY3TLBUXIEON5NQVUUX6OMPWBZIQNFM WebServers Web Servers tcp 80 80 0.0.0.0/0 """ % (version.ec2_api,) sample_describe_security_groups_multiple_result = """\ 52b4c730-f29f-498d-94c1-91efb75994cc UYY3TLBUXIEON5NQVUUX6OMPWBZIQNFM MessageServers Message Servers tcp 80 80 0.0.0.0/0 UYY3TLBUXIEON5NQVUUX6OMPWBZIQNFM WebServers Web Servers tcp 80 80 0.0.0.0/0 tcp 0 65535 group-user-id group-name1 udp 0 65535 group-user-id group-name1 icmp -1 -1 group-user-id group-name1 tcp 0 65535 group-user-id group-name2 udp 0 65535 group-user-id group-name2 icmp -1 -1 group-user-id group-name2 """ % (version.ec2_api,) sample_describe_security_groups_multiple_groups = """\ 481987ac-07e2-4f34-99b9-38bcce029ce9 170743011661 web/ssh Web and SSH access icmp -1 -1 170723411662 default 175723011368 test1 tcp 1 65535 170723411662 default 175723011368 test1 udp 1 65535 170723411662 default 175723011368 test1 tcp 22 22 0.0.0.0/0 tcp 80 80 0.0.0.0/0 """ % (version.ec2_api,) sample_create_security_group = """\ true """ % (version.ec2_api,) sample_duplicate_create_security_group_result = """\ InvalidGroup.Duplicate The security group 'group1' already exists. 89c977b5-22da-4c68-9148-9e0ebce5f68e """ sample_invalid_create_security_group_result = """\ InvalidGroup.Reserved Specified group name is a reserved name. 89c977b5-22da-4c68-9148-9e0ebce5f68e """ sample_delete_security_group = """\ true """ % (version.ec2_api,) sample_delete_security_group_failure = """\ InvalidGroup.InUse Group groupID1:GroupReferredTo is used by groups: \ groupID2:UsingGroup 9a6df05f-9c27-47aa-81d8-6619689210cc """ sample_authorize_security_group = """\ true """ % (version.ec2_api,) sample_revoke_security_group = """\ true """ % (version.ec2_api,) sample_describe_volumes_result = """\ vol-4282672b 800 in-use 2008-05-07T11:51:50.000Z us-east-1a snap-12345678 vol-4282672b i-6058a509 800 /dev/sdh attached 2008-05-07T12:51:50.000Z """ % (version.ec2_api,) sample_describe_snapshots_result = """\ snap-78a54011 vol-4d826724 pending 2008-05-07T12:51:50.000Z 80%% """ % (version.ec2_api,) sample_create_volume_result = """\ vol-4d826724 800 creating 2008-05-07T11:51:50.000Z us-east-1a """ % (version.ec2_api,) sample_delete_volume_result = """\ true """ % (version.ec2_api,) sample_create_snapshot_result = """\ snap-78a54011 vol-4d826724 pending 2008-05-07T12:51:50.000Z """ % (version.ec2_api,) sample_delete_snapshot_result = """\ true """ % (version.ec2_api,) sample_attach_volume_result = """\ vol-4d826724 i-6058a509 /dev/sdh attaching 2008-05-07T11:51:50.000Z """ % (version.ec2_api,) sample_ec2_error_message = """\ Error.Code Message for Error.Code 0ef9fc37-6230-4d81-b2e6-1b36277d4247 """ sample_ec2_error_messages = """\ Error.Code1 Message for Error.Code1 Error.Code2 Message for Error.Code2 0ef9fc37-6230-4d81-b2e6-1b36277d4247 """ sample_single_describe_keypairs_result = """\ gsg-keypair 1f:51:ae:28:bf:89:e9:d8:1f:25:5d:37:2d:7d:b8:\ ca:9f:f5:f1:6f """ % (version.ec2_api,) sample_multiple_describe_keypairs_result = """\ gsg-keypair-1 1f:51:ae:28:bf:89:e9:d8:1f:25:5d:37:2d:7d:b8:\ ca:9f:f5:f1:6f gsg-keypair-2 1f:51:ae:28:bf:89:e9:d8:1f:25:5d:37:2d:7d:b8:\ ca:9f:f5:f1:70 """ % (version.ec2_api,) sample_create_keypair_result = """\ example-key-name 1f:51:ae:28:bf:89:e9:d8:1f:25:5d:37:2d:7d:b8:\ ca:9f:f5:f1:6f -----BEGIN RSA PRIVATE KEY----- MIIEoQIBAAKCAQBuLFg5ujHrtm1jnutSuoO8Xe56LlT+HM8v/xkaa39EstM3/aFxTHgElQiJLChp HungXQ29VTc8rc1bW0lkdi23OH5eqkMHGhvEwqa0HWASUMll4o3o/IX+0f2UcPoKCOVUR+jx71Sg 5AU52EQfanIn3ZQ8lFW7Edp5a3q4DhjGlUKToHVbicL5E+g45zfB95wIyywWZfeW/UUF3LpGZyq/ ebIUlq1qTbHkLbCC2r7RTn8vpQWp47BGVYGtGSBMpTRP5hnbzzuqj3itkiLHjU39S2sJCJ0TrJx5 i8BygR4s3mHKBj8l+ePQxG1kGbF6R4yg6sECmXn17MRQVXODNHZbAgMBAAECggEAY1tsiUsIwDl5 91CXirkYGuVfLyLflXenxfI50mDFms/mumTqloHO7tr0oriHDR5K7wMcY/YY5YkcXNo7mvUVD1pM ZNUJs7rw9gZRTrf7LylaJ58kOcyajw8TsC4e4LPbFaHwS1d6K8rXh64o6WgW4SrsB6ICmr1kGQI7 3wcfgt5ecIu4TZf0OE9IHjn+2eRlsrjBdeORi7KiUNC/pAG23I6MdDOFEQRcCSigCj+4/mciFUSA SWS4dMbrpb9FNSIcf9dcLxVM7/6KxgJNfZc9XWzUw77Jg8x92Zd0fVhHOux5IZC+UvSKWB4dyfcI tE8C3p9bbU9VGyY5vLCAiIb4qQKBgQDLiO24GXrIkswF32YtBBMuVgLGCwU9h9HlO9mKAc2m8Cm1 jUE5IpzRjTedc9I2qiIMUTwtgnw42auSCzbUeYMURPtDqyQ7p6AjMujp9EPemcSVOK9vXYL0Ptco xW9MC0dtV6iPkCN7gOqiZXPRKaFbWADp16p8UAIvS/a5XXk5jwKBgQCKkpHi2EISh1uRkhxljyWC iDCiK6JBRsMvpLbc0v5dKwP5alo1fmdR5PJaV2qvZSj5CYNpMAy1/EDNTY5OSIJU+0KFmQbyhsbm rdLNLDL4+TcnT7c62/aH01ohYaf/VCbRhtLlBfqGoQc7+sAc8vmKkesnF7CqCEKDyF/dhrxYdQKB gC0iZzzNAapayz1+JcVTwwEid6j9JqNXbBc+Z2YwMi+T0Fv/P/hwkX/ypeOXnIUcw0Ih/YtGBVAC DQbsz7LcY1HqXiHKYNWNvXgwwO+oiChjxvEkSdsTTIfnK4VSCvU9BxDbQHjdiNDJbL6oar92UN7V rBYvChJZF7LvUH4YmVpHAoGAbZ2X7XvoeEO+uZ58/BGKOIGHByHBDiXtzMhdJr15HTYjxK7OgTZm gK+8zp4L9IbvLGDMJO8vft32XPEWuvI8twCzFH+CsWLQADZMZKSsBasOZ/h1FwhdMgCMcY+Qlzd4 JZKjTSu3i7vhvx6RzdSedXEMNTZWN4qlIx3kR5aHcukCgYA9T+Zrvm1F0seQPbLknn7EqhXIjBaT P8TTvW/6bdPi23ExzxZn7KOdrfclYRph1LHMpAONv/x2xALIf91UB+v5ohy1oDoasL0gij1houRe 2ERKKdwz0ZL9SWq6VTdhr/5G994CK72fy5WhyERbDjUIdHaK3M849JJuf8cSrvSb4g== -----END RSA PRIVATE KEY----- """ % (version.ec2_api,) sample_delete_keypair_true_result = """\ true """ % (version.ec2_api,) sample_delete_keypair_false_result = """\ false """ % (version.ec2_api,) sample_delete_keypair_no_result = """\ """ % (version.ec2_api,) sample_duplicate_keypair_result = """\ InvalidKeyPair.Duplicate The key pair 'key1' already exists. 89c977b5-22da-4c68-9148-9e0ebce5f68e """ sample_import_keypair_result = """\ example-key-name 1f:51:ae:28:bf:89:e9:d8:1f:25:5d:37:2d:7d:b8:\ ca:9f:f5:f1:6f """ % (version.ec2_api,) sample_allocate_address_result = """\ 67.202.55.255 """ % (version.ec2_api,) sample_release_address_result = """\ true """ % (version.ec2_api,) sample_associate_address_result = """\ true """ % (version.ec2_api,) sample_disassociate_address_result = """\ true """ % (version.ec2_api,) sample_describe_addresses_result = """\ i-28a64341 67.202.55.255 67.202.55.233 """ % (version.ec2_api,) sample_describe_availability_zones_single_result = """\ us-east-1a available """ % (version.ec2_api,) sample_describe_availability_zones_multiple_results = """\ us-east-1a available us-east-1b available us-east-1c available """ % (version.ec2_api,) sample_invalid_client_token_result = """\ InvalidClientTokenId The AWS Access Key Id you provided does not exist in our\ records. 47bfd77d-78d6-446d-be0d-f7621795dded """ sample_restricted_resource_result = """\ AuthFailure Unauthorized attempt to access restricted resource a99e832e-e6e0-416a-9a35-81798ea521b4 """ sample_server_internal_error_result = """\ InternalError We encountered an internal error. Please try again. A2A7E5395E27DFBB f691zulHNsUqonsZkjhILnvWwD3ZnmOM4ObM1wXTc6xuS3GzPmjArp8QC/sGsn6K\ """ sample_list_buckets_result = """\ bcaf1ffd86f41caff1a493dc2ad8c2c281e37522a640e161ca5fb16fd081034f webfile quotes 2006-02-03T16:45:09.000Z samples 2006-02-03T16:41:58.000Z """ % (version.s3_api,) sample_get_bucket_result = """\ mybucket N Ned 40 false Nelson 2006-01-01T12:00:00.000Z "828ef3fdfa96f00ad9f27c383fc9ac7f" 5 STANDARD bcaf1ffd86f41caff1a493dc2ad8c2c281e37522a640e161ca5fb16fd081034f webfile Neo 2006-01-01T12:00:00.000Z "828ef3fdfa96f00ad9f27c383fc9ac7f" 4 STANDARD bcaf1ffd86f41caff1a493dc2ad8c2c281e37522a640e161ca5fb16fd081034f webfile """ % (version.s3_api,) sample_get_bucket_location_result = """\ EU\ """ sample_request_payment = """\ Requester """ sample_s3_signature_mismatch = """\ SignatureDoesNotMatch The request signature we calculated does not match the signature\ you provided. Check your key and signing method. 47 45 54 0a 31 42 32 4d 32 59 38 41 73 67 54 70 67 41 6d\ 59 37 50 68 43 66 67 3d 3d 0a 0a 54 68 75 2c 20 30 35 20 4e 6f 76 20 32 30\ 30 39 20 32 31 3a 33 33 3a 32 39 20 47 4d 54 0a 2f AB9216C8640751B2 sAPBpmFdsOsgUUwtSLsiT6KIwP1mPbmrYY0xUoahzJE263qmABkTaqzGhHddgOq5\ ltowhdrbjaQ8dQc9VS5MxzJfsPJZi0BZHEzJC3r9pzU= GET\n1B2M2Y8AsgTpgAmY7PhCfg==\n\nThu, 05 Nov 2009 21:33:29\ GMT\n/ SOMEKEYID """ sample_s3_invalid_access_key_result = """\ InvalidAccessKeyId The AWS Access Key Id you provided does not exist in our records.\ 0223AD81A94821CE HAw5g9P1VkN8ztgLKFTK20CY5LmCfTwXcSths1O7UQV6NuJx2P4tmFnpuOsziwOE\ SOMEKEYID """ sample_access_control_policy_result = """\ 8a6925ce4adf588a4f21c32aa37900beef baz@example.net 8a6925ce4adf588a4f21c32aa379004fef foo@example.net FULL_CONTROL 8a6925ce4adf588a4f21c32aa37900feed bar@example.net READ """ sample_s3_get_bucket_lifecycle_result = """\ 30-day-log-deletion-rule logs Enabled 30 """ sample_s3_get_bucket_lifecycle_multiple_rules_result = """\ 30-day-log-deletion-rule logs Enabled 30 another-id another-logs Disabled 37 """ sample_s3_get_bucket_website_result = """\ index.html 404.html """ sample_s3_get_bucket_website_no_error_result = """\ index.html """ sample_s3_get_bucket_notification_result = """\ """ sample_s3_get_bucket_notification_with_topic_result = """\ arn:aws:sns:us-east-1:123456789012:myTopic s3:ReducedRedundancyLostObject """ sample_s3_get_bucket_versioning_result = """\ """ sample_s3_get_bucket_versioning_enabled_result = """\ Enabled """ sample_s3_get_bucket_versioning_mfa_disabled_result = """\ Enabled Disabled """ txAWS-0.2.3/txaws/exception.py0000664000175000017500000000731011741311335017730 0ustar oubiwannoubiwann00000000000000# Copyright (c) 2009 Canonical Ltd # Licenced under the txaws licence available at /LICENSE in the txaws source. from twisted.web.error import Error from txaws.util import XML class AWSError(Error): """ A base class for txAWS errors. """ def __init__(self, xml_bytes, status, message=None, response=None): super(AWSError, self).__init__(status, message, response) if not xml_bytes: raise ValueError("XML cannot be empty.") self.original = xml_bytes self.errors = [] self.request_id = "" self.host_id = "" self.parse() def __str__(self): return self._get_error_message_string() def __repr__(self): return "<%s object with %s>" % ( self.__class__.__name__, self._get_error_code_string()) def _set_request_id(self, tree): request_id_node = tree.find(".//RequestID") if hasattr(request_id_node, "text"): text = request_id_node.text if text: self.request_id = text def _set_host_id(self, tree): host_id = tree.find(".//HostID") if hasattr(host_id, "text"): text = host_id.text if text: self.host_id = text def _get_error_code_string(self): count = len(self.errors) error_code = self.get_error_codes() if count > 1: return "Error count: %s" % error_code else: return "Error code: %s" % error_code def _get_error_message_string(self): count = len(self.errors) error_message = self.get_error_messages() if count > 1: return "%s." % error_message else: return "Error Message: %s" % error_message def _node_to_dict(self, node): data = {} for child in node: if child.tag and child.text: data[child.tag] = child.text return data def _check_for_html(self, tree): if tree.tag == "html": message = "Could not parse HTML in the response." raise AWSResponseParseError(message) def _set_400_error(self, tree): """ This method needs to be implemented by subclasses. """ def _set_500_error(self, tree): self._set_request_id(tree) self._set_host_id(tree) data = self._node_to_dict(tree) if data: self.errors.append(data) def parse(self, xml_bytes=""): if not xml_bytes: xml_bytes = self.original self.original = xml_bytes tree = XML(xml_bytes.strip()) self._check_for_html(tree) self._set_request_id(tree) if self.status: status = int(self.status) else: status = 400 if status >= 500: self._set_500_error(tree) else: self._set_400_error(tree) def has_error(self, errorString): for error in self.errors: if errorString in error.values(): return True return False def get_error_codes(self): count = len(self.errors) if count > 1: return count elif count == 0: return else: return self.errors[0]["Code"] def get_error_messages(self): count = len(self.errors) if count > 1: return "Multiple EC2 Errors" elif count == 0: return "Empty error list" else: return self.errors[0]["Message"] class AWSResponseParseError(Exception): """ txAWS was unable to parse the server response. """ class CertsNotFoundError(Exception): """ txAWS was not able to find any SSL certificates. """ txAWS-0.2.3/txaws/credentials.py0000664000175000017500000000312111741311335020223 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2009 Robert Collins # Licenced under the txaws licence available at /LICENSE in the txaws source. """Credentials for accessing AWS services.""" import os from txaws.util import hmac_sha256, hmac_sha1 __all__ = ["AWSCredentials"] ENV_ACCESS_KEY = "AWS_ACCESS_KEY_ID" ENV_SECRET_KEY = "AWS_SECRET_ACCESS_KEY" class AWSCredentials(object): """Create an AWSCredentials object. @param access_key: The access key to use. If None the environment variable AWS_ACCESS_KEY_ID is consulted. @param secret_key: The secret key to use. If None the environment variable AWS_SECRET_ACCESS_KEY is consulted. """ def __init__(self, access_key="", secret_key=""): self.access_key = access_key self.secret_key = secret_key # perform checks for access key if not self.access_key: self.access_key = os.environ.get(ENV_ACCESS_KEY) if not self.access_key: raise ValueError("Could not find %s" % ENV_ACCESS_KEY) # perform checks for secret key if not self.secret_key: self.secret_key = os.environ.get(ENV_SECRET_KEY) if not self.secret_key: raise ValueError("Could not find %s" % ENV_SECRET_KEY) def sign(self, bytes, hash_type="sha256"): """Sign some bytes.""" if hash_type == "sha256": return hmac_sha256(self.secret_key, bytes) elif hash_type == "sha1": return hmac_sha1(self.secret_key, bytes) else: raise RuntimeError("Unsupported hash type: '%s'" % hash_type) txAWS-0.2.3/txaws/__init__.py0000664000175000017500000000000011741311335017456 0ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/util.py0000664000175000017500000000520411741311335016707 0ustar oubiwannoubiwann00000000000000"""Generally useful utilities for AWS web services not specific to a service. New things in this module should be of relevance to more than one of Amazon's services. """ from base64 import b64encode from hashlib import sha1, md5, sha256 import hmac from urlparse import urlparse, urlunparse import time # Import XMLTreeBuilder from somewhere; here in one place to prevent # duplication. try: from xml.etree.ElementTree import XMLTreeBuilder except ImportError: from elementtree.ElementTree import XMLTreeBuilder __all__ = ["hmac_sha1", "hmac_sha256", "iso8601time", "calculate_md5", "XML"] def calculate_md5(data): digest = md5(data).digest() return b64encode(digest) def hmac_sha1(secret, data): digest = hmac.new(secret, data, sha1).digest() return b64encode(digest) def hmac_sha256(secret, data): digest = hmac.new(secret, data, sha256).digest() return b64encode(digest) def iso8601time(time_tuple): """Format time_tuple as a ISO8601 time string. :param time_tuple: Either None, to use the current time, or a tuple tuple. """ if time_tuple: return time.strftime("%Y-%m-%dT%H:%M:%SZ", time_tuple) else: return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) class NamespaceFixXmlTreeBuilder(XMLTreeBuilder): def _fixname(self, key): if "}" in key: key = key.split("}", 1)[1] return key def XML(text): parser = NamespaceFixXmlTreeBuilder() parser.feed(text) return parser.close() def parse(url, defaultPort=True): """ Split the given URL into the scheme, host, port, and path. @type url: C{str} @param url: An URL to parse. @type defaultPort: C{bool} @param defaultPort: Whether to return the default port associated with the scheme in the given url, when the url doesn't specify one. @return: A four-tuple of the scheme, host, port, and path of the URL. All of these are C{str} instances except for port, which is an C{int}. """ url = url.strip() parsed = urlparse(url) scheme = parsed[0] path = urlunparse(("", "") + parsed[2:]) host = parsed[1] if ":" in host: host, port = host.split(":") try: port = int(port) except ValueError: # A non-numeric port was given, it will be replaced with # an appropriate default value if defaultPort is True port = None else: port = None if port is None and defaultPort: if scheme == "https": port = 443 else: port = 80 if path == "": path = "/" return (str(scheme), str(host), port, str(path)) txAWS-0.2.3/txaws/version.py0000664000175000017500000000007511741311335017420 0ustar oubiwannoubiwann00000000000000txaws = "0.2.3" ec2_api = "2009-11-30" s3_api = "2006-03-01" txAWS-0.2.3/txaws/s3/0000775000175000017500000000000011741312025015701 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/s3/exception.py0000664000175000017500000000130211741311335020250 0ustar oubiwannoubiwann00000000000000# Copyright (c) 2009 Canonical Ltd # Licenced under the txaws licence available at /LICENSE in the txaws source. from txaws.exception import AWSError class S3Error(AWSError): """ A error class providing custom methods on S3 errors. """ def _set_400_error(self, tree): if tree.tag.lower() == "error": data = self._node_to_dict(tree) if data: self.errors.append(data) def get_error_code(self, *args, **kwargs): return super(S3Error, self).get_error_codes(*args, **kwargs) def get_error_message(self, *args, **kwargs): return super(S3Error, self).get_error_messages(*args, **kwargs) txAWS-0.2.3/txaws/s3/__init__.py0000664000175000017500000000000011741311335020003 0ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/s3/client.py0000664000175000017500000004761011741311335017544 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2008 Tristan Seligmann # Copyright (C) 2009 Canonical Ltd # Copyright (C) 2009 Duncan McGreggor # Copyright (C) 2012 New Dream Network (DreamHost) # Licenced under the txaws licence available at /LICENSE in the txaws source. """ Client wrapper for Amazon's Simple Storage Service. API stability: unstable. Various API-incompatible changes are planned in order to expose missing functionality in this wrapper. """ import mimetypes from twisted.web.http import datetimeToString from dateutil.parser import parse as parseTime from txaws.client.base import BaseClient, BaseQuery, error_wrapper from txaws.s3.acls import AccessControlPolicy from txaws.s3.model import ( Bucket, BucketItem, BucketListing, ItemOwner, LifecycleConfiguration, LifecycleConfigurationRule, NotificationConfiguration, RequestPayment, VersioningConfiguration, WebsiteConfiguration) from txaws.s3.exception import S3Error from txaws.service import AWSServiceEndpoint, S3_ENDPOINT from txaws.util import XML, calculate_md5 def s3_error_wrapper(error): error_wrapper(error, S3Error) class URLContext(object): """ The hosts and the paths that form an S3 endpoint change depending upon the context in which they are called. While S3 supports bucket names in the host name, we use the convention of providing it in the path so that using IP addresses and alternative implementations of S3 actually works (e.g. Walrus). """ def __init__(self, service_endpoint, bucket="", object_name=""): self.endpoint = service_endpoint self.bucket = bucket self.object_name = object_name def get_host(self): return self.endpoint.get_host() def get_path(self): path = "/" if self.bucket is not None: path += self.bucket if self.bucket is not None and self.object_name: if not self.object_name.startswith("/"): path += "/" path += self.object_name elif self.bucket is not None and not path.endswith("/"): path += "/" return path def get_url(self): if self.endpoint.port is not None: return "%s://%s:%d%s" % ( self.endpoint.scheme, self.get_host(), self.endpoint.port, self.get_path()) else: return "%s://%s%s" % ( self.endpoint.scheme, self.get_host(), self.get_path()) class S3Client(BaseClient): """A client for S3.""" def __init__(self, creds=None, endpoint=None, query_factory=None): if query_factory is None: query_factory = Query super(S3Client, self).__init__(creds, endpoint, query_factory) def list_buckets(self): """ List all buckets. Returns a list of all the buckets owned by the authenticated sender of the request. """ query = self.query_factory( action="GET", creds=self.creds, endpoint=self.endpoint) d = query.submit() return d.addCallback(self._parse_list_buckets) def _parse_list_buckets(self, xml_bytes): """ Parse XML bucket list response. """ root = XML(xml_bytes) buckets = [] for bucket_data in root.find("Buckets"): name = bucket_data.findtext("Name") date_text = bucket_data.findtext("CreationDate") date_time = parseTime(date_text) bucket = Bucket(name, date_time) buckets.append(bucket) return buckets def create_bucket(self, bucket): """ Create a new bucket. """ query = self.query_factory( action="PUT", creds=self.creds, endpoint=self.endpoint, bucket=bucket) return query.submit() def delete_bucket(self, bucket): """ Delete a bucket. The bucket must be empty before it can be deleted. """ query = self.query_factory( action="DELETE", creds=self.creds, endpoint=self.endpoint, bucket=bucket) return query.submit() def get_bucket(self, bucket): """ Get a list of all the objects in a bucket. """ query = self.query_factory( action="GET", creds=self.creds, endpoint=self.endpoint, bucket=bucket) d = query.submit() return d.addCallback(self._parse_get_bucket) def _parse_get_bucket(self, xml_bytes): root = XML(xml_bytes) name = root.findtext("Name") prefix = root.findtext("Prefix") marker = root.findtext("Marker") max_keys = root.findtext("MaxKeys") is_truncated = root.findtext("IsTruncated") contents = [] for content_data in root.findall("Contents"): key = content_data.findtext("Key") date_text = content_data.findtext("LastModified") modification_date = parseTime(date_text) etag = content_data.findtext("ETag") size = content_data.findtext("Size") storage_class = content_data.findtext("StorageClass") owner_id = content_data.findtext("Owner/ID") owner_display_name = content_data.findtext("Owner/DisplayName") owner = ItemOwner(owner_id, owner_display_name) content_item = BucketItem(key, modification_date, etag, size, storage_class, owner) contents.append(content_item) common_prefixes = [] for prefix_data in root.findall("CommonPrefixes"): common_prefixes.append(prefix_data.text) return BucketListing(name, prefix, marker, max_keys, is_truncated, contents, common_prefixes) def get_bucket_location(self, bucket): """ Get the location (region) of a bucket. @param bucket: The name of the bucket. @return: A C{Deferred} that will fire with the bucket's region. """ query = self.query_factory(action="GET", creds=self.creds, endpoint=self.endpoint, bucket=bucket, object_name="?location") d = query.submit() return d.addCallback(self._parse_bucket_location) def _parse_bucket_location(self, xml_bytes): """Parse a C{LocationConstraint} XML document.""" root = XML(xml_bytes) return root.text or "" def get_bucket_lifecycle(self, bucket): """ Get the lifecycle configuration of a bucket. @param bucket: The name of the bucket. @return: A C{Deferred} that will fire with the bucket's lifecycle configuration. """ query = self.query_factory( action='GET', creds=self.creds, endpoint=self.endpoint, bucket=bucket, object_name='?lifecycle') return query.submit().addCallback(self._parse_lifecycle_config) def _parse_lifecycle_config(self, xml_bytes): """Parse a C{LifecycleConfiguration} XML document.""" root = XML(xml_bytes) rules = [] for content_data in root.findall("Rule"): id = content_data.findtext("ID") prefix = content_data.findtext("Prefix") status = content_data.findtext("Status") expiration = int(content_data.findtext("Expiration/Days")) rules.append( LifecycleConfigurationRule(id, prefix, status, expiration)) return LifecycleConfiguration(rules) def get_bucket_website_config(self, bucket): """ Get the website configuration of a bucket. @param bucket: The name of the bucket. @return: A C{Deferred} that will fire with the bucket's website configuration. """ query = self.query_factory( action='GET', creds=self.creds, endpoint=self.endpoint, bucket=bucket, object_name='?website') return query.submit().addCallback(self._parse_website_config) def _parse_website_config(self, xml_bytes): """Parse a C{WebsiteConfiguration} XML document.""" root = XML(xml_bytes) index_suffix = root.findtext("IndexDocument/Suffix") error_key = root.findtext("ErrorDocument/Key") return WebsiteConfiguration(index_suffix, error_key) def get_bucket_notification_config(self, bucket): """ Get the notification configuration of a bucket. @param bucket: The name of the bucket. @return: A C{Deferred} that will request the bucket's notification configuration. """ query = self.query_factory( action='GET', creds=self.creds, endpoint=self.endpoint, bucket=bucket, object_name='?notification') return query.submit().addCallback(self._parse_notification_config) def _parse_notification_config(self, xml_bytes): """Parse a C{NotificationConfiguration} XML document.""" root = XML(xml_bytes) topic = root.findtext("TopicConfiguration/Topic") event = root.findtext("TopicConfiguration/Event") return NotificationConfiguration(topic, event) def get_bucket_versioning_config(self, bucket): """ Get the versioning configuration of a bucket. @param bucket: The name of the bucket. @return: A C{Deferred} that will request the bucket's versioning configuration. """ query = self.query_factory( action='GET', creds=self.creds, endpoint=self.endpoint, bucket=bucket, object_name='?versioning') return query.submit().addCallback(self._parse_versioning_config) def _parse_versioning_config(self, xml_bytes): """Parse a C{VersioningConfiguration} XML document.""" root = XML(xml_bytes) mfa_delete = root.findtext("MfaDelete") status = root.findtext("Status") return VersioningConfiguration(mfa_delete=mfa_delete, status=status) def get_bucket_acl(self, bucket): """ Get the access control policy for a bucket. """ query = self.query_factory( action='GET', creds=self.creds, endpoint=self.endpoint, bucket=bucket, object_name='?acl') return query.submit().addCallback(self._parse_acl) def put_bucket_acl(self, bucket, access_control_policy): """ Set access control policy on a bucket. """ data = access_control_policy.to_xml() query = self.query_factory( action='PUT', creds=self.creds, endpoint=self.endpoint, bucket=bucket, object_name='?acl', data=data) return query.submit().addCallback(self._parse_acl) def _parse_acl(self, xml_bytes): """ Parse an C{AccessControlPolicy} XML document and convert it into an L{AccessControlPolicy} instance. """ return AccessControlPolicy.from_xml(xml_bytes) def put_object(self, bucket, object_name, data, content_type=None, metadata={}, amz_headers={}): """ Put an object in a bucket. An existing object with the same name will be replaced. @param bucket: The name of the bucket. @param object: The name of the object. @param data: The data to write. @param content_type: The type of data being written. @param metadata: A C{dict} used to build C{x-amz-meta-*} headers. @param amz_headers: A C{dict} used to build C{x-amz-*} headers. @return: A C{Deferred} that will fire with the result of request. """ query = self.query_factory( action="PUT", creds=self.creds, endpoint=self.endpoint, bucket=bucket, object_name=object_name, data=data, content_type=content_type, metadata=metadata, amz_headers=amz_headers) return query.submit() def copy_object(self, source_bucket, source_object_name, dest_bucket=None, dest_object_name=None, metadata={}, amz_headers={}): """ Copy an object stored in S3 from a source bucket to a destination bucket. @param source_bucket: The S3 bucket to copy the object from. @param source_object_name: The name of the object to copy. @param dest_bucket: Optionally, the S3 bucket to copy the object to. Defaults to C{source_bucket}. @param dest_object_name: Optionally, the name of the new object. Defaults to C{source_object_name}. @param metadata: A C{dict} used to build C{x-amz-meta-*} headers. @param amz_headers: A C{dict} used to build C{x-amz-*} headers. @return: A C{Deferred} that will fire with the result of request. """ dest_bucket = dest_bucket or source_bucket dest_object_name = dest_object_name or source_object_name amz_headers["copy-source"] = "/%s/%s" % (source_bucket, source_object_name) query = self.query_factory( action="PUT", creds=self.creds, endpoint=self.endpoint, bucket=dest_bucket, object_name=dest_object_name, metadata=metadata, amz_headers=amz_headers) return query.submit() def get_object(self, bucket, object_name): """ Get an object from a bucket. """ query = self.query_factory( action="GET", creds=self.creds, endpoint=self.endpoint, bucket=bucket, object_name=object_name) return query.submit() def head_object(self, bucket, object_name): """ Retrieve object metadata only. """ query = self.query_factory( action="HEAD", creds=self.creds, endpoint=self.endpoint, bucket=bucket, object_name=object_name) d = query.submit() return d.addCallback(query.get_response_headers) def delete_object(self, bucket, object_name): """ Delete an object from a bucket. Once deleted, there is no method to restore or undelete an object. """ query = self.query_factory( action="DELETE", creds=self.creds, endpoint=self.endpoint, bucket=bucket, object_name=object_name) return query.submit() def put_object_acl(self, bucket, object_name, access_control_policy): """ Set access control policy on an object. """ data = access_control_policy.to_xml() query = self.query_factory( action='PUT', creds=self.creds, endpoint=self.endpoint, bucket=bucket, object_name='%s?acl' % object_name, data=data) return query.submit().addCallback(self._parse_acl) def get_object_acl(self, bucket, object_name): """ Get the access control policy for an object. """ query = self.query_factory( action='GET', creds=self.creds, endpoint=self.endpoint, bucket=bucket, object_name='%s?acl' % object_name) return query.submit().addCallback(self._parse_acl) def put_request_payment(self, bucket, payer): """ Set request payment configuration on bucket to payer. @param bucket: The name of the bucket. @param payer: The name of the payer. @return: A C{Deferred} that will fire with the result of the request. """ data = RequestPayment(payer).to_xml() query = self.query_factory( action="PUT", creds=self.creds, endpoint=self.endpoint, bucket=bucket, object_name="?requestPayment", data=data) return query.submit() def get_request_payment(self, bucket): """ Get the request payment configuration on a bucket. @param bucket: The name of the bucket. @return: A C{Deferred} that will fire with the name of the payer. """ query = self.query_factory( action="GET", creds=self.creds, endpoint=self.endpoint, bucket=bucket, object_name="?requestPayment") return query.submit().addCallback(self._parse_get_request_payment) def _parse_get_request_payment(self, xml_bytes): """ Parse a C{RequestPaymentConfiguration} XML document and extract the payer. """ return RequestPayment.from_xml(xml_bytes).payer class Query(BaseQuery): """A query for submission to the S3 service.""" def __init__(self, bucket=None, object_name=None, data="", content_type=None, metadata={}, amz_headers={}, *args, **kwargs): super(Query, self).__init__(*args, **kwargs) self.bucket = bucket self.object_name = object_name self.data = data self.content_type = content_type self.metadata = metadata self.amz_headers = amz_headers self.date = datetimeToString() if not self.endpoint or not self.endpoint.host: self.endpoint = AWSServiceEndpoint(S3_ENDPOINT) self.endpoint.set_method(self.action) def set_content_type(self): """ Set the content type based on the file extension used in the object name. """ if self.object_name and not self.content_type: # XXX nothing is currently done with the encoding... we may # need to in the future self.content_type, encoding = mimetypes.guess_type( self.object_name, strict=False) def get_headers(self): """ Build the list of headers needed in order to perform S3 operations. """ headers = {"Content-Length": len(self.data), "Content-MD5": calculate_md5(self.data), "Date": self.date} for key, value in self.metadata.iteritems(): headers["x-amz-meta-" + key] = value for key, value in self.amz_headers.iteritems(): headers["x-amz-" + key] = value # Before we check if the content type is set, let's see if we can set # it by guessing the the mimetype. self.set_content_type() if self.content_type is not None: headers["Content-Type"] = self.content_type if self.creds is not None: signature = self.sign(headers) headers["Authorization"] = "AWS %s:%s" % ( self.creds.access_key, signature) return headers def get_canonicalized_amz_headers(self, headers): """ Get the headers defined by Amazon S3. """ headers = [ (name.lower(), value) for name, value in headers.iteritems() if name.lower().startswith("x-amz-")] headers.sort() # XXX missing spec implementation: # 1) txAWS doesn't currently combine headers with the same name # 2) txAWS doesn't currently unfold long headers return "".join("%s:%s\n" % (name, value) for name, value in headers) def get_canonicalized_resource(self): """ Get an S3 resource path. """ path = "/" if self.bucket is not None: path += self.bucket if self.bucket is not None and self.object_name: if not self.object_name.startswith("/"): path += "/" path += self.object_name elif self.bucket is not None and not path.endswith("/"): path += "/" return path def sign(self, headers): """Sign this query using its built in credentials.""" text = (self.action + "\n" + headers.get("Content-MD5", "") + "\n" + headers.get("Content-Type", "") + "\n" + headers.get("Date", "") + "\n" + self.get_canonicalized_amz_headers(headers) + self.get_canonicalized_resource()) return self.creds.sign(text, hash_type="sha1") def submit(self, url_context=None): """Submit this query. @return: A deferred from get_page """ if not url_context: url_context = URLContext( self.endpoint, self.bucket, self.object_name) d = self.get_page( url_context.get_url(), method=self.action, postdata=self.data, headers=self.get_headers()) return d.addErrback(s3_error_wrapper) txAWS-0.2.3/txaws/s3/acls.py0000664000175000017500000001050411741311335017200 0ustar oubiwannoubiwann00000000000000from txaws.util import XML PERMISSIONS = ("FULL_CONTROL", "WRITE", "WRITE_ACP", "READ", "READ_ACP") class XMLMixin(object): def to_xml(self): return "".join(self._to_xml()) class AccessControlPolicy(XMLMixin): def __init__(self, owner=None, access_control_list=()): self.owner = owner self.access_control_list = access_control_list def _to_xml(self, buffer=None): if buffer is None: buffer = [] buffer.append("\n") if self.owner: self.owner._to_xml(buffer=buffer, indent=1) buffer.append(" \n") for grant in self.access_control_list: grant._to_xml(buffer=buffer, indent=2) buffer.append(" \n" "") return buffer @classmethod def from_xml(cls, xml_bytes): root = XML(xml_bytes) owner_node = root.find("Owner") owner = Owner(owner_node.findtext("ID"), owner_node.findtext("DisplayName")) acl_node = root.find("AccessControlList") acl = [] for grant_node in acl_node.findall("Grant"): grantee_node = grant_node.find("Grantee") grantee = Grantee(grantee_node.findtext("ID"), grantee_node.findtext("DisplayName")) permission = grant_node.findtext("Permission") acl.append(Grant(grantee, permission)) return cls(owner, acl) class Grant(XMLMixin): def __init__(self, grantee, permission=None): self.grantee = grantee self.permission = permission def _set_permission(self, perm): if perm not in PERMISSIONS: raise ValueError("Invalid permission '%s'. Must be one of %s" % (perm, ",".join(PERMISSIONS))) self._permission = perm def _get_permission(self): return self._permission permission = property(_get_permission, _set_permission) def _to_xml(self, buffer=None, indent=0): if buffer is None: buffer = [] ws = " " * (indent * 2) buffer.append(ws + "\n") if self.grantee: self.grantee._to_xml(buffer, indent + 1) if self.permission: buffer.append("%s %s\n" % ( ws, self.permission)) buffer.append(ws + "\n") return buffer class Owner(XMLMixin): def __init__(self, id, display_name): self.id = id self.display_name = display_name def _to_xml(self, buffer=None, indent=0): if buffer is None: buffer = [] ws = " " * (indent * 2) buffer.append("%s\n" "%s %s\n" "%s %s\n" "%s\n" % (ws, ws, self.id, ws, self.display_name, ws)) return buffer class Grantee(XMLMixin): def __init__(self, id="", display_name="", email_address="", uri=""): if id or display_name: msg = "Both 'id' and 'display_name' must be provided." if not (id and display_name): raise ValueError(msg) self.id = id self.display_name = display_name self.email_address = email_address self.uri = uri def _to_xml(self, buffer=None, indent=0): if buffer is None: buffer = [] ws = " " * (indent * 2) if self.id and self.display_name: xsi_type = "CanonicalUser" value = ("%s %s\n" "%s %s\n" % ( ws, self.id, ws, self.display_name)) elif self.email_address: xsi_type = "AmazonCustomerByEmail" value = "%s %s\n" % ( ws, self.email_address) elif self.uri: xsi_type = "Group" value = "%s %s\n" % (ws, self.uri) buffer.append("%s\n' "%s%s\n" % (ws, xsi_type, value, ws)) return buffer txAWS-0.2.3/txaws/s3/model.py0000664000175000017500000001040411741311335017355 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2009 Canonical Ltd # Copyright (C) 2009 Duncan McGreggor # Copyright (C) 2011 Drew Smathers # Copyright (C) 2012 New Dream Network (DreamHost) # Licenced under the txaws licence available at /LICENSE in the txaws source. from txaws.util import XML class Bucket(object): """ An Amazon S3 storage bucket. """ def __init__(self, name, creation_date): self.name = name self.creation_date = creation_date class ItemOwner(object): """ The owner of a content item. """ def __init__(self, id, display_name): self.id = id self.display_name = display_name class BucketItem(object): """ The contents of an Amazon S3 bucket. """ def __init__(self, key, modification_date, etag, size, storage_class, owner=None): self.key = key self.modification_date = modification_date self.etag = etag self.size = size self.storage_class = storage_class self.owner = owner class BucketListing(object): """ A mapping for the data in a bucket listing. """ def __init__(self, name, prefix, marker, max_keys, is_truncated, contents=None, common_prefixes=None): self.name = name self.prefix = prefix self.marker = marker self.max_keys = max_keys self.is_truncated = is_truncated self.contents = contents self.common_prefixes = common_prefixes class LifecycleConfiguration(object): """ Returns the lifecycle configuration information set on the bucket. """ def __init__(self, rules): self.rules = rules class LifecycleConfigurationRule(object): """ Container for elements that describe a lifecycle rule. """ def __init__(self, id, prefix, status, expiration): self.id = id self.prefix = prefix self.status = status self.expiration = expiration class WebsiteConfiguration(object): """ A mapping for the data in a bucket website configuration. """ def __init__(self, index_suffix, error_key=None): self.index_suffix = index_suffix self.error_key = error_key class NotificationConfiguration(object): """ A mapping for the data in a bucket notification configuration. """ def __init__(self, topic=None, event=None): self.topic = topic self.event = event class VersioningConfiguration(object): """ Container for the bucket versioning configuration. According to Amazon: C{MfaDelete}: This element is only returned if the bucket has been configured with C{MfaDelete}. If the bucket has never been so configured, this element is not returned. The possible values are None, "Disabled" or "Enabled". C{Status}: If the bucket has never been so configured, this element is not returned. The possible values are None, "Suspended" or "Enabled". """ def __init__(self, mfa_delete=None, status=None): self.mfa_delete = mfa_delete self.status = status class FileChunk(object): """ An Amazon S3 file chunk. S3 returns file chunks, 10 MB at a time, until the entire file is returned. These chunks need to be assembled once they are all returned. """ class RequestPayment(object): """ A payment request. @param payer: One of 'Requester' or 'BucketOwner'. """ payer_choices = ("Requester", "BucketOwner") def __init__(self, payer): if payer not in self.payer_choices: raise ValueError("Invalid value for payer: `%s`. Must be one of " "%s." % (payer, ",".join(self.payer_choices))) self.payer = payer def to_xml(self): """ Convert this request into a C{RequestPaymentConfiguration} XML document. """ return ("\n' " %s\n" "" % self.payer) @classmethod def from_xml(cls, xml_bytes): """ Create an instance from a C{RequestPaymentConfiguration} XML document. """ root = XML(xml_bytes) return cls(root.findtext("Payer")) txAWS-0.2.3/txaws/s3/tests/0000775000175000017500000000000011741312025017043 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/s3/tests/test_acls.py0000664000175000017500000001143011741311335021400 0ustar oubiwannoubiwann00000000000000from twisted.trial.unittest import TestCase from txaws.testing import payload from txaws.s3 import acls class ACLTestCase(TestCase): def test_owner_to_xml(self): owner = acls.Owner(id='8a6925ce4adf588a4f21c32aa379004fef', display_name='BucketOwnersEmail@amazon.com') xml_bytes = owner.to_xml() self.assertEquals(xml_bytes, """\ 8a6925ce4adf588a4f21c32aa379004fef BucketOwnersEmail@amazon.com """) def test_grantee_canonical_missing_parameter(self): self.assertRaises( ValueError, acls.Grantee, {'id': '8a6925ce4adf588a4f21c32aa379004fef'}) self.assertRaises( ValueError, acls.Grantee, {'display_name': 'BucketOwnersEmail@amazon.com'}) def test_grantee_canonical_to_xml(self): grantee = acls.Grantee(id='8a6925ce4adf588a4f21c32aa379004fef', display_name='BucketOwnersEmail@amazon.com') xml_bytes = grantee.to_xml() self.assertEquals(xml_bytes, """\ 8a6925ce4adf588a4f21c32aa379004fef BucketOwnersEmail@amazon.com """) def test_grantee_email_to_xml(self): grantee = acls.Grantee(email_address="BucketOwnersEmail@amazon.com") xml_bytes = grantee.to_xml() self.assertEquals(xml_bytes, """\ BucketOwnersEmail@amazon.com """) def test_grantee_uri_to_xml(self): grantee = acls.Grantee( uri='http://acs.amazonaws.com/groups/global/AuthenticatedUsers') xml_bytes = grantee.to_xml() self.assertEquals(xml_bytes, """\ http://acs.amazonaws.com/groups/global/AuthenticatedUsers """) def test_grant_to_xml(self): grantee = acls.Grantee(id='8a6925ce4adf588a4f21c32aa379004fef', display_name='BucketOwnersEmail@amazon.com') grant = acls.Grant(grantee, 'FULL_CONTROL') xml_bytes = grant.to_xml() self.assertEquals(xml_bytes, """\ 8a6925ce4adf588a4f21c32aa379004fef BucketOwnersEmail@amazon.com FULL_CONTROL """) def test_access_control_policy_to_xml(self): grantee = acls.Grantee(id='8a6925ce4adf588a4f21c32aa379004fef', display_name='foo@example.net') grant1 = acls.Grant(grantee, 'FULL_CONTROL') grantee = acls.Grantee(id='8a6925ce4adf588a4f21c32aa37900feed', display_name='bar@example.net') grant2 = acls.Grant(grantee, 'READ') owner = acls.Owner(id='8a6925ce4adf588a4f21c32aa37900beef', display_name='baz@example.net') acp = acls.AccessControlPolicy(owner=owner, access_control_list=[grant1, grant2]) xml_bytes = acp.to_xml() self.assertEquals(xml_bytes, payload.sample_access_control_policy_result) def test_permission_enum(self): grantee = acls.Grantee(id='8a6925ce4adf588a4f21c32aa379004fef', display_name='BucketOwnersEmail@amazon.com') acls.Grant(grantee, 'FULL_CONTROL') acls.Grant(grantee, 'WRITE') acls.Grant(grantee, 'WRITE_ACP') acls.Grant(grantee, 'READ') acls.Grant(grantee, 'READ_ACP') self.assertRaises(ValueError, acls.Grant, grantee, 'GO_HOG_WILD') def test_from_xml(self): policy = acls.AccessControlPolicy.from_xml( payload.sample_access_control_policy_result) self.assertEquals(policy.owner.id, '8a6925ce4adf588a4f21c32aa37900beef') self.assertEquals(policy.owner.display_name, 'baz@example.net') self.assertEquals(len(policy.access_control_list), 2) grant1 = policy.access_control_list[0] self.assertEquals(grant1.grantee.id, '8a6925ce4adf588a4f21c32aa379004fef') self.assertEquals(grant1.grantee.display_name, 'foo@example.net') self.assertEquals(grant1.permission, 'FULL_CONTROL') grant2 = policy.access_control_list[1] self.assertEquals(grant2.grantee.id, '8a6925ce4adf588a4f21c32aa37900feed') self.assertEquals(grant2.grantee.display_name, 'bar@example.net') self.assertEquals(grant2.permission, 'READ') txAWS-0.2.3/txaws/s3/tests/__init__.py0000664000175000017500000000000011741311335021145 0ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/s3/tests/test_exception.py0000664000175000017500000000452111741311335022457 0ustar oubiwannoubiwann00000000000000# Copyright (c) 2009 Canonical Ltd # Licenced under the txaws licence available at /LICENSE in the txaws source. from twisted.trial.unittest import TestCase from txaws.s3.exception import S3Error from txaws.testing import payload from txaws.util import XML REQUEST_ID = "0ef9fc37-6230-4d81-b2e6-1b36277d4247" class S3ErrorTestCase(TestCase): def test_set_400_error(self): xml = "12" error = S3Error("", 400) error._set_400_error(XML(xml)) self.assertEquals(error.errors[0]["Code"], "1") self.assertEquals(error.errors[0]["Message"], "2") def test_get_error_code(self): error = S3Error(payload.sample_s3_invalid_access_key_result, 400) self.assertEquals(error.get_error_code(), "InvalidAccessKeyId") def test_get_error_message(self): error = S3Error(payload.sample_s3_invalid_access_key_result, 400) self.assertEquals( error.get_error_message(), ("The AWS Access Key Id you provided does not exist in our " "records.")) def test_error_count(self): error = S3Error(payload.sample_s3_invalid_access_key_result, 400) self.assertEquals(len(error.errors), 1) def test_error_repr(self): error = S3Error(payload.sample_s3_invalid_access_key_result, 400) self.assertEquals( repr(error), "") def test_signature_mismatch_result(self): error = S3Error(payload.sample_s3_signature_mismatch, 400) self.assertEquals( error.get_error_messages(), ("The request signature we calculated does not match the " "signature you provided. Check your key and signing method.")) def test_invalid_access_key_result(self): error = S3Error(payload.sample_s3_invalid_access_key_result, 400) self.assertEquals( error.get_error_messages(), ("The AWS Access Key Id you provided does not exist in our " "records.")) def test_internal_error_result(self): error = S3Error(payload.sample_server_internal_error_result, 400) self.assertEquals( error.get_error_messages(), "We encountered an internal error. Please try again.") txAWS-0.2.3/txaws/s3/tests/test_client.py0000664000175000017500000014022711741311335021743 0ustar oubiwannoubiwann00000000000000from twisted.internet.defer import succeed from txaws.credentials import AWSCredentials try: from txaws.s3 import client except ImportError: s3clientSkip = ("S3Client couldn't be imported (perhaps because dateutil, " "on which it depends, isn't present)") else: s3clientSkip = None from txaws.s3.acls import AccessControlPolicy from txaws.s3.model import RequestPayment from txaws.service import AWSServiceEndpoint from txaws.testing import payload from txaws.testing.base import TXAWSTestCase from txaws.util import calculate_md5 class URLContextTestCase(TXAWSTestCase): endpoint = AWSServiceEndpoint("https://s3.amazonaws.com/") def test_get_host_with_no_bucket(self): url_context = client.URLContext(self.endpoint) self.assertEquals(url_context.get_host(), "s3.amazonaws.com") def test_get_host_with_bucket(self): url_context = client.URLContext(self.endpoint, "mystuff") self.assertEquals(url_context.get_host(), "s3.amazonaws.com") def test_get_path_with_no_bucket(self): url_context = client.URLContext(self.endpoint) self.assertEquals(url_context.get_path(), "/") def test_get_path_with_bucket(self): url_context = client.URLContext(self.endpoint, bucket="mystuff") self.assertEquals(url_context.get_path(), "/mystuff/") def test_get_path_with_bucket_and_object(self): url_context = client.URLContext( self.endpoint, bucket="mystuff", object_name="/images/thing.jpg") self.assertEquals(url_context.get_host(), "s3.amazonaws.com") self.assertEquals(url_context.get_path(), "/mystuff/images/thing.jpg") def test_get_path_with_bucket_and_object_without_slash(self): url_context = client.URLContext( self.endpoint, bucket="mystuff", object_name="images/thing.jpg") self.assertEquals(url_context.get_host(), "s3.amazonaws.com") self.assertEquals(url_context.get_path(), "/mystuff/images/thing.jpg") def test_get_url_with_custom_endpoint(self): endpoint = AWSServiceEndpoint("http://localhost/") url_context = client.URLContext(endpoint) self.assertEquals(url_context.endpoint.get_uri(), "http://localhost/") self.assertEquals(url_context.get_url(), "http://localhost/") def test_get_uri_with_endpoint_bucket_and_object(self): endpoint = AWSServiceEndpoint("http://localhost/") url_context = client.URLContext( endpoint, bucket="mydocs", object_name="notes.txt") self.assertEquals( url_context.get_url(), "http://localhost/mydocs/notes.txt") def test_custom_port_endpoint(self): test_uri = 'http://0.0.0.0:12345/' endpoint = AWSServiceEndpoint(uri=test_uri) self.assertEquals(endpoint.port, 12345) self.assertEquals(endpoint.scheme, 'http') context = client.URLContext(service_endpoint=endpoint, bucket="foo", object_name="bar") self.assertEquals(context.get_host(), '0.0.0.0') self.assertEquals(context.get_url(), test_uri + 'foo/bar') def test_custom_port_endpoint_https(self): test_uri = 'https://0.0.0.0:12345/' endpoint = AWSServiceEndpoint(uri=test_uri) self.assertEquals(endpoint.port, 12345) self.assertEquals(endpoint.scheme, 'https') context = client.URLContext(service_endpoint=endpoint, bucket="foo", object_name="bar") self.assertEquals(context.get_host(), '0.0.0.0') self.assertEquals(context.get_url(), test_uri + 'foo/bar') URLContextTestCase.skip = s3clientSkip class S3ClientTestCase(TXAWSTestCase): def setUp(self): TXAWSTestCase.setUp(self) self.creds = AWSCredentials( access_key="accessKey", secret_key="secretKey") self.endpoint = AWSServiceEndpoint() def test_list_buckets(self): class StubQuery(client.Query): def __init__(query, action, creds, endpoint): super(StubQuery, query).__init__( action=action, creds=creds) self.assertEquals(action, "GET") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, None) self.assertEqual(query.object_name, None) self.assertEqual(query.data, "") self.assertEqual(query.metadata, {}) def submit(query): return succeed(payload.sample_list_buckets_result) def check_list_buckets(results): bucket1, bucket2 = results self.assertEquals(bucket1.name, "quotes") self.assertEquals( bucket1.creation_date.timetuple(), (2006, 2, 3, 16, 45, 9, 4, 34, 0)) self.assertEquals(bucket2.name, "samples") self.assertEquals( bucket2.creation_date.timetuple(), (2006, 2, 3, 16, 41, 58, 4, 34, 0)) creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) d = s3.list_buckets() return d.addCallback(check_list_buckets) def test_create_bucket(self): class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None): super(StubQuery, query).__init__( action=action, creds=creds, bucket=bucket) self.assertEquals(action, "PUT") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, None) self.assertEqual(query.data, "") self.assertEqual(query.metadata, {}) def submit(query, url_context=None): return succeed(None) creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) return s3.create_bucket("mybucket") def test_get_bucket(self): class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None): super(StubQuery, query).__init__( action=action, creds=creds, bucket=bucket) self.assertEquals(action, "GET") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, None) self.assertEqual(query.data, "") self.assertEqual(query.metadata, {}) def submit(query, url_context=None): return succeed(payload.sample_get_bucket_result) def check_results(listing): self.assertEquals(listing.name, "mybucket") self.assertEquals(listing.prefix, "N") self.assertEquals(listing.marker, "Ned") self.assertEquals(listing.max_keys, "40") self.assertEquals(listing.is_truncated, "false") self.assertEquals(len(listing.contents), 2) content1 = listing.contents[0] self.assertEquals(content1.key, "Nelson") self.assertEquals( content1.modification_date.timetuple(), (2006, 1, 1, 12, 0, 0, 6, 1, 0)) self.assertEquals( content1.etag, '"828ef3fdfa96f00ad9f27c383fc9ac7f"') self.assertEquals(content1.size, "5") self.assertEquals(content1.storage_class, "STANDARD") owner = content1.owner self.assertEquals(owner.id, "bcaf1ffd86f41caff1a493dc2ad8c2c281e37522a640e16" "1ca5fb16fd081034f") self.assertEquals(owner.display_name, "webfile") creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) d = s3.get_bucket("mybucket") return d.addCallback(check_results) def test_get_bucket_location(self): """ L{S3Client.get_bucket_location} creates a L{Query} to get a bucket's location. It parses the returned C{LocationConstraint} XML document and returns a C{Deferred} that requests the bucket's location constraint. """ class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None): super(StubQuery, query).__init__(action=action, creds=creds, bucket=bucket, object_name=object_name) self.assertEquals(action, "GET") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "?location") self.assertEqual(query.data, "") self.assertEqual(query.metadata, {}) self.assertEqual(query.amz_headers, {}) def submit(query, url_context=None): return succeed(payload.sample_get_bucket_location_result) def check_results(location_constraint): self.assertEquals(location_constraint, "EU") creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) d = s3.get_bucket_location("mybucket") return d.addCallback(check_results) def test_get_bucket_lifecycle_multiple_rules(self): """ L{S3Client.get_bucket_lifecycle} creates a L{Query} to get a bucket's lifecycle. It parses the returned C{LifecycleConfiguration} XML document and returns a C{Deferred} that requests the bucket's lifecycle configuration with multiple rules. """ class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None): super(StubQuery, query).__init__(action=action, creds=creds, bucket=bucket, object_name=object_name) self.assertEquals(action, "GET") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "?lifecycle") self.assertEqual(query.data, "") self.assertEqual(query.metadata, {}) self.assertEqual(query.amz_headers, {}) def submit(query, url_context=None): return succeed(payload. sample_s3_get_bucket_lifecycle_multiple_rules_result) def check_results(lifecycle_config): self.assertTrue(len(lifecycle_config.rules) == 2) rule = lifecycle_config.rules[1] self.assertEquals(rule.id, 'another-id') self.assertEquals(rule.prefix, 'another-logs') self.assertEquals(rule.status, 'Disabled') self.assertEquals(rule.expiration, 37) creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) d = s3.get_bucket_lifecycle("mybucket") return d.addCallback(check_results) def test_get_bucket_lifecycle(self): """ L{S3Client.get_bucket_lifecycle} creates a L{Query} to get a bucket's lifecycle. It parses the returned C{LifecycleConfiguration} XML document and returns a C{Deferred} that requests the bucket's lifecycle configuration. """ class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None): super(StubQuery, query).__init__(action=action, creds=creds, bucket=bucket, object_name=object_name) self.assertEquals(action, "GET") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "?lifecycle") self.assertEqual(query.data, "") self.assertEqual(query.metadata, {}) self.assertEqual(query.amz_headers, {}) def submit(query, url_context=None): return succeed(payload.sample_s3_get_bucket_lifecycle_result) def check_results(lifecycle_config): rule = lifecycle_config.rules[0] self.assertEquals(rule.id, '30-day-log-deletion-rule') self.assertEquals(rule.prefix, 'logs') self.assertEquals(rule.status, 'Enabled') self.assertEquals(rule.expiration, 30) creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) d = s3.get_bucket_lifecycle("mybucket") return d.addCallback(check_results) def test_get_bucket_website_config(self): """ L{S3Client.get_bucket_website_config} creates a L{Query} to get a bucket's website configurtion. It parses the returned C{WebsiteConfiguration} XML document and returns a C{Deferred} that requests the bucket's website configuration. """ class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None): super(StubQuery, query).__init__(action=action, creds=creds, bucket=bucket, object_name=object_name) self.assertEquals(action, "GET") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "?website") self.assertEqual(query.data, "") self.assertEqual(query.metadata, {}) self.assertEqual(query.amz_headers, {}) def submit(query, url_context=None): return succeed(payload. sample_s3_get_bucket_website_no_error_result) def check_results(website_config): self.assertEquals(website_config.index_suffix, "index.html") self.assertEquals(website_config.error_key, None) creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) d = s3.get_bucket_website_config("mybucket") return d.addCallback(check_results) def test_get_bucket_website_config_with_error_doc(self): """ L{S3Client.get_bucket_website_config} creates a L{Query} to get a bucket's website configurtion. It parses the returned C{WebsiteConfiguration} XML document and returns a C{Deferred} that requests the bucket's website configuration with the error document. """ class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None): super(StubQuery, query).__init__(action=action, creds=creds, bucket=bucket, object_name=object_name) self.assertEquals(action, "GET") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "?website") self.assertEqual(query.data, "") self.assertEqual(query.metadata, {}) self.assertEqual(query.amz_headers, {}) def submit(query, url_context=None): return succeed(payload.sample_s3_get_bucket_website_result) def check_results(website_config): self.assertEquals(website_config.index_suffix, "index.html") self.assertEquals(website_config.error_key, "404.html") creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) d = s3.get_bucket_website_config("mybucket") return d.addCallback(check_results) def test_get_bucket_notification_config(self): """ L{S3Client.get_bucket_notification_config} creates a L{Query} to get a bucket's notification configuration. It parses the returned C{NotificationConfiguration} XML document and returns a C{Deferred} that requests the bucket's notification configuration. """ class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None): super(StubQuery, query).__init__(action=action, creds=creds, bucket=bucket, object_name=object_name) self.assertEquals(action, "GET") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "?notification") self.assertEqual(query.data, "") self.assertEqual(query.metadata, {}) self.assertEqual(query.amz_headers, {}) def submit(query, url_context=None): return succeed(payload. sample_s3_get_bucket_notification_result) def check_results(notification_config): self.assertEquals(notification_config.topic, None) self.assertEquals(notification_config.event, None) creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) d = s3.get_bucket_notification_config("mybucket") return d.addCallback(check_results) def test_get_bucket_notification_config_with_topic(self): """ L{S3Client.get_bucket_notification_config} creates a L{Query} to get a bucket's notification configuration. It parses the returned C{NotificationConfiguration} XML document and returns a C{Deferred} that requests the bucket's notification configuration with a topic. """ class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None): super(StubQuery, query).__init__(action=action, creds=creds, bucket=bucket, object_name=object_name) self.assertEquals(action, "GET") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "?notification") self.assertEqual(query.data, "") self.assertEqual(query.metadata, {}) self.assertEqual(query.amz_headers, {}) def submit(query, url_context=None): return succeed( payload. sample_s3_get_bucket_notification_with_topic_result) def check_results(notification_config): self.assertEquals(notification_config.topic, "arn:aws:sns:us-east-1:123456789012:myTopic") self.assertEquals(notification_config.event, "s3:ReducedRedundancyLostObject") creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) d = s3.get_bucket_notification_config("mybucket") return d.addCallback(check_results) def test_get_bucket_versioning_config(self): """ L{S3Client.get_bucket_versioning_configuration} creates a L{Query} to get a bucket's versioning status. It parses the returned C{VersioningConfiguration} XML document and returns a C{Deferred} that requests the bucket's versioning configuration. """ class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None): super(StubQuery, query).__init__(action=action, creds=creds, bucket=bucket, object_name=object_name) self.assertEquals(action, "GET") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "?versioning") self.assertEqual(query.data, "") self.assertEqual(query.metadata, {}) self.assertEqual(query.amz_headers, {}) def submit(query, url_context=None): return succeed(payload.sample_s3_get_bucket_versioning_result) def check_results(versioning_config): self.assertEquals(versioning_config.status, None) creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) d = s3.get_bucket_versioning_config("mybucket") return d.addCallback(check_results) def test_get_bucket_versioning_config_enabled(self): """ L{S3Client.get_bucket_versioning_config} creates a L{Query} to get a bucket's versioning configuration. It parses the returned C{VersioningConfiguration} XML document and returns a C{Deferred} that requests the bucket's versioning configuration that has a enabled C{Status}. """ class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None): super(StubQuery, query).__init__(action=action, creds=creds, bucket=bucket, object_name=object_name) self.assertEquals(action, "GET") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "?versioning") self.assertEqual(query.data, "") self.assertEqual(query.metadata, {}) self.assertEqual(query.amz_headers, {}) def submit(query, url_context=None): return succeed(payload. sample_s3_get_bucket_versioning_enabled_result) def check_results(versioning_config): self.assertEquals(versioning_config.status, 'Enabled') creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) d = s3.get_bucket_versioning_config("mybucket") return d.addCallback(check_results) def test_get_bucket_versioning_config_mfa_disabled(self): """ L{S3Client.get_bucket_versioning_config} creates a L{Query} to get a bucket's versioning configuration. It parses the returned C{VersioningConfiguration} XML document and returns a C{Deferred} that requests the bucket's versioning configuration that has a disabled C{MfaDelete}. """ class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None): super(StubQuery, query).__init__(action=action, creds=creds, bucket=bucket, object_name=object_name) self.assertEquals(action, "GET") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "?versioning") self.assertEqual(query.data, "") self.assertEqual(query.metadata, {}) self.assertEqual(query.amz_headers, {}) def submit(query, url_context=None): return succeed( payload. sample_s3_get_bucket_versioning_mfa_disabled_result) def check_results(versioning_config): self.assertEquals(versioning_config.mfa_delete, 'Disabled') creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) d = s3.get_bucket_versioning_config("mybucket") return d.addCallback(check_results) def test_delete_bucket(self): class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None): super(StubQuery, query).__init__( action=action, creds=creds, bucket=bucket) self.assertEquals(action, "DELETE") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, None) self.assertEqual(query.data, "") self.assertEqual(query.metadata, {}) def submit(query, url_context=None): return succeed(None) creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) return s3.delete_bucket("mybucket") def test_put_bucket_acl(self): class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None, data=""): super(StubQuery, query).__init__(action=action, creds=creds, bucket=bucket, object_name=object_name, data=data) self.assertEquals(action, "PUT") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "?acl") self.assertEqual(query.data, payload.sample_access_control_policy_result) self.assertEqual(query.metadata, {}) def submit(query, url_context=None): return succeed(payload.sample_access_control_policy_result) def check_result(result): self.assert_(isinstance(result, AccessControlPolicy)) creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) policy = AccessControlPolicy.from_xml( payload.sample_access_control_policy_result) return s3.put_bucket_acl("mybucket", policy).addCallback(check_result) def test_get_bucket_acl(self): class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None, data=""): super(StubQuery, query).__init__(action=action, creds=creds, bucket=bucket, object_name=object_name, data=data) self.assertEquals(action, "GET") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "?acl") self.assertEqual(query.data, "") self.assertEqual(query.metadata, {}) def submit(query, url_context=None): return succeed(payload.sample_access_control_policy_result) def check_result(result): self.assert_(isinstance(result, AccessControlPolicy)) creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) return s3.get_bucket_acl("mybucket").addCallback(check_result) def test_put_request_payment(self): """ L{S3Client.put_request_payment} creates a L{Query} to set payment information. An C{RequestPaymentConfiguration} XML document is built and sent to the endpoint and a C{Deferred} is returned that fires with the results of the request. """ class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None, data=None, content_type=None, metadata=None): super(StubQuery, query).__init__( action=action, creds=creds, bucket=bucket, object_name=object_name, data=data, content_type=content_type, metadata=metadata) self.assertEqual(action, "PUT") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "?requestPayment") xml = ("\n' " Requester\n" "") self.assertEqual(query.data, xml) self.assertEqual(query.metadata, None) def submit(query): return succeed(None) creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) return s3.put_request_payment("mybucket", "Requester") def test_get_request_payment(self): """ L{S3Client.get_request_payment} creates a L{Query} to get payment information. It parses the returned C{RequestPaymentConfiguration} XML document and returns a C{Deferred} that fires with the payer's name. """ class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None, data=None, content_type=None, metadata=None): super(StubQuery, query).__init__( action=action, creds=creds, bucket=bucket, object_name=object_name, data=data, content_type=content_type, metadata=metadata) self.assertEqual(action, "GET") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "?requestPayment") self.assertEqual(query.metadata, None) def submit(query): return succeed(payload.sample_request_payment) def check_request_payment(result): self.assertEquals(result, "Requester") creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) deferred = s3.get_request_payment("mybucket") return deferred.addCallback(check_request_payment) def test_put_object(self): class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None, data=None, content_type=None, metadata=None, amz_headers=None): super(StubQuery, query).__init__( action=action, creds=creds, bucket=bucket, object_name=object_name, data=data, content_type=content_type, metadata=metadata, amz_headers=amz_headers) self.assertEqual(action, "PUT") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "objectname") self.assertEqual(query.data, "some data") self.assertEqual(query.content_type, "text/plain") self.assertEqual(query.metadata, {"key": "some meta data"}) self.assertEqual(query.amz_headers, {"acl": "public-read"}) def submit(query): return succeed(None) creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) return s3.put_object("mybucket", "objectname", "some data", content_type="text/plain", metadata={"key": "some meta data"}, amz_headers={"acl": "public-read"}) def test_copy_object(self): """ L{S3Client.copy_object} creates a L{Query} to copy an object from one bucket to another. """ class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None, data=None, content_type=None, metadata=None, amz_headers=None): super(StubQuery, query).__init__( action=action, creds=creds, bucket=bucket, object_name=object_name, data=data, content_type=content_type, metadata=metadata, amz_headers=amz_headers) self.assertEqual(action, "PUT") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "newbucket") self.assertEqual(query.object_name, "newobjectname") self.assertEqual(query.data, None) self.assertEqual(query.content_type, None) self.assertEqual(query.metadata, {"key": "some meta data"}) self.assertEqual(query.amz_headers, {"copy-source": "/mybucket/objectname"}) def submit(query): return succeed(None) creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) return s3.copy_object("mybucket", "objectname", "newbucket", "newobjectname", metadata={"key": "some meta data"}) def test_get_object(self): class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None, data=None, content_type=None, metadata=None, amz_headers=None): super(StubQuery, query).__init__( action=action, creds=creds, bucket=bucket, object_name=object_name, data=data, content_type=content_type, metadata=metadata) self.assertEqual(action, "GET") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "objectname") def submit(query): return succeed(None) creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) return s3.get_object("mybucket", "objectname") def test_head_object(self): class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None, data=None, content_type=None, metadata=None): super(StubQuery, query).__init__( action=action, creds=creds, bucket=bucket, object_name=object_name, data=data, content_type=content_type, metadata=metadata) self.assertEqual(action, "HEAD") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "objectname") def submit(query): return succeed(None) creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) return s3.head_object("mybucket", "objectname") def test_delete_object(self): class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None, data=None, content_type=None, metadata=None): super(StubQuery, query).__init__( action=action, creds=creds, bucket=bucket, object_name=object_name, data=data, content_type=content_type, metadata=metadata) self.assertEqual(action, "DELETE") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "objectname") def submit(query): return succeed(None) creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) return s3.delete_object("mybucket", "objectname") def test_put_object_acl(self): class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None, data=""): super(StubQuery, query).__init__(action=action, creds=creds, bucket=bucket, object_name=object_name, data=data) self.assertEquals(action, "PUT") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "myobject?acl") self.assertEqual(query.data, payload.sample_access_control_policy_result) self.assertEqual(query.metadata, {}) self.assertEqual(query.metadata, {}) def submit(query, url_context=None): return succeed(payload.sample_access_control_policy_result) def check_result(result): self.assert_(isinstance(result, AccessControlPolicy)) creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) policy = AccessControlPolicy.from_xml( payload.sample_access_control_policy_result) deferred = s3.put_object_acl("mybucket", "myobject", policy) return deferred.addCallback(check_result) def test_get_object_acl(self): class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket=None, object_name=None, data=""): super(StubQuery, query).__init__(action=action, creds=creds, bucket=bucket, object_name=object_name, data=data) self.assertEquals(action, "GET") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(query.bucket, "mybucket") self.assertEqual(query.object_name, "myobject?acl") self.assertEqual(query.data, "") self.assertEqual(query.metadata, {}) def submit(query, url_context=None): return succeed(payload.sample_access_control_policy_result) def check_result(result): self.assert_(isinstance(result, AccessControlPolicy)) creds = AWSCredentials("foo", "bar") s3 = client.S3Client(creds, query_factory=StubQuery) deferred = s3.get_object_acl("mybucket", "myobject") return deferred.addCallback(check_result) S3ClientTestCase.skip = s3clientSkip class QueryTestCase(TXAWSTestCase): creds = AWSCredentials(access_key="fookeyid", secret_key="barsecretkey") endpoint = AWSServiceEndpoint("https://choopy.s3.amazonaws.com/") def test_default_creation(self): query = client.Query(action="PUT") self.assertEquals(query.bucket, None) self.assertEquals(query.object_name, None) self.assertEquals(query.data, "") self.assertEquals(query.content_type, None) self.assertEquals(query.metadata, {}) def test_default_endpoint(self): query = client.Query(action="PUT") self.assertEquals(self.endpoint.host, "choopy.s3.amazonaws.com") self.assertEquals(query.endpoint.host, "s3.amazonaws.com") self.assertEquals(self.endpoint.method, "GET") self.assertEquals(query.endpoint.method, "PUT") def test_set_content_type_no_object_name(self): query = client.Query(action="PUT") query.set_content_type() self.assertEquals(query.content_type, None) def test_set_content_type(self): query = client.Query(action="PUT", object_name="advicedog.jpg") query.set_content_type() self.assertEquals(query.content_type, "image/jpeg") def test_set_content_type_with_content_type_already_set(self): query = client.Query( action="PUT", object_name="data.txt", content_type="text/csv") query.set_content_type() self.assertNotEquals(query.content_type, "text/plain") self.assertEquals(query.content_type, "text/csv") def test_get_headers(self): query = client.Query( action="GET", creds=self.creds, bucket="mystuff", object_name="/images/thing.jpg") headers = query.get_headers() self.assertEquals(headers.get("Content-Type"), "image/jpeg") self.assertEquals(headers.get("Content-Length"), 0) self.assertEquals( headers.get("Content-MD5"), "1B2M2Y8AsgTpgAmY7PhCfg==") self.assertTrue(len(headers.get("Date")) > 25) self.assertTrue( headers.get("Authorization").startswith("AWS fookeyid:")) self.assertTrue(len(headers.get("Authorization")) > 40) def test_get_headers_with_data(self): query = client.Query( action="GET", creds=self.creds, bucket="mystuff", object_name="/images/thing.jpg", data="BINARY IMAGE DATA") headers = query.get_headers() self.assertEquals(headers.get("Content-Type"), "image/jpeg") self.assertEquals(headers.get("Content-Length"), 17) self.assertTrue(len(headers.get("Date")) > 25) self.assertTrue( headers.get("Authorization").startswith("AWS fookeyid:")) self.assertTrue(len(headers.get("Authorization")) > 40) def test_get_canonicalized_amz_headers(self): query = client.Query( action="SomeThing", metadata={"a": 1, "b": 2, "c": 3}) headers = query.get_headers() self.assertEquals( sorted(headers.keys()), ["Content-Length", "Content-MD5", "Date", "x-amz-meta-a", "x-amz-meta-b", "x-amz-meta-c"]) amz_headers = query.get_canonicalized_amz_headers(headers) self.assertEquals( amz_headers, "x-amz-meta-a:1\nx-amz-meta-b:2\nx-amz-meta-c:3\n") def test_get_canonicalized_resource(self): query = client.Query(action="PUT", bucket="images") result = query.get_canonicalized_resource() self.assertEquals(result, "/images/") def test_get_canonicalized_resource_with_object_name(self): query = client.Query( action="PUT", bucket="images", object_name="advicedog.jpg") result = query.get_canonicalized_resource() self.assertEquals(result, "/images/advicedog.jpg") def test_get_canonicalized_resource_with_slashed_object_name(self): query = client.Query( action="PUT", bucket="images", object_name="/advicedog.jpg") result = query.get_canonicalized_resource() self.assertEquals(result, "/images/advicedog.jpg") def test_sign(self): query = client.Query(action="PUT", creds=self.creds) signed = query.sign({}) self.assertEquals(signed, "H6UJCNHizzXZCGPl7wM6nL6tQdo=") def test_object_query(self): """ Test that a request addressing an object is created correctly. """ DATA = "objectData" DIGEST = "zhdB6gwvocWv/ourYUWMxA==" request = client.Query( action="PUT", bucket="somebucket", object_name="object/name/here", data=DATA, content_type="text/plain", metadata={"foo": "bar"}, amz_headers={"acl": "public-read"}, creds=self.creds, endpoint=self.endpoint) request.sign = lambda headers: "TESTINGSIG=" self.assertEqual(request.action, "PUT") headers = request.get_headers() self.assertNotEqual(headers.pop("Date"), "") self.assertEqual(headers, {"Authorization": "AWS fookeyid:TESTINGSIG=", "Content-Type": "text/plain", "Content-Length": len(DATA), "Content-MD5": DIGEST, "x-amz-meta-foo": "bar", "x-amz-acl": "public-read"}) self.assertEqual(request.data, "objectData") def test_bucket_query(self): """ Test that a request addressing a bucket is created correctly. """ DIGEST = "1B2M2Y8AsgTpgAmY7PhCfg==" query = client.Query( action="GET", bucket="somebucket", creds=self.creds, endpoint=self.endpoint) query.sign = lambda headers: "TESTINGSIG=" self.assertEqual(query.action, "GET") headers = query.get_headers() self.assertNotEqual(headers.pop("Date"), "") self.assertEqual( headers, { "Authorization": "AWS fookeyid:TESTINGSIG=", "Content-Length": 0, "Content-MD5": DIGEST}) self.assertEqual(query.data, "") def test_submit(self): """ Submitting the request should invoke getPage correctly. """ class StubQuery(client.Query): def __init__(query, action, creds, endpoint, bucket): super(StubQuery, query).__init__( action=action, creds=creds, bucket=bucket) self.assertEquals(action, "GET") self.assertEqual(creds.access_key, "fookeyid") self.assertEqual(creds.secret_key, "barsecretkey") self.assertEqual(query.bucket, "somebucket") self.assertEqual(query.object_name, None) self.assertEqual(query.data, "") self.assertEqual(query.metadata, {}) def submit(query): return succeed("") query = StubQuery(action="GET", creds=self.creds, endpoint=self.endpoint, bucket="somebucket") return query.submit() def test_authentication(self): query = client.Query( action="GET", creds=self.creds, endpoint=self.endpoint) query.sign = lambda headers: "TESTINGSIG=" query.date = "Wed, 28 Mar 2007 01:29:59 +0000" headers = query.get_headers() self.assertEqual( headers["Authorization"], "AWS fookeyid:TESTINGSIG=") QueryTestCase.skip = s3clientSkip class MiscellaneousTestCase(TXAWSTestCase): def test_content_md5(self): self.assertEqual(calculate_md5("somedata"), "rvr3UC1SmUw7AZV2NqPN0g==") def test_request_payment_enum(self): """ Only 'Requester' or 'BucketOwner' may be provided when a L{RequestPayment} is instantiated. """ RequestPayment("Requester") RequestPayment("BucketOwner") self.assertRaises(ValueError, RequestPayment, "Bob") txAWS-0.2.3/txaws/script.py0000664000175000017500000000333411741311335017240 0ustar oubiwannoubiwann00000000000000from optparse import OptionParser from txaws import meta from txaws import version # XXX Once we start adding script that require conflicting options, we'll need # multiple parsers and option dispatching... def parse_options(usage): parser = OptionParser(usage, version="%s %s" % ( meta.display_name, version.txaws)) parser.add_option( "-a", "--access-key", dest="access_key", help="access key ID") parser.add_option( "-s", "--secret-key", dest="secret_key", help="access secret key") parser.add_option( "-r", "--region", dest="region", help="US or EU (valid for AWS only)") parser.add_option( "-U", "--url", dest="url", help="service URL/endpoint") parser.add_option( "-b", "--bucket", dest="bucket", help="name of the bucket") parser.add_option( "-o", "--object-name", dest="object_name", help="name of the object") parser.add_option( "-d", "--object-data", dest="object_data", help="content data of the object") parser.add_option( "--object-file", dest="object_filename", help=("the path to the file that will be saved as an object; if " "provided, the --object-name and --object-data options are " "not necessary")) parser.add_option( "-c", "--content-type", dest="content_type", help="content type of the object") options, args = parser.parse_args() if not (options.access_key and options.secret_key): parser.error( "both the access key ID and the secret key must be supplied") region = options.region if region and region.upper() not in ["US", "EU"]: parser.error("region must be one of 'US' or 'EU'") return (options, args) txAWS-0.2.3/txaws/meta.py0000664000175000017500000000037711741311335016666 0ustar oubiwannoubiwann00000000000000display_name = "txAWS" library_name = "txaws" author = "txAWS Developers" author_email = "txaws-dev@lists.launchpad.net" license = "MIT" url = "http://launchpad.net/txaws" description = """ Twisted-based Asynchronous Libraries for Amazon Web Services """ txAWS-0.2.3/txaws/regions.py0000664000175000017500000000533311741311335017403 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2009 Duncan McGreggor # Copyright (C) 2009 Robert Collins # Copyright (C) 2012 New Dream Network, LLC (DreamHost) # Licenced under the txaws licence available at /LICENSE in the txaws source. __all__ = ["REGION_US", "REGION_EU", "EC2_US_EAST", "EC2_US_WEST", "EC2_ASIA_PACIFIC", "EC2_EU_WEST", "EC2_SOUTH_AMERICA_EAST", "EC2_ALL_REGIONS"] # These old EC2 variable names are maintained for backwards compatibility. REGION_US = "US" REGION_EU = "EU" EC2_ENDPOINT_US = "https://us-east-1.ec2.amazonaws.com/" EC2_ENDPOINT_EU = "https://eu-west-1.ec2.amazonaws.com/" # These are the new EC2 variables. EC2_US_EAST = [ {"region": "US East (Northern Virginia) Region", "endpoint": "https://ec2.us-east-1.amazonaws.com"}] EC2_US_WEST = [ {"region": "US West (Oregon) Region", "endpoint": "https://ec2.us-west-2.amazonaws.com"}, {"region": "US West (Northern California) Region", "endpoint": "https://ec2.us-west-1.amazonaws.com"}] EC2_US = EC2_US_EAST + EC2_US_WEST EC2_ASIA_PACIFIC = [ {"region": "Asia Pacific (Singapore) Region", "endpoint": "https://ec2.ap-southeast-1.amazonaws.com"}, {"region": "Asia Pacific (Tokyo) Region", "endpoint": "https://ec2.ap-northeast-1.amazonaws.com"}] EC2_EU_WEST = [ {"region": "EU (Ireland) Region", "endpoint": "https://ec2.eu-west-1.amazonaws.com"}] EC2_EU = EC2_EU_WEST EC2_SOUTH_AMERICA_EAST = [ {"region": "South America (Sao Paulo) Region", "endpoint": "https://ec2.sa-east-1.amazonaws.com"}] EC2_SOUTH_AMERICA = EC2_SOUTH_AMERICA_EAST EC2_ALL_REGIONS = EC2_US + EC2_ASIA_PACIFIC + EC2_EU + EC2_SOUTH_AMERICA # This old S3 variable is maintained for backwards compatibility. S3_ENDPOINT = "https://s3.amazonaws.com/" # These are the new S3 variables. S3_US_DEFAULT = [ {"region": "US Standard *", "endpoint": "https://s3.amazonaws.com"}] S3_US_WEST = [ {"region": "US West (Oregon) Region", "endpoint": "https://s3-us-west-2.amazonaws.com"}, {"region": "US West (Northern California) Region", "endpoint": "https://s3-us-west-1.amazonaws.com"}] S3_ASIA_PACIFIC = [ {"region": "Asia Pacific (Singapore) Region", "endpoint": "https://s3-ap-southeast-1.amazonaws.com"}, {"region": "Asia Pacific (Tokyo) Region", "endpoint": "https://s3-ap-northeast-1.amazonaws.com"}] S3_US = S3_US_DEFAULT + S3_US_WEST S3_EU_WEST = [ {"region": "EU (Ireland) Region", "endpoint": "https://s3-eu-west-1.amazonaws.com"}] S3_EU = S3_EU_WEST S3_SOUTH_AMERICA_EAST = [ {"region": "South America (Sao Paulo) Region", "endpoint": "s3-sa-east-1.amazonaws.com"}] S3_SOUTH_AMERICA = S3_SOUTH_AMERICA_EAST S3_ALL_REGIONS = S3_US + S3_ASIA_PACIFIC + S3_EU + S3_SOUTH_AMERICA txAWS-0.2.3/txaws/server/0000775000175000017500000000000011741312025016662 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/server/exception.py0000664000175000017500000000262711741311335021244 0ustar oubiwannoubiwann00000000000000class APIError(Exception): """Raised while handling an API request. @param status: The HTTP status code the response will be set to. @param code: A machine-parsable textual code for the error. @param message: A human-readable description of the error. @param response: The full body of the response to be sent to the client, if C{None} it will be generated from C{code} and C{message}. See also L{API.dump_error}. """ def __init__(self, status, code=None, message=None, response=None): super(APIError, self).__init__(message) self.status = int(status) self.code = code self.message = message self.response = response if self.response is None: if self.code is None or self.message is None: raise RuntimeError("If the response is not specified, code " "and status must both be set.") else: if self.code is not None or self.message is not None: raise RuntimeError("If the full response payload is passed, " "code and message must not be set.") def __str__(self): # This avoids an exception when twisted logger logs the message, as it # currently doesn't support unicode. if self.message is not None: return self.message.encode("ascii", "replace") return "" txAWS-0.2.3/txaws/server/__init__.py0000664000175000017500000000000011741311335020764 0ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/server/schema.py0000664000175000017500000004152211741311335020503 0ustar oubiwannoubiwann00000000000000from datetime import datetime from operator import itemgetter from dateutil.tz import tzutc from dateutil.parser import parse from txaws.server.exception import APIError class SchemaError(APIError): """Raised when failing to extract or bundle L{Parameter}s.""" def __init__(self, message): code = self.__class__.__name__[:-len("Error")] super(SchemaError, self).__init__(400, code=code, message=message) class MissingParameterError(SchemaError): """Raised when a parameter is missing. @param name: The name of the missing parameter. """ def __init__(self, name): message = "The request must contain the parameter %s" % name super(MissingParameterError, self).__init__(message) class InvalidParameterValueError(SchemaError): """Raised when the value of a parameter is invalid.""" class InvalidParameterCombinationError(SchemaError): """ Raised when there is more than one parameter with the same name, when this isn't explicitly allowed for. @param name: The name of the missing parameter. """ def __init__(self, name): message = "The parameter '%s' may only be specified once." % name super(InvalidParameterCombinationError, self).__init__(message) class UnknownParameterError(SchemaError): """Raised when a parameter to extract is unknown.""" def __init__(self, name): message = "The parameter %s is not recognized" % name super(UnknownParameterError, self).__init__(message) class Parameter(object): """A single parameter in an HTTP request. @param name: A name for the key of the parameter, as specified in a request. For example, a single parameter would be specified simply as 'GroupName'. If more than one group name was accepted, it would be specified as 'GroupName.n'. A more complex example is 'IpPermissions.n.Groups.m.GroupName'. @param optional: If C{True} the parameter may not be present. @param default: A default value for the parameter, if not present. @param min: Minimum value for a parameter. @param max: Maximum value for a parameter. @param allow_none: Whether the parameter may be C{None}. @param validator: A callable to validate the parameter, returning a bool. """ def __init__(self, name, optional=False, default=None, min=None, max=None, allow_none=False, validator=None): self.name = name self.optional = optional self.default = default self.min = min self.max = max self.allow_none = allow_none self.validator = validator def coerce(self, value): """Coerce a single value according to this parameter's settings. @param value: A L{str}, or L{None}. If L{None} is passed - meaning no value is avalable at all, not even the empty string - and this parameter is optional, L{self.default} will be returned. """ if value is None: if self.optional: return self.default else: value = "" if value == "": if not self.allow_none: raise MissingParameterError(self.name) return self.default try: self._check_range(value) parsed = self.parse(value) if self.validator and not self.validator(parsed): raise ValueError(value) return parsed except ValueError: try: value = value.decode("utf-8") message = "Invalid %s value %s" % (self.kind, value) except UnicodeDecodeError: message = "Invalid %s value" % self.kind raise InvalidParameterValueError(message) def _check_range(self, value): """Check that the given C{value} is in the expected range.""" if self.min is None and self.max is None: return measure = self.measure(value) prefix = "Value (%s) for parameter %s is invalid. %s" if self.min is not None and measure < self.min: message = prefix % (value, self.name, self.lower_than_min_template % self.min) raise InvalidParameterValueError(message) if self.max is not None and measure > self.max: message = prefix % (value, self.name, self.greater_than_max_template % self.max) raise InvalidParameterValueError(message) def parse(self, value): """ Parse a single parameter value coverting it to the appropriate type. """ raise NotImplementedError() def format(self, value): """ Format a single parameter value in a way suitable for an HTTP request. """ raise NotImplementedError() def measure(self, value): """ Return an C{int} providing a measure for C{value}, used for C{range}. """ raise NotImplementedError() class Unicode(Parameter): """A parameter that must be a C{unicode}.""" kind = "unicode" lower_than_min_template = "Length must be at least %s." greater_than_max_template = "Length exceeds maximum of %s." def parse(self, value): return value.decode("utf-8") def format(self, value): return value.encode("utf-8") def measure(self, value): return len(value) class RawStr(Parameter): """A parameter that must be a C{str}.""" kind = "raw string" def parse(self, value): return str(value) def format(self, value): return value class Integer(Parameter): """A parameter that must be a positive C{int}.""" kind = "integer" lower_than_min_template = "Value must be at least %s." greater_than_max_template = "Value exceeds maximum of %s." def __init__(self, name, optional=False, default=None, min=0, max=None, allow_none=False, validator=None): super(Integer, self).__init__(name, optional, default, min, max, allow_none, validator) def parse(self, value): return int(value) def format(self, value): return str(value) def measure(self, value): return int(value) class Bool(Parameter): """A parameter that must be a C{bool}.""" kind = "boolean" def parse(self, value): if value == "true": return True if value == "false": return False raise ValueError() def format(self, value): if value: return "true" else: return "false" class Enum(Parameter): """A parameter with enumerated values. @param name: The name of the parameter, as specified in a request. @param optional: If C{True} the parameter may not be present. @param default: A default value for the parameter, if not present. @param mapping: A mapping of accepted values to the values that will be returned by C{parse}. """ kind = "enum" def __init__(self, name, mapping, optional=False, default=None): super(Enum, self).__init__(name, optional=optional, default=default) self.mapping = mapping self.reverse = dict((value, key) for key, value in mapping.iteritems()) def parse(self, value): try: return self.mapping[value] except KeyError: raise ValueError() def format(self, value): return self.reverse[value] class Date(Parameter): """A parameter that must be a valid ISO 8601 formatted date.""" kind = "date" def parse(self, value): return parse(value).replace(tzinfo=tzutc()) def format(self, value): # Convert value to UTC. tt = value.utctimetuple() utc_value = datetime( tt.tm_year, tt.tm_mon, tt.tm_mday, tt.tm_hour, tt.tm_min, tt.tm_sec) return datetime.strftime(utc_value, "%Y-%m-%dT%H:%M:%SZ") class Arguments(object): """Arguments parsed from a request.""" def __init__(self, tree): """Initialize a new L{Arguments} instance. @param tree: The C{dict}-based structure of the L{Argument}instance to create. """ for key, value in tree.iteritems(): self.__dict__[key] = self._wrap(value) def __iter__(self): """Returns an iterator yielding C{(name, value)} tuples.""" return self.__dict__.iteritems() def __getitem__(self, index): """Return the argument value with the given L{index}.""" return self.__dict__[index] def __len__(self): """Return the number of arguments.""" return len(self.__dict__) def _wrap(self, value): """Wrap the given L{tree} with L{Arguments} as necessary. @param tree: A {dict}, containing L{dict}s and/or leaf values, nested arbitrarily deep. """ if isinstance(value, dict): if any(isinstance(name, int) for name in value.keys()): if not all(isinstance(name, int) for name in value.keys()): raise RuntimeError("Integer and non-integer keys: %r" % value.keys()) items = sorted(value.iteritems(), key=itemgetter(0)) return [self._wrap(value) for (name, value) in items] else: return Arguments(value) else: return value class Schema(object): """ The schema that the arguments of an HTTP request must be compliant with. """ def __init__(self, *parameters): """Initialize a new L{Schema} instance. Any number of L{Parameter} instances can be passed. The parameter path is used as the target in L{Schema.extract} and L{Schema.bundle}. For example:: schema = Schema(Unicode('Name')) means that the result of L{Schema.extract} would have a C{Name} attribute. Similarly, L{Schema.bundle} would look for a C{Name} attribute. A more complex example:: schema = Schema(Unicode('Name.#')) means that the result of L{Schema.extract} would have a C{Name} attribute, which would itself contain a list of names. Similarly, L{Schema.bundle} would look for a C{Name} attribute. """ self._parameters = dict( (self._get_template(parameter.name), parameter) for parameter in parameters) def extract(self, params): """Extract parameters from a raw C{dict} according to this schema. @param params: The raw parameters to parse. @return: An L{Arguments} object holding the extracted arguments. @raises UnknownParameterError: If C{params} contains keys that this schema doesn't know about. """ tree = {} rest = {} # Extract from the given arguments and parse according to the # corresponding parameters. for name, value in params.iteritems(): template = self._get_template(name) parameter = self._parameters.get(template) if template.endswith(".#") and parameter is None: # If we were unable to find a direct match for a template that # allows multiple values. Let's attempt to find it without the # multiple value marker which Amazon allows. For example if the # template is 'PublicIp', then a single key 'PublicIp.1' is # allowed. parameter = self._parameters.get(template[:-2]) if parameter is not None: name = name[:-2] # At this point, we have a template that doesn't have the .# # marker to indicate multiple values. We don't allow multiple # "single" values for the same element. if name in tree.keys(): raise InvalidParameterCombinationError(name) if parameter is None: rest[name] = value else: self._set_value(tree, name, parameter.coerce(value)) # Ensure that the tree arguments are consistent with constraints # defined in the schema. for template, parameter in self._parameters.iteritems(): self._ensure_tree(tree, parameter, *template.split(".")) return Arguments(tree), rest def bundle(self, *arguments, **extra): """Bundle the given arguments in a C{dict} with EC2-style format. @param arguments: L{Arguments} instances to bundle. Keys in later objects will override those in earlier objects. @param extra: Any number of additional parameters. These will override similarly named arguments in L{arguments}. """ params = {} for argument in arguments: self._flatten(params, argument) self._flatten(params, extra) for name, value in params.iteritems(): parameter = self._parameters.get(self._get_template(name)) if parameter is None: raise RuntimeError("Parameter '%s' not in schema" % name) else: if value is None: params[name] = "" else: params[name] = parameter.format(value) return params def _get_template(self, key): """Return the canonical template for a given parameter key. For example:: 'Child.1.Name.2' becomes:: 'Child.#.Name.#' """ parts = key.split(".") for index, part in enumerate(parts[1::2]): parts[index * 2 + 1] = "#" return ".".join(parts) def _set_value(self, tree, path, value): """Set C{value} at C{path} in the given C{tree}. For example:: tree = {} _set_value(tree, 'foo.1.bar.2', True) results in C{tree} becoming:: {'foo': {1: {'bar': {2: True}}}} @param tree: A L{dict}. @param path: A L{str}. @param value: The value to set. Can be anything. """ nodes = [] for index, node in enumerate(path.split(".")): if index % 2: # Nodes with odd indexes must be non-negative integers try: node = int(node) except ValueError: raise UnknownParameterError(path) if node < 0: raise UnknownParameterError(path) nodes.append(node) for node in nodes[:-1]: tree = tree.setdefault(node, {}) tree[nodes[-1]] = value def _ensure_tree(self, tree, parameter, node, *nodes): """Check that C{node} exists in C{tree} and is followed by C{nodes}. C{node} and C{nodes} should correspond to a template path (i.e. where there are no absolute indexes, but C{#} instead). """ if node == "#": if len(nodes) == 0: if len(tree.keys()) == 0 and not parameter.optional: raise MissingParameterError(parameter.name) else: for subtree in tree.itervalues(): self._ensure_tree(subtree, parameter, *nodes) else: if len(nodes) == 0: if node not in tree.keys(): # No value for this parameter is present, if it's not # optional nor allow_none is set, the call below will # raise a MissingParameterError tree[node] = parameter.coerce(None) else: if node not in tree.keys(): tree[node] = {} self._ensure_tree(tree[node], parameter, *nodes) def _flatten(self, params, tree, path=""): """ For every element in L{tree}, set C{path} to C{value} in the given L{params} dictionary. @param params: A L{dict} which will be populated. @param tree: A structure made up of L{Argument}s, L{list}s, L{dict}s and leaf values. """ if isinstance(tree, Arguments): for name, value in tree: self._flatten(params, value, "%s.%s" % (path, name)) elif isinstance(tree, dict): for name, value in tree.iteritems(): self._flatten(params, value, "%s.%s" % (path, name)) elif isinstance(tree, list): for index, value in enumerate(tree): self._flatten(params, value, "%s.%d" % (path, index + 1)) elif tree is not None: params[path.lstrip(".")] = tree else: # None is discarded. pass def extend(self, *schema_items): """ Add any number of schema items to a new schema. """ parameters = self._parameters.values() for item in schema_items: if isinstance(item, Parameter): parameters.append(item) else: raise TypeError("Illegal argument %s" % item) return Schema(*parameters) txAWS-0.2.3/txaws/server/resource.py0000664000175000017500000002557211741311335021101 0ustar oubiwannoubiwann00000000000000from datetime import datetime, timedelta from uuid import uuid4 from dateutil.tz import tzutc from twisted.python import log from twisted.python.reflect import safe_str from twisted.internet.defer import maybeDeferred from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET from txaws.ec2.client import Signature from txaws.service import AWSServiceEndpoint from txaws.credentials import AWSCredentials from txaws.server.schema import ( Schema, Unicode, Integer, Enum, RawStr, Date) from txaws.server.exception import APIError from txaws.server.call import Call class QueryAPI(Resource): """Base class for EC2-like query APIs. @param registry: The L{Registry} to use to look up L{Method}s for handling the API requests. @param path: Optionally, the actual resource path the clients are using when sending HTTP requests to this API, to take into account when validating the signature. This can differ from the one in the HTTP request we're processing in case the service sits behind a reverse proxy, like Apache. For this works to work you have to make sure that 'path + path_of_the_rewritten_request' equals the resource path that clients are sending the request to. The following class variables must be defined by sub-classes: @ivar signature_versions: A list of allowed values for 'SignatureVersion'. @cvar content_type: The content type to set the 'Content-Type' header to. """ isLeaf = True time_format = "%Y-%m-%dT%H:%M:%SZ" schema = Schema( Unicode("Action"), RawStr("AWSAccessKeyId"), Date("Timestamp", optional=True), Date("Expires", optional=True), RawStr("Version", optional=True), Enum("SignatureMethod", {"HmacSHA256": "sha256", "HmacSHA1": "sha1"}, optional=True, default="HmacSHA256"), Unicode("Signature"), Integer("SignatureVersion", optional=True, default=2)) def __init__(self, registry=None, path=None): Resource.__init__(self) self.path = path self.registry = registry def get_method(self, call, *args, **kwargs): """Return the L{Method} instance to invoke for the given L{Call}. @param args: Positional arguments to pass to the method constructor. @param kwargs: Keyword arguments to pass to the method constructor. """ method_class = self.registry.get(call.action, call.version) method = method_class(*args, **kwargs) if not method.is_available(): raise APIError(400, "InvalidAction", "The action %s is not " "valid for this web service." % call.action) else: return method def get_principal(self, access_key): """Return a principal object by access key. The returned object must have C{access_key} and C{secret_key} attributes and if the authentication succeeds, it will be passed to the created L{Call}. """ raise NotImplemented("Must be implemented by subclasses") def handle(self, request): """Handle an HTTP request for executing an API call. This method authenticates the request checking its signature, and then calls the C{execute} method, passing it a L{Call} object set with the principal for the authenticated user and the generic parameters extracted from the request. @param request: The L{HTTPRequest} to handle. """ request.id = str(uuid4()) deferred = maybeDeferred(self._validate, request) deferred.addCallback(self.execute) def write_response(response): request.setHeader("Content-Length", str(len(response))) request.setHeader("Content-Type", self.content_type) request.write(response) request.finish() return response def write_error(failure): if failure.check(APIError): status = failure.value.status # Don't log the stack traces for 4xx responses. if status < 400 or status >= 500: log.err(failure) else: log.msg("status: %s message: %s" % ( status, safe_str(failure.value))) bytes = failure.value.response if bytes is None: bytes = self.dump_error(failure.value, request) else: log.err(failure) bytes = safe_str(failure.value) status = 500 request.setResponseCode(status) request.write(bytes) request.finish() deferred.addCallback(write_response) deferred.addErrback(write_error) return deferred def dump_error(self, error, request): """Serialize an error generating the response to send to the client. @param error: The L{APIError} to format. @param request: The request that generated the error. """ raise NotImplementedError("Must be implemented by subclass.") def dump_result(self, result): """Serialize the result of the method invokation. @param result: The L{Method} result to serialize. """ return result def authorize(self, method, call): """Authorize to invoke the given L{Method} with the given L{Call}.""" def execute(self, call): """Execute an API L{Call}. At this point the request has been authenticated and C{call.principal} is set with the L{Principal} for the L{User} requesting the call. @return: The response to write in the request for the given L{Call}. @raises: An L{APIError} in case the execution fails, sporting an error message the HTTP status code to return. """ method = self.get_method(call) deferred = maybeDeferred(self.authorize, method, call) deferred.addCallback(lambda _: method.invoke(call)) return deferred.addCallback(self.dump_result) def get_utc_time(self): """Return a C{datetime} object with the current time in UTC.""" return datetime.now(tzutc()) def _validate(self, request): """Validate an L{HTTPRequest} before executing it. The following conditions are checked: - The request contains all the generic parameters. - The action specified in the request is a supported one. - The signature mechanism is a supported one. - The provided signature matches the one calculated using the locally stored secret access key for the user. - The signature hasn't expired. @return: The validated L{Call}, set with its default arguments and the the principal of the accessing L{User}. """ params = dict((k, v[-1]) for k, v in request.args.iteritems()) args, rest = self.schema.extract(params) self._validate_generic_parameters(args) def create_call(principal): self._validate_principal(principal, args) self._validate_signature(request, principal, args, params) return Call(raw_params=rest, principal=principal, action=args.Action, version=args.Version, id=request.id) deferred = maybeDeferred(self.get_principal, args.AWSAccessKeyId) deferred.addCallback(create_call) return deferred def _validate_generic_parameters(self, args): """Validate the generic request parameters. @param args: Parsed schema arguments. @raises APIError: In the following cases: - Action is not included in C{self.actions} - SignatureVersion is not included in C{self.signature_versions} - Expires and Timestamp are present - Expires is before the current time - Timestamp is older than 15 minutes. """ utc_now = self.get_utc_time() if getattr(self, "actions", None) is not None: # Check the deprecated 'actions' attribute if not args.Action in self.actions: raise APIError(400, "InvalidAction", "The action %s is not " "valid for this web service." % args.Action) else: self.registry.check(args.Action, args.Version) if not args.SignatureVersion in self.signature_versions: raise APIError(403, "InvalidSignature", "SignatureVersion '%s' " "not supported" % args.SignatureVersion) if args.Expires and args.Timestamp: raise APIError(400, "InvalidParameterCombination", "The parameter Timestamp cannot be used with " "the parameter Expires") if args.Expires and args.Expires < utc_now: raise APIError(400, "RequestExpired", "Request has expired. Expires date is %s" % ( args.Expires.strftime(self.time_format))) if args.Timestamp and args.Timestamp + timedelta(minutes=15) < utc_now: raise APIError(400, "RequestExpired", "Request has expired. Timestamp date is %s" % ( args.Timestamp.strftime(self.time_format))) def _validate_principal(self, principal, args): """Validate the principal.""" if principal is None: raise APIError(401, "AuthFailure", "No user with access key '%s'" % args.AWSAccessKeyId) def _validate_signature(self, request, principal, args, params): """Validate the signature.""" creds = AWSCredentials(principal.access_key, principal.secret_key) endpoint = AWSServiceEndpoint() endpoint.set_method(request.method) endpoint.set_canonical_host(request.getHeader("Host")) path = request.path if self.path is not None: path = "%s/%s" % (self.path.rstrip("/"), path.lstrip("/")) endpoint.set_path(path) params.pop("Signature") signature = Signature(creds, endpoint, params) if signature.compute() != args.Signature: raise APIError(403, "SignatureDoesNotMatch", "The request signature we calculated does not " "match the signature you provided. Check your " "key and signing method.") def get_status_text(self): """Get the text to return when a status check is made.""" return "Query API Service" def render_GET(self, request): """Handle a GET request.""" if not request.args: request.setHeader("Content-Type", "text/plain") return self.get_status_text() else: self.handle(request) return NOT_DONE_YET render_POST = render_GET txAWS-0.2.3/txaws/server/registry.py0000664000175000017500000000427011741311335021112 0ustar oubiwannoubiwann00000000000000from txaws.server.exception import APIError class Registry(object): """Register API L{Method}s. for handling specific actions and versions""" def __init__(self): self._by_action = {} def add(self, method_class, action, version=None): """Add a method class to the regitry. @param method_class: The method class to add @param action: The action that the method class can handle @param version: The version that the method class can handle """ by_version = self._by_action.setdefault(action, {}) if version in by_version: raise RuntimeError("A method was already registered for action" " %s in version %s" % (action, version)) by_version[version] = method_class def check(self, action, version=None): """Check if the given action is supported in the given version. @raises APIError: If there's no method class registered for handling the given action or version. """ if action not in self._by_action: raise APIError(400, "InvalidAction", "The action %s is not valid " "for this web service." % action) by_version = self._by_action[action] if None not in by_version: # There's no catch-all method, let's try the version-specific one if version not in by_version: raise APIError(400, "InvalidVersion", "Invalid API version.") def get(self, action, version=None): """Get the method class handing the given action and version.""" by_version = self._by_action[action] if version in by_version: return by_version[version] else: return by_version[None] def scan(self, module, onerror=None, ignore=None): """Scan the given module object for L{Method}s and register them.""" from venusian import Scanner scanner = Scanner(registry=self) kwargs = {"onerror": onerror, "categories": ["method"]} if ignore is not None: # Only pass it if specified, for backward compatibility kwargs["ignore"] = ignore scanner.scan(module, **kwargs) txAWS-0.2.3/txaws/server/method.py0000664000175000017500000000324611741311335020524 0ustar oubiwannoubiwann00000000000000def method(method_class): """Decorator to use to mark an API method. When invoking L{Registry.scan} the classes marked with this decorator will be added to the registry. @param method_class: The L{Method} class to register. """ def callback(scanner, name, method_class): if method_class.actions is not None: actions = method_class.actions else: actions = [name] if method_class.versions is not None: versions = method_class.versions else: versions = [None] for action in actions: for version in versions: scanner.registry.add(method_class, action=action, version=version) from venusian import attach attach(method_class, callback, category="method") return method_class class Method(object): """Handle a single HTTP request to an API resource. @cvar actions: List of actions that the Method can handle, if C{None} the class name will be used as only supported action. @cvar versions: List of versions that the Method can handle, if C{None} all versions will be supported. """ actions = None versions = None def invoke(self, call): """Invoke this method for executing the given C{call}.""" raise NotImplemented("Sub-classes have to implement the invoke method") def is_available(self): """Return a boolean indicating wether this method is available. Override this to dynamically decide at run-time whether specific methods are available or not. """ return True txAWS-0.2.3/txaws/server/tests/0000775000175000017500000000000011741312025020024 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/server/tests/test_registry.py0000664000175000017500000001050411741311335023310 0ustar oubiwannoubiwann00000000000000from twisted.trial.unittest import TestCase from txaws.server.method import Method from txaws.server.registry import Registry from txaws.server.exception import APIError try: from txaws.server.tests.fixtures import ( has_venusian, importerror, amodule) from txaws.server.tests.fixtures.amodule import TestMethod from txaws.server.tests.fixtures.importerror.amodule import ( TestMethod as testmethod) no_class_decorators = False except SyntaxError: no_class_decorators = True has_venusian = False class RegistryTestCase(TestCase): if no_class_decorators: skip = ("Your version of Python doesn't seem to support class " "decorators.") def setUp(self): super(RegistryTestCase, self).setUp() self.registry = Registry() def test_add(self): """ L{MethodRegistry.add} registers a method class for the given action and version. """ self.registry.add(TestMethod, "test", "1.0") self.registry.add(TestMethod, "test", "2.0") self.registry.check("test", "1.0") self.registry.check("test", "2.0") self.assertIdentical(TestMethod, self.registry.get("test", "1.0")) self.assertIdentical(TestMethod, self.registry.get("test", "2.0")) def test_add_duplicate_method(self): """ L{MethodRegistry.add} fails if a method class for the given action and version was already registered. """ class TestMethod2(Method): pass self.registry.add(TestMethod, "test", "1.0") self.assertRaises(RuntimeError, self.registry.add, TestMethod2, "test", "1.0") def test_get(self): """ L{MethodRegistry.get} returns the method class registered for the given action and version. """ class TestMethod2(Method): pass self.registry.add(TestMethod, "test", "1.0") self.registry.add(TestMethod, "test", "2.0") self.registry.add(TestMethod2, "test", "3.0") self.assertIdentical(TestMethod, self.registry.get("test", "1.0")) self.assertIdentical(TestMethod, self.registry.get("test", "2.0")) self.assertIdentical(TestMethod2, self.registry.get("test", "3.0")) def test_check_with_missing_action(self): """ L{MethodRegistry.get} fails if the given action is not registered. """ error = self.assertRaises(APIError, self.registry.check, "boom", "1.0") self.assertEqual(400, error.status) self.assertEqual("InvalidAction", error.code) self.assertEqual("The action boom is not valid for this web service.", error.message) def test_check_with_missing_version(self): """ L{MethodRegistry.get} fails if the given action is not registered. """ self.registry.add(TestMethod, "test", "1.0") error = self.assertRaises(APIError, self.registry.check, "test", "2.0") self.assertEqual(400, error.status) self.assertEqual("InvalidVersion", error.code) self.assertEqual("Invalid API version.", error.message) def test_scan(self): """ L{MethodRegistry.scan} registers the L{Method}s decorated with L{api}. """ self.registry.scan(amodule) self.assertIdentical(TestMethod, self.registry.get("TestMethod", None)) def test_scan_raises_error_on_importerror(self): """ L{MethodRegistry.scan} raises an error by default when an error happens and there is no onerror callback is passed. """ self.assertRaises(ImportError, self.registry.scan, importerror) def test_scan_swallows_with_onerror(self): """ L{MethodRegistry.scan} accepts an onerror callback that can be used to deal with scanning errors. """ swallowed = [] def swallow(error): swallowed.append(error) self.registry.scan(importerror, onerror=swallow) self.assertEqual(1, len(swallowed)) self.assertEqual(testmethod, self.registry.get("TestMethod")) if not has_venusian: test_scan.skip = "venusian module not available" test_scan_raises_error_on_importerror.skip = ( "venusian module not available") test_scan_swallows_with_onerror.skip = "venusian module not available" txAWS-0.2.3/txaws/server/tests/__init__.py0000664000175000017500000000000011741311335022126 0ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/server/tests/test_call.py0000664000175000017500000000057211741311335022357 0ustar oubiwannoubiwann00000000000000from twisted.trial.unittest import TestCase from txaws.server.call import Call class CallTestCase(TestCase): def test_default_version(self): """ If no version is explicitly requested, C{version} is set to 2009-11-30, which is the earliest version we support. """ call = Call() self.assertEqual(call.version, "2009-11-30") txAWS-0.2.3/txaws/server/tests/fixtures/0000775000175000017500000000000011741312025021675 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/server/tests/fixtures/__init__.py0000664000175000017500000000026611741311335024015 0ustar oubiwannoubiwann00000000000000try: import venusian except ImportError: method = lambda function: function has_venusian = False else: from txaws.server.method import method has_venusian = True txAWS-0.2.3/txaws/server/tests/fixtures/amodule.py0000664000175000017500000000020311741311335023673 0ustar oubiwannoubiwann00000000000000from txaws.server.tests.fixtures import method from txaws.server.method import Method @method class TestMethod(Method): pass txAWS-0.2.3/txaws/server/tests/fixtures/importerror/0000775000175000017500000000000011741312025024261 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/server/tests/fixtures/importerror/__init__.py0000664000175000017500000000000011741311335026363 0ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/server/tests/fixtures/importerror/amodule.py0000664000175000017500000000020311741311335026257 0ustar oubiwannoubiwann00000000000000from txaws.server.method import Method from txaws.server.tests.fixtures import method @method class TestMethod(Method): pass txAWS-0.2.3/txaws/server/tests/fixtures/importerror/submodule/0000775000175000017500000000000011741312025026260 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/server/tests/fixtures/importerror/submodule/__init__.py0000664000175000017500000000000011741311335030362 0ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/server/tests/fixtures/importerror/submodule/will_raise_import_error.py0000664000175000017500000000002411741311335033566 0ustar oubiwannoubiwann00000000000000import doesnt.exist txAWS-0.2.3/txaws/server/tests/test_resource.py0000664000175000017500000006142711741311335023301 0ustar oubiwannoubiwann00000000000000from cStringIO import StringIO from datetime import datetime from dateutil.tz import tzutc try: import json except ImportError: import simplejson as json from twisted.trial.unittest import TestCase from twisted.python.reflect import safe_str from txaws.credentials import AWSCredentials from txaws.service import AWSServiceEndpoint from txaws.ec2.client import Query from txaws.server.method import Method from txaws.server.registry import Registry from txaws.server.resource import QueryAPI from txaws.server.exception import APIError class FakeRequest(object): def __init__(self, params, endpoint): self.params = params self.endpoint = endpoint self.written = StringIO() self.finished = False self.code = None self.headers = {"Host": endpoint.get_canonical_host()} @property def args(self): return dict((key, [value]) for key, value in self.params.iteritems()) @property def method(self): return self.endpoint.method @property def path(self): return self.endpoint.path def write(self, content): assert isinstance(content, str), "Only strings should be written" self.written.write(content) def finish(self): if self.code is None: self.code = 200 self.finished = True def setResponseCode(self, code): self.code = code def setHeader(self, key, value): self.headers[key] = value def getHeader(self, key): return self.headers.get(key) @property def response(self): return self.written.getvalue() class TestMethod(Method): def invoke(self, call): return "data" class TestPrincipal(object): def __init__(self, creds): self.creds = creds @property def access_key(self): return self.creds.access_key @property def secret_key(self): return self.creds.secret_key class TestQueryAPI(QueryAPI): signature_versions = (1, 2) content_type = "text/plain" def __init__(self, *args, **kwargs): QueryAPI.__init__(self, *args, **kwargs) self.principal = None def get_principal(self, access_key): if self.principal and self.principal.access_key == access_key: return self.principal def dump_error(self, error, request): return str("%s - %s" % (error.code, safe_str(error.message))) class QueryAPITestCase(TestCase): def setUp(self): super(QueryAPITestCase, self).setUp() self.registry = Registry() self.registry.add(TestMethod, action="SomeAction", version=None) self.api = TestQueryAPI(registry=self.registry) def test_handle(self): """ L{QueryAPI.handle} forwards valid requests to L{QueryAPI.execute}. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) def check(ignored): self.assertTrue(request.finished) self.assertEqual("data", request.response) self.assertEqual("4", request.headers["Content-Length"]) self.assertEqual("text/plain", request.headers["Content-Type"]) self.assertEqual(200, request.code) self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check) def test_handle_with_dump_result(self): """ L{QueryAPI.handle} serializes the action result with C{dump_result}. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) def check(ignored): self.assertEqual("data", json.loads(request.response)) self.api.dump_result = json.dumps self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check) def test_handle_with_deprecated_actions(self): """ L{QueryAPI.handle} supports the legacy 'actions' attribute. """ self.api.actions = ["SomeAction"] creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) def check(ignored): self.assertEqual("data", request.response) self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check) def test_handle_pass_params_to_call(self): """ L{QueryAPI.handle} creates a L{Call} object with the correct parameters. """ self.registry.add(TestMethod, "SomeAction", "1.2.3") creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint, other_params={"Foo": "bar", "Version": "1.2.3"}) query.sign() request = FakeRequest(query.params, endpoint) def execute(call): self.assertEqual({"Foo": "bar"}, call.get_raw_params()) self.assertIdentical(self.api.principal, call.principal) self.assertEqual("SomeAction", call.action) self.assertEqual("1.2.3", call.version) self.assertEqual(request.id, call.id) return "ok" def check(ignored): self.assertEqual("ok", request.response) self.assertEqual(200, request.code) self.api.execute = execute self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check) def test_handle_ensures_version_is_str(self): """ L{QueryAPI.schema} coerces the Version parameter to a str, in order to let URLs built with it be str, as required by urllib.quote in python 2.7. """ self.registry.add(TestMethod, "SomeAction", "1.2.3") creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint, other_params={"Version": u"1.2.3"}) query.sign() request = FakeRequest(query.params, endpoint) def execute(call): self.assertEqual("1.2.3", call.version) self.assertIsInstance(call.version, str) return "ok" self.api.execute = execute self.api.principal = TestPrincipal(creds) return self.api.handle(request) def test_handle_empty_request(self): """ If an empty request is received a message describing the API is returned. """ endpoint = AWSServiceEndpoint("http://uri") request = FakeRequest({}, endpoint) self.assertEqual("Query API Service", self.api.render(request)) self.assertEqual("text/plain", request.headers["Content-Type"]) self.assertEqual(None, request.code) def test_handle_with_signature_version_1(self): """SignatureVersion 1 is supported as well.""" creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint, other_params={"SignatureVersion": "1"}) query.sign() request = FakeRequest(query.params, endpoint) def check(ignore): self.assertEqual("data", request.response) self.assertEqual(200, request.code) self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check) def test_handle_with_signature_sha1(self): """ The C{HmacSHA1} signature method is supported, in which case the signing using sha1 instead. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign(hash_type="sha1") request = FakeRequest(query.params, endpoint) def check(ignore): self.assertEqual("data", request.response) self.assertEqual(200, request.code) self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check) def test_handle_with_unsupported_version(self): """If signature versions is not supported an error is raised.""" creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) def check(ignored): errors = self.flushLoggedErrors() self.assertEquals(0, len(errors)) self.assertEqual("InvalidSignature - SignatureVersion '2' " "not supported", request.response) self.assertEqual(403, request.code) self.api.signature_versions = (1,) return self.api.handle(request).addCallback(check) def test_handle_with_internal_error(self): """ If an unknown error occurs while handling the request, L{QueryAPI.handle} responds with HTTP status 500. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) self.api.execute = lambda call: 1 / 0 def check(ignored): errors = self.flushLoggedErrors() self.assertEquals(1, len(errors)) self.assertTrue(request.finished) self.assertEqual("integer division or modulo by zero", request.response) self.assertEqual(500, request.code) self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check) def test_handle_500_api_error(self): """ If an L{APIError} is raised with a status code superior or equal to 500, the error is logged on the server side. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) def fail_execute(call): raise APIError(500, response="oops") self.api.execute = fail_execute def check(ignored): errors = self.flushLoggedErrors() self.assertEquals(1, len(errors)) self.assertTrue(request.finished) self.assertEqual("oops", request.response) self.assertEqual(500, request.code) self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check) def test_handle_with_parameter_error(self): """ If an error occurs while parsing the parameters, L{QueryAPI.handle} responds with HTTP status 400. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign() query.params.pop("Action") request = FakeRequest(query.params, endpoint) def check(ignored): errors = self.flushLoggedErrors() self.assertEquals(0, len(errors)) self.assertEqual("MissingParameter - The request must contain " "the parameter Action", request.response) self.assertEqual(400, request.code) return self.api.handle(request).addCallback(check) def test_handle_unicode_api_error(self): """ If an L{APIError} contains a unicode message, L{QueryAPI} is able to protect itself from it. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) def fail_execute(call): raise APIError(400, code="LangError", message=u"\N{HIRAGANA LETTER A}dvanced") self.api.execute = fail_execute def check(ignored): errors = self.flushLoggedErrors() self.assertEquals(0, len(errors)) self.assertTrue(request.finished) self.assertTrue(request.response.startswith("LangError")) self.assertEqual(400, request.code) self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check) def test_handle_unicode_error(self): """ If an arbitrary error raised by an API method contains a unicode message, L{QueryAPI} is able to protect itself from it. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) def fail_execute(call): raise ValueError(u"\N{HIRAGANA LETTER A}dvanced") self.api.execute = fail_execute def check(ignored): errors = self.flushLoggedErrors() self.assertEquals(1, len(errors)) self.assertTrue(request.finished) self.assertIn("ValueError", request.response) self.assertEqual(500, request.code) self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check) def test_handle_with_unsupported_action(self): """Only actions registered in the L{Registry} are supported.""" creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="FooBar", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) def check(ignored): errors = self.flushLoggedErrors() self.assertEquals(0, len(errors)) self.assertEqual("InvalidAction - The action FooBar is not valid" " for this web service.", request.response) self.assertEqual(400, request.code) return self.api.handle(request).addCallback(check) def test_handle_non_available_method(self): """Only actions registered in the L{Registry} are supported.""" class NonAvailableMethod(Method): def is_available(self): return False self.registry.add(NonAvailableMethod, action="CantDoIt") creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="CantDoIt", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) def check(ignored): errors = self.flushLoggedErrors() self.assertEquals(0, len(errors)) self.assertEqual("InvalidAction - The action CantDoIt is not " "valid for this web service.", request.response) self.assertEqual(400, request.code) self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check) def test_handle_with_deprecated_actions_and_unsupported_action(self): """ If the deprecated L{QueryAPI.actions} attribute is set, it will be used for looking up supported actions. """ self.api.actions = ["SomeAction"] creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="FooBar", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) def check(ignored): errors = self.flushLoggedErrors() self.assertEquals(0, len(errors)) self.assertEqual("InvalidAction - The action FooBar is not valid" " for this web service.", request.response) self.assertEqual(400, request.code) return self.api.handle(request).addCallback(check) def test_handle_with_non_existing_user(self): """ If no L{Principal} can be found with the given access key ID, L{QueryAPI.handle} responds with HTTP status 400. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) def check(ignored): errors = self.flushLoggedErrors() self.assertEquals(0, len(errors)) self.assertEqual("AuthFailure - No user with access key 'access'", request.response) self.assertEqual(401, request.code) return self.api.handle(request).addCallback(check) def test_handle_with_wrong_signature(self): """ If the signature in the request doesn't match the one calculated with the locally stored secret access key, and error is returned. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign() query.params["Signature"] = "wrong" request = FakeRequest(query.params, endpoint) def check(ignored): errors = self.flushLoggedErrors() self.assertEquals(0, len(errors)) self.assertEqual("SignatureDoesNotMatch - The request signature " "we calculated does not match the signature you " "provided. Check your key and signing method.", request.response) self.assertEqual(403, request.code) self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check) def test_handle_with_timestamp_and_expires(self): """ If the request contains both Expires and Timestamp parameters, an error is returned. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint, other_params={"Timestamp": "2010-01-01T12:00:00Z", "Expires": "2010-01-01T12:00:00Z"}) query.sign() request = FakeRequest(query.params, endpoint) def check(ignored): errors = self.flushLoggedErrors() self.assertEquals(0, len(errors)) self.assertEqual( "InvalidParameterCombination - The parameter Timestamp" " cannot be used with the parameter Expires", request.response) self.assertEqual(400, request.code) return self.api.handle(request).addCallback(check) def test_handle_with_non_expired_signature(self): """ If the request contains an Expires parameter with a time that is after the current time, everything is fine. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint, other_params={"Expires": "2010-01-01T12:00:00Z"}) query.sign() request = FakeRequest(query.params, endpoint) def check(ignored): self.assertEqual("data", request.response) self.assertEqual(200, request.code) now = datetime(2009, 12, 31, tzinfo=tzutc()) self.api.get_utc_time = lambda: now self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check) def test_handle_with_expired_signature(self): """ If the request contains an Expires parameter with a time that is before the current time, an error is returned. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri") query = Query(action="SomeAction", creds=creds, endpoint=endpoint, other_params={"Expires": "2010-01-01T12:00:00Z"}) query.sign() request = FakeRequest(query.params, endpoint) def check(ignored): errors = self.flushLoggedErrors() self.assertEquals(0, len(errors)) self.assertEqual( "RequestExpired - Request has expired. Expires date is" " 2010-01-01T12:00:00Z", request.response) self.assertEqual(400, request.code) now = datetime(2010, 1, 1, 12, 0, 1, tzinfo=tzutc()) self.api.get_utc_time = lambda: now return self.api.handle(request).addCallback(check) def test_handle_with_post_method(self): """ L{QueryAPI.handle} forwards valid requests using the HTTP POST method to L{QueryAPI.execute}. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://uri", method="POST") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) def check(ignored): self.assertEqual("data", request.response) self.assertEqual(200, request.code) self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check) def test_handle_with_port_number(self): """ If the request Host header includes a port number, it's included in the text that get signed when checking the signature. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://endpoint:1234") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) def check(ignored): self.assertEqual("data", request.response) self.assertEqual(200, request.code) self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check) def test_handle_with_endpoint_with_terminating_slash(self): """ Check signature should handle a URI with a terminating slash. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://endpoint/") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) def check(ignored): self.assertEqual("data", request.response) self.assertEqual(200, request.code) self.api.principal = TestPrincipal(creds) return self.api.handle(request).addCallback(check) def test_handle_with_custom_path(self): """ If L{QueryAPI.path} is not C{None} it will be used in place of the HTTP request path when calculating the signature. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://endpoint/path/") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) # Simulate a request rewrite, like apache would do request.endpoint.path = "/" def check(ignored): self.assertTrue(request.finished) self.assertEqual(200, request.code) self.api.principal = TestPrincipal(creds) self.api.path = "/path/" return self.api.handle(request).addCallback(check) def test_handle_with_custom_path_and_rest(self): """ If L{QueryAPI.path} is not C{None} it will be used in place of the HTTP request path when calculating the signature. The rest of the path is appended as for the HTTP request. """ creds = AWSCredentials("access", "secret") endpoint = AWSServiceEndpoint("http://endpoint/path/rest") query = Query(action="SomeAction", creds=creds, endpoint=endpoint) query.sign() request = FakeRequest(query.params, endpoint) # Simulate a request rewrite, like apache would do request.endpoint.path = "/rest" def check(ignored): self.assertTrue(request.finished) self.assertEqual(200, request.code) self.api.principal = TestPrincipal(creds) self.api.path = "/path/" return self.api.handle(request).addCallback(check) txAWS-0.2.3/txaws/server/tests/test_exception.py0000664000175000017500000000403611741311335023441 0ustar oubiwannoubiwann00000000000000# -*- coding: utf-8 -*- from unittest import TestCase from txaws.server.exception import APIError class APIErrorTestCase(TestCase): def test_with_no_parameters(self): """ The L{APIError} constructor must be passed either a code/message pair or a full response payload. """ self.assertRaises(RuntimeError, APIError, 400) def test_with_response_and_code(self): """ If the L{APIError} constructor is passed a full response payload, it can't be passed an error code. """ self.assertRaises(RuntimeError, APIError, 400, code="FooBar", response="foo bar") def test_with_response_and_message(self): """ If the L{APIError} constructor is passed a full response payload, it can't be passed an error code. """ self.assertRaises(RuntimeError, APIError, 400, message="Foo Bar", response="foo bar") def test_with_code_and_no_message(self): """ If the L{APIError} constructor is passed an error code, it must be passed an error message as well. """ self.assertRaises(RuntimeError, APIError, 400, code="FooBar") def test_with_message_and_no_code(self): """ If the L{APIError} constructor is passed an error message, it must be passed an error code as well. """ self.assertRaises(RuntimeError, APIError, 400, message="Foo Bar") def test_with_string_status(self): """ The L{APIError} constructor can be passed a C{str} as status code, and it will be converted to C{intp}. """ error = APIError("200", response="noes") self.assertEqual(200, error.status) def test_with_unicode_message(self): """ L{APIError} will convert message to plain ASCII if converted to string. """ error = APIError(400, code="APIError", message=u"cittá") self.assertEqual(u"cittá", error.message) self.assertEqual("citt?", str(error)) txAWS-0.2.3/txaws/server/tests/test_schema.py0000664000175000017500000005605011741311335022706 0ustar oubiwannoubiwann00000000000000# -*- coding: utf-8 -*- from datetime import datetime from dateutil.tz import tzutc, tzoffset from twisted.trial.unittest import TestCase from txaws.server.exception import APIError from txaws.server.schema import ( Arguments, Bool, Date, Enum, Integer, Parameter, RawStr, Schema, Unicode) class ArgumentsTestCase(TestCase): def test_instantiate_empty(self): """Creating an L{Arguments} object.""" arguments = Arguments({}) self.assertEqual({}, arguments.__dict__) def test_instantiate_non_empty(self): """Creating an L{Arguments} object with some arguments.""" arguments = Arguments({"foo": 123, "bar": 456}) self.assertEqual(123, arguments.foo) self.assertEqual(456, arguments.bar) def test_iterate(self): """L{Arguments} returns an iterator with both keys and values.""" arguments = Arguments({"foo": 123, "bar": 456}) self.assertEqual([("foo", 123), ("bar", 456)], list(arguments)) def test_getitem(self): """Values can be looked up using C{[index]} notation.""" arguments = Arguments({1: "a", 2: "b", "foo": "bar"}) self.assertEqual("b", arguments[2]) self.assertEqual("bar", arguments["foo"]) def test_getitem_error(self): """L{KeyError} is raised when the argument is not found.""" arguments = Arguments({}) self.assertRaises(KeyError, arguments.__getitem__, 1) def test_len(self): """C{len()} can be used with an L{Arguments} instance.""" self.assertEqual(0, len(Arguments({}))) self.assertEqual(1, len(Arguments({1: 2}))) def test_nested_data(self): """L{Arguments} can cope fine with nested data structures.""" arguments = Arguments({"foo": Arguments({"bar": "egg"})}) self.assertEqual("egg", arguments.foo.bar) def test_nested_data_with_numbers(self): """L{Arguments} can cope fine with list items.""" arguments = Arguments({"foo": {1: "egg"}}) self.assertEqual("egg", arguments.foo[0]) class ParameterTestCase(TestCase): def test_coerce(self): """ L{Parameter.coerce} coerces a request argument with a single value. """ parameter = Parameter("Test") parameter.parse = lambda value: value self.assertEqual("foo", parameter.coerce("foo")) def test_coerce_with_optional(self): """L{Parameter.coerce} returns C{None} if the parameter is optional.""" parameter = Parameter("Test", optional=True) self.assertEqual(None, parameter.coerce(None)) def test_coerce_with_required(self): """ L{Parameter.coerce} raises an L{APIError} if the parameter is required but not present in the request. """ parameter = Parameter("Test") error = self.assertRaises(APIError, parameter.coerce, None) self.assertEqual(400, error.status) self.assertEqual("MissingParameter", error.code) self.assertEqual("The request must contain the parameter Test", error.message) def test_coerce_with_default(self): """ L{Parameter.coerce} returns F{Parameter.default} if the parameter is optional and not present in the request. """ parameter = Parameter("Test", optional=True, default=123) self.assertEqual(123, parameter.coerce(None)) def test_coerce_with_parameter_error(self): """ L{Parameter.coerce} raises an L{APIError} if an invalid value is passed as request argument. """ parameter = Parameter("Test") parameter.parse = lambda value: int(value) parameter.kind = "integer" error = self.assertRaises(APIError, parameter.coerce, "foo") self.assertEqual(400, error.status) self.assertEqual("InvalidParameterValue", error.code) self.assertEqual("Invalid integer value foo", error.message) def test_coerce_with_parameter_error_unicode(self): """ L{Parameter.coerce} raises an L{APIError} if an invalid value is passed as request argument and parameter value is unicode. """ parameter = Parameter("Test") parameter.parse = lambda value: int(value) parameter.kind = "integer" error = self.assertRaises(APIError, parameter.coerce, "citt\xc3\xa1") self.assertEqual(400, error.status) self.assertEqual("InvalidParameterValue", error.code) self.assertEqual(u"Invalid integer value cittá", error.message) def test_coerce_with_empty_strings(self): """ L{Parameter.coerce} returns C{None} if the value is an empty string and C{allow_none} is C{True}. """ parameter = Parameter("Test", allow_none=True) self.assertEqual(None, parameter.coerce("")) def test_coerce_with_empty_strings_error(self): """ L{Parameter.coerce} raises an error if the value is an empty string and C{allow_none} is not C{True}. """ parameter = Parameter("Test") error = self.assertRaises(APIError, parameter.coerce, "") self.assertEqual(400, error.status) self.assertEqual("MissingParameter", error.code) self.assertEqual("The request must contain the parameter Test", error.message) def test_coerce_with_min(self): """ L{Parameter.coerce} raises an error if the given value is lower than the lower bound. """ parameter = Parameter("Test", min=50) parameter.measure = lambda value: int(value) parameter.lower_than_min_template = "Please give me at least %s" error = self.assertRaises(APIError, parameter.coerce, "4") self.assertEqual(400, error.status) self.assertEqual("InvalidParameterValue", error.code) self.assertEqual("Value (4) for parameter Test is invalid. " "Please give me at least 50", error.message) def test_coerce_with_max(self): """ L{Parameter.coerce} raises an error if the given value is greater than the upper bound. """ parameter = Parameter("Test", max=3) parameter.measure = lambda value: len(value) parameter.greater_than_max_template = "%s should be enough for anybody" error = self.assertRaises(APIError, parameter.coerce, "longish") self.assertEqual(400, error.status) self.assertEqual("InvalidParameterValue", error.code) self.assertEqual("Value (longish) for parameter Test is invalid. " "3 should be enough for anybody", error.message) def test_validator_invalid(self): """ L{Parameter.coerce} raises an error if the validator returns False. """ parameter = Parameter("Test", validator=lambda _: False) parameter.parse = lambda value: value parameter.kind = "test_parameter" error = self.assertRaises(APIError, parameter.coerce, "foo") self.assertEqual(400, error.status) self.assertEqual("InvalidParameterValue", error.code) self.assertEqual("Invalid test_parameter value foo", error.message) def test_validator_valid(self): """ L{Parameter.coerce} returns the correct value if validator returns True. """ parameter = Parameter("Test", validator=lambda _: True) parameter.parse = lambda value: value parameter.kind = "test_parameter" self.assertEqual("foo", parameter.coerce("foo")) class UnicodeTestCase(TestCase): def test_parse(self): """L{Unicode.parse} converts the given raw C{value} to C{unicode}.""" parameter = Unicode("Test") self.assertEqual(u"foo", parameter.parse("foo")) def test_parse_unicode(self): """L{Unicode.parse} works with unicode input.""" parameter = Unicode("Test") self.assertEqual(u"cittá", parameter.parse("citt\xc3\xa1")) def test_format(self): """L{Unicode.format} encodes the given C{unicode} with utf-8.""" parameter = Unicode("Test") value = parameter.format(u"fo\N{TAGBANWA LETTER SA}") self.assertEqual("fo\xe1\x9d\xb0", value) self.assertTrue(isinstance(value, str)) def test_min_and_max(self): """The L{Unicode} parameter properly supports ranges.""" parameter = Unicode("Test", min=2, max=4) error = self.assertRaises(APIError, parameter.coerce, "a") self.assertEqual(400, error.status) self.assertEqual("InvalidParameterValue", error.code) self.assertIn("Length must be at least 2.", error.message) error = self.assertRaises(APIError, parameter.coerce, "abcde") self.assertIn("Length exceeds maximum of 4.", error.message) self.assertEqual(400, error.status) self.assertEqual("InvalidParameterValue", error.code) def test_invalid_unicode(self): """ The L{Unicode} parameter returns an error with invalid unicode data. """ parameter = Unicode("Test") error = self.assertRaises(APIError, parameter.coerce, "Test\x95Error") self.assertIn(u"Invalid unicode value", error.message) self.assertEqual(400, error.status) self.assertEqual("InvalidParameterValue", error.code) class RawStrTestCase(TestCase): def test_parse(self): """L{RawStr.parse} checks that the given raw C{value} is a string.""" parameter = RawStr("Test") self.assertEqual("foo", parameter.parse("foo")) def test_format(self): """L{RawStr.format} simply returns the given string.""" parameter = RawStr("Test") value = parameter.format("foo") self.assertEqual("foo", value) self.assertTrue(isinstance(value, str)) class IntegerTestCase(TestCase): def test_parse(self): """L{Integer.parse} converts the given raw C{value} to C{int}.""" parameter = Integer("Test") self.assertEqual(123, parameter.parse("123")) def test_parse_with_negative(self): """L{Integer.parse} converts the given raw C{value} to C{int}.""" parameter = Integer("Test") error = self.assertRaises(APIError, parameter.coerce, "-1") self.assertEqual(400, error.status) self.assertEqual("InvalidParameterValue", error.code) self.assertIn("Value must be at least 0.", error.message) def test_format(self): """L{Integer.format} converts the given integer to a string.""" parameter = Integer("Test") self.assertEqual("123", parameter.format(123)) def test_min_and_max(self): """The L{Integer} parameter properly supports ranges.""" parameter = Integer("Test", min=2, max=4) error = self.assertRaises(APIError, parameter.coerce, "1") self.assertEqual(400, error.status) self.assertEqual("InvalidParameterValue", error.code) self.assertIn("Value must be at least 2.", error.message) error = self.assertRaises(APIError, parameter.coerce, "5") self.assertIn("Value exceeds maximum of 4.", error.message) self.assertEqual(400, error.status) self.assertEqual("InvalidParameterValue", error.code) def test_non_integer_string(self): """ The L{Integer} parameter raises an L{APIError} when passed non-int values (in this case, a string). """ garbage = "blah" parameter = Integer("Test") error = self.assertRaises(APIError, parameter.coerce, garbage) self.assertEqual(400, error.status) self.assertEqual("InvalidParameterValue", error.code) self.assertIn("Invalid integer value %s" % garbage, error.message) class BoolTestCase(TestCase): def test_parse(self): """L{Bool.parse} converts 'true' to C{True}.""" parameter = Bool("Test") self.assertEqual(True, parameter.parse("true")) def test_parse_with_false(self): """L{Bool.parse} converts 'false' to C{False}.""" parameter = Bool("Test") self.assertEqual(False, parameter.parse("false")) def test_parse_with_error(self): """ L{Bool.parse} raises C{ValueError} if the given value is neither 'true' or 'false'. """ parameter = Bool("Test") self.assertRaises(ValueError, parameter.parse, "0") def test_format(self): """L{Bool.format} converts the given boolean to either '0' or '1'.""" parameter = Bool("Test") self.assertEqual("true", parameter.format(True)) self.assertEqual("false", parameter.format(False)) class EnumTestCase(TestCase): def test_parse(self): """L{Enum.parse} accepts a map for translating values.""" parameter = Enum("Test", {"foo": "bar"}) self.assertEqual("bar", parameter.parse("foo")) def test_parse_with_error(self): """ L{Bool.parse} raises C{ValueError} if the given value is not present in the mapping. """ parameter = Enum("Test", {}) self.assertRaises(ValueError, parameter.parse, "bar") def test_format(self): """L{Enum.format} converts back the given value to the original map.""" parameter = Enum("Test", {"foo": "bar"}) self.assertEqual("foo", parameter.format("bar")) class DateTestCase(TestCase): def test_parse(self): """L{Date.parse checks that the given raw C{value} is a date/time.""" parameter = Date("Test") date = datetime(2010, 9, 15, 23, 59, 59, tzinfo=tzutc()) self.assertEqual(date, parameter.parse("2010-09-15T23:59:59Z")) def test_format(self): """ L{Date.format} returns a string representation of the given datetime instance. """ parameter = Date("Test") date = datetime(2010, 9, 15, 23, 59, 59, tzinfo=tzoffset('UTC', 120 * 60)) self.assertEqual("2010-09-15T21:59:59Z", parameter.format(date)) class SchemaTestCase(TestCase): def test_extract(self): """ L{Schema.extract} returns an L{Argument} object whose attributes are the arguments extracted from the given C{request}, as specified. """ schema = Schema(Unicode("name")) arguments, _ = schema.extract({"name": "value"}) self.assertEqual("value", arguments.name) def test_extract_with_rest(self): """ L{Schema.extract} stores unknown parameters in the 'rest' return dictionary. """ schema = Schema() _, rest = schema.extract({"name": "value"}) self.assertEqual(rest, {"name": "value"}) def test_extract_with_many_arguments(self): """L{Schema.extract} can handle multiple parameters.""" schema = Schema(Unicode("name"), Integer("count")) arguments, _ = schema.extract({"name": "value", "count": "123"}) self.assertEqual(u"value", arguments.name) self.assertEqual(123, arguments.count) def test_extract_with_optional(self): """L{Schema.extract} can handle optional parameters.""" schema = Schema(Unicode("name"), Integer("count", optional=True)) arguments, _ = schema.extract({"name": "value"}) self.assertEqual(u"value", arguments.name) self.assertEqual(None, arguments.count) def test_extract_with_numbered(self): """ L{Schema.extract} can handle parameters with numbered values. """ schema = Schema(Unicode("name.n")) arguments, _ = schema.extract({"name.0": "Joe", "name.1": "Tom"}) self.assertEqual("Joe", arguments.name[0]) self.assertEqual("Tom", arguments.name[1]) def test_extract_with_single_numbered(self): """ L{Schema.extract} can handle a single parameter with a numbered value. """ schema = Schema(Unicode("name.n")) arguments, _ = schema.extract({"name.0": "Joe"}) self.assertEqual("Joe", arguments.name[0]) def test_extract_complex(self): """L{Schema} can cope with complex schemas.""" schema = Schema( Unicode("GroupName"), RawStr("IpPermissions.n.IpProtocol"), Integer("IpPermissions.n.FromPort"), Integer("IpPermissions.n.ToPort"), Unicode("IpPermissions.n.Groups.m.UserId", optional=True), Unicode("IpPermissions.n.Groups.m.GroupName", optional=True)) arguments, _ = schema.extract( {"GroupName": "Foo", "IpPermissions.1.IpProtocol": "tcp", "IpPermissions.1.FromPort": "1234", "IpPermissions.1.ToPort": "5678", "IpPermissions.1.Groups.1.GroupName": "Bar", "IpPermissions.1.Groups.2.GroupName": "Egg"}) self.assertEqual(u"Foo", arguments.GroupName) self.assertEqual(1, len(arguments.IpPermissions)) self.assertEqual(1234, arguments.IpPermissions[0].FromPort) self.assertEqual(5678, arguments.IpPermissions[0].ToPort) self.assertEqual(2, len(arguments.IpPermissions[0].Groups)) self.assertEqual("Bar", arguments.IpPermissions[0].Groups[0].GroupName) self.assertEqual("Egg", arguments.IpPermissions[0].Groups[1].GroupName) def test_extract_with_multiple_parameters_in_singular_schema(self): """ If multiple parameters are passed in to a Schema element that is not flagged as supporting multiple values then we should throw an C{APIError}. """ schema = Schema(Unicode("name")) params = {"name.1": "value", "name.2": "value2"} error = self.assertRaises(APIError, schema.extract, params) self.assertEqual(400, error.status) self.assertEqual("InvalidParameterCombination", error.code) self.assertEqual("The parameter 'name' may only be specified once.", error.message) def test_extract_with_mixed(self): """ L{Schema.extract} stores in the rest result all numbered parameters given without an index. """ schema = Schema(Unicode("name.n")) _, rest = schema.extract({"name": "foo", "name.1": "bar"}) self.assertEqual(rest, {"name": "foo"}) def test_extract_with_non_numbered_template(self): """ L{Schema.extract} accepts a single numbered argument even if the associated template is not numbered. """ schema = Schema(Unicode("name")) arguments, _ = schema.extract({"name.1": "foo"}) self.assertEqual("foo", arguments.name) def test_extract_with_non_integer_index(self): """ L{Schema.extract} raises an error when trying to pass a numbered parameter with a non-integer index. """ schema = Schema(Unicode("name.n")) params = {"name.one": "foo"} error = self.assertRaises(APIError, schema.extract, params) self.assertEqual(400, error.status) self.assertEqual("UnknownParameter", error.code) self.assertEqual("The parameter name.one is not recognized", error.message) def test_extract_with_negative_index(self): """ L{Schema.extract} raises an error when trying to pass a numbered parameter with a negative index. """ schema = Schema(Unicode("name.n")) params = {"name.-1": "foo"} error = self.assertRaises(APIError, schema.extract, params) self.assertEqual(400, error.status) self.assertEqual("UnknownParameter", error.code) self.assertEqual("The parameter name.-1 is not recognized", error.message) def test_bundle(self): """ L{Schema.bundle} returns a dictionary of raw parameters that can be used for an EC2-style query. """ schema = Schema(Unicode("name")) params = schema.bundle(name="foo") self.assertEqual({"name": "foo"}, params) def test_bundle_with_numbered(self): """ L{Schema.bundle} correctly handles numbered arguments. """ schema = Schema(Unicode("name.n")) params = schema.bundle(name=["foo", "bar"]) self.assertEqual({"name.1": "foo", "name.2": "bar"}, params) def test_bundle_with_none(self): """L{None} values are discarded in L{Schema.bundle}.""" schema = Schema(Unicode("name.n", optional=True)) params = schema.bundle(name=None) self.assertEqual({}, params) def test_bundle_with_empty_numbered(self): """ L{Schema.bundle} correctly handles an empty numbered arguments list. """ schema = Schema(Unicode("name.n")) params = schema.bundle(names=[]) self.assertEqual({}, params) def test_bundle_with_numbered_not_supplied(self): """ L{Schema.bundle} ignores parameters that are not present. """ schema = Schema(Unicode("name.n")) params = schema.bundle() self.assertEqual({}, params) def test_bundle_with_multiple(self): """ L{Schema.bundle} correctly handles multiple arguments. """ schema = Schema(Unicode("name.n"), Integer("count")) params = schema.bundle(name=["Foo", "Bar"], count=123) self.assertEqual({"name.1": "Foo", "name.2": "Bar", "count": "123"}, params) def test_bundle_with_arguments(self): """L{Schema.bundle} can bundle L{Arguments} too.""" schema = Schema(Unicode("name.n"), Integer("count")) arguments = Arguments({"name": Arguments({1: "Foo", 7: "Bar"}), "count": 123}) params = schema.bundle(arguments) self.assertEqual({"name.1": "Foo", "name.7": "Bar", "count": "123"}, params) def test_bundle_with_arguments_and_extra(self): """ L{Schema.bundle} can bundle L{Arguments} with keyword arguments too. Keyword arguments take precedence. """ schema = Schema(Unicode("name.n"), Integer("count")) arguments = Arguments({"name": {1: "Foo", 7: "Bar"}, "count": 321}) params = schema.bundle(arguments, count=123) self.assertEqual({"name.1": "Foo", "name.2": "Bar", "count": "123"}, params) def test_bundle_with_missing_parameter(self): """ L{Schema.bundle} raises an exception one of the given parameters doesn't exist in the schema. """ schema = Schema(Integer("count")) self.assertRaises(RuntimeError, schema.bundle, name="foo") def test_add_single_extra_schema_item(self): """New Parameters can be added to the Schema.""" schema = Schema(Unicode("name")) schema = schema.extend(Unicode("computer")) arguments, _ = schema.extract({"name": "value", "computer": "testing"}) self.assertEqual(u"value", arguments.name) self.assertEqual("testing", arguments.computer) def test_add_extra_schema_items(self): """A list of new Parameters can be added to the Schema.""" schema = Schema(Unicode("name")) schema = schema.extend(Unicode("computer"), Integer("count")) arguments, _ = schema.extract({"name": "value", "computer": "testing", "count": "5"}) self.assertEqual(u"value", arguments.name) self.assertEqual("testing", arguments.computer) self.assertEqual(5, arguments.count) txAWS-0.2.3/txaws/server/tests/test_method.py0000664000175000017500000000076611741311335022731 0ustar oubiwannoubiwann00000000000000from twisted.trial.unittest import TestCase from txaws.server.method import Method class MethodTestCase(TestCase): def setUp(self): super(MethodTestCase, self).setUp() self.method = Method() def test_defaults(self): """ By default a L{Method} applies to all API versions and handles a single action matching its class name. """ self.assertIdentical(None, self.method.actions) self.assertIdentical(None, self.method.versions) txAWS-0.2.3/txaws/server/call.py0000664000175000017500000000453311741311335020157 0ustar oubiwannoubiwann00000000000000from uuid import uuid4 from txaws.version import ec2_api as ec2_api_version from txaws.server.exception import APIError class Call(object): """Hold information about a single API call initiated by an HTTP request. @param raw_params: The raw parameters for the action to be executed, the format is a dictionary mapping parameter names to parameter values, like C{{'ParamName': param_value}}. @param principal: The principal issuing this API L{Call}. @param action: The action to be performed. @ivar id: A unique identifier for the API call. @ivar principal: The principal performing the call. @ivar args: An L{Arguments} object holding parameters extracted from the raw parameters according to a L{Schema}, it will be available after calling the C{parse} method. @ivar rest: Extra parameters not included in the given arguments schema, it will be available after calling the L{parse} method. @ivar version: The version of the API call. Defaults to 2009-11-30. """ def __init__(self, raw_params=None, principal=None, action=None, version=None, id=None): if id is None: id = str(uuid4()) self.id = id self._raw_params = {} if raw_params is not None: self._raw_params.update(raw_params) self.action = action if version is None: version = ec2_api_version self.version = version self.principal = principal def parse(self, schema, strict=True): """Update C{args} and C{rest}, parsing the raw request arguments. @param schema: The L{Schema} the parameters must be extracted with. @param strict: If C{True} an error is raised if parameters not included in the schema are found, otherwise the extra parameters will be saved in the C{rest} attribute. """ self.args, self.rest = schema.extract(self._raw_params) if strict and self.rest: raise APIError(400, "UnknownParameter", "The parameter %s is not " "recognized" % self.rest.keys()[0]) def get_raw_params(self): """Return a C{dict} holding the raw API call paramaters. The format of the dictionary is C{{'ParamName': param_value}}. """ return self._raw_params.copy() txAWS-0.2.3/txaws/service.py0000664000175000017500000001241311741311335017372 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2009 Duncan McGreggor # Copyright (C) 2009 Robert Collins # Licenced under the txaws licence available at /LICENSE in the txaws source. from txaws.credentials import AWSCredentials from txaws import regions from txaws.util import parse __all__ = ["AWSServiceEndpoint", "AWSServiceRegion", "REGION_US", "REGION_EU"] # These old variable names are maintained for backwards compatibility. REGION_US = regions.REGION_US REGION_EU = regions.REGION_EU EC2_ENDPOINT_US = regions.EC2_ENDPOINT_US EC2_ENDPOINT_EU = regions.EC2_ENDPOINT_EU S3_ENDPOINT = regions.S3_ENDPOINT class AWSServiceEndpoint(object): """ @param uri: The URL for the service. @param method: The HTTP method used when accessing a service. @param ssl_hostname_verification: Whether or not SSL hotname verification will be done when connecting to the endpoint. """ def __init__(self, uri="", method="GET", ssl_hostname_verification=False): self.host = "" self.port = None self.path = "/" self.method = method self.ssl_hostname_verification = ssl_hostname_verification self._parse_uri(uri) if not self.scheme: self.scheme = "http" def _parse_uri(self, uri): scheme, host, port, path = parse( str(uri), defaultPort=False) self.scheme = scheme self.host = host self.port = port self.path = path def set_host(self, host): self.host = host def get_host(self): return self.host def get_canonical_host(self): """ Return the canonical host as for the Host HTTP header specification. """ host = self.host.lower() if self.port is not None: host = "%s:%s" % (host, self.port) return host def set_canonical_host(self, canonical_host): """ Set host and port from a canonical host string as for the Host HTTP header specification. """ parts = canonical_host.lower().split(":") self.host = parts[0] if len(parts) > 1 and parts[1]: self.port = int(parts[1]) else: self.port = None def set_path(self, path): self.path = path def get_uri(self): """Get a URL representation of the service.""" uri = "%s://%s%s" % (self.scheme, self.get_canonical_host(), self.path) return uri def set_method(self, method): self.method = method class AWSServiceRegion(object): """ This object represents a collection of client factories that use the same credentials. With Amazon, this collection is associated with a region (e.g., US or EU). @param creds: an AWSCredentials instance, optional. @param access_key: The access key to use. This is only checked if no creds parameter was passed. @param secret_key: The secret key to use. This is only checked if no creds parameter was passed. @param region: a string value that represents the region that the associated creds will be used against a collection of services. @param uri: an endpoint URI that, if provided, will override the region parameter. @param method: The method argument forwarded to L{AWSServiceEndpoint}. """ # XXX update unit test to check for both ec2 and s3 endpoints def __init__(self, creds=None, access_key="", secret_key="", region=REGION_US, uri="", ec2_uri="", s3_uri="", method="GET"): if not creds: creds = AWSCredentials(access_key, secret_key) self.creds = creds # Provide backwards compatibility for the "uri" parameter. if uri and not ec2_uri: ec2_uri = uri if not ec2_uri and region == REGION_US: ec2_uri = EC2_ENDPOINT_US elif not ec2_uri and region == REGION_EU: ec2_uri = EC2_ENDPOINT_EU if not s3_uri: s3_uri = S3_ENDPOINT self._clients = {} self.ec2_endpoint = AWSServiceEndpoint(uri=ec2_uri, method=method) self.s3_endpoint = AWSServiceEndpoint(uri=s3_uri, method=method) def get_client(self, cls, purge_cache=False, *args, **kwds): """ This is a general method for getting a client: if present, it is pulled from the cache; if not, a new one is instantiated and then put into the cache. This method should not be called directly, but rather by other client-specific methods (e.g., get_ec2_client). """ key = str(cls) + str(args) + str(kwds) instance = self._clients.get(key) if purge_cache or not instance: instance = cls(*args, **kwds) self._clients[key] = instance return instance def get_ec2_client(self, creds=None): from txaws.ec2.client import EC2Client if creds: self.creds = creds return self.get_client(EC2Client, creds=self.creds, endpoint=self.ec2_endpoint, query_factory=None) def get_s3_client(self, creds=None): from txaws.s3.client import S3Client if creds: self.creds = creds return self.get_client(S3Client, creds=self.creds, endpoint=self.s3_endpoint, query_factory=None) txAWS-0.2.3/txaws/client/0000775000175000017500000000000011741312025016632 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/client/discover/0000775000175000017500000000000011741312025020450 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/client/discover/__init__.py0000664000175000017500000000000011741311335022552 0ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/client/discover/command.py0000664000175000017500000000637211741311335022453 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2010 Jamu Kakar # Licenced under the txaws licence available at /LICENSE in the txaws source. """ A L{Command} object makes an arbitrary EC2 API method call and displays the response received from the backend cloud. """ import sys from txaws.ec2.client import Query from txaws.exception import AWSError from txaws.service import AWSServiceRegion class Command(object): """ An EC2 API method call command that can make a request and display the response received from the backend cloud. @param key: The AWS access key ID to use when making the method call. @param secret: The AWS secret key to sign the method call with. @param endpoint: The URL of the cloud to invoke the method on. @param parameters: A C{dict} with parameters to include with the method call. @param output: Optionally, a stream to write output to. Defaults to C{sys.stdout}. @param query_factory: Optionally, a factory to create the L{Query} object used to invoke the method. Defaults to returning a L{Query} instance. """ def __init__(self, key, secret, endpoint, action, parameters, output=None, query_factory=None): self.key = key self.secret = secret self.endpoint = endpoint self.action = action self.parameters = parameters if output is None: output = sys.stdout self.output = output if query_factory is None: query_factory = Query self.query_factory = query_factory def run(self): """ Run the configured method and write the HTTP response status and text to the output stream. """ region = AWSServiceRegion(access_key=self.key, secret_key=self.secret, uri=self.endpoint) query = self.query_factory(action=self.action, creds=region.creds, endpoint=region.ec2_endpoint, other_params=self.parameters) def write_response(response): print >> self.output, "URL: %s" % query.client.url print >> self.output print >> self.output, "HTTP status code: %s" % query.client.status print >> self.output print >> self.output, response def write_error(failure): if failure.check(AWSError): message = failure.value.original else: message = failure.getErrorMessage() if message.startswith("Error Message: "): message = message[len("Error Message: "):] print >> self.output, "URL: %s" % query.client.url print >> self.output if getattr(query.client, "status", None) is not None: print >> self.output, "HTTP status code: %s" % ( query.client.status,) print >> self.output print >> self.output, message if getattr(failure.value, "response", None) is not None: print >> self.output print >> self.output, failure.value.response deferred = query.submit() deferred.addCallback(write_response) deferred.addErrback(write_error) return deferred txAWS-0.2.3/txaws/client/discover/entry_point.py0000664000175000017500000001526511741311335023410 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2010 Jamu Kakar # Licenced under the txaws licence available at /LICENSE in the txaws source. """A command-line client for discovering how the EC2 API works.""" import os import sys from txaws.client.discover.command import Command class OptionError(Exception): """ Raised if insufficient command-line arguments are provided when creating a L{Command}. """ class UsageError(Exception): """Raised if the usage message should be shown.""" USAGE_MESSAGE = """\ Purpose: Invoke an EC2 API method with arbitrary parameters. Usage: txaws-discover [--key KEY] [--secret SECRET] [--endpoint ENDPOINT] --action ACTION [PARAMETERS, ...] Options: --key The AWS access key to use when making the API request. --secret The AWS secret key to use when making the API request. --endpoint The region endpoint to make the API request against. --action The name of the EC2 API to invoke. -h, --help Show help message. Description: The purpose of this program is to aid discovery of the EC2 API. It can run any EC2 API method, with arbitrary parameters. The response received from the backend cloud is printed to the screen, to show exactly what happened in response to the request. The --key, --secret, --endpoint and --action command-line arguments are required. If AWS_ENDPOINT, AWS_ACCESS_KEY_ID or AWS_SECRET_ACCESS_KEY environment variables are defined the corresponding options can be omitted and the values defined in the environment variables will be used. Any additional parameters, beyond those defined above, will be included with the request as method parameters. Examples: The following examples omit the --key, --secret and --endpoint command-line arguments for brevity. They must be included unless corresponding values are available from the environment. Run the DescribeRegions method, without any optional parameters: txaws-discover --action DescribeRegions Run the DescribeRegions method, with an optional RegionName.0 parameter: txaws-discover --action DescribeRegions --RegionName.0 us-west-1 """ def parse_options(arguments): """Parse command line arguments. The parsing logic is fairly simple. It can only parse long-style parameters of the form:: --key value Several parameters can be defined in the environment and will be used unless explicitly overridden with command-line arguments. The access key, secret and endpoint values will be loaded from C{AWS_ACCESS_KEY_ID}, C{AWS_SECRET_ACCESS_KEY} and C{AWS_ENDPOINT} environment variables. @param arguments: A list of command-line arguments. The first item is expected to be the name of the program being run. @raises OptionError: Raised if incorrectly formed command-line arguments are specified, or if required command-line arguments are not present. @raises UsageError: Raised if C{--help} is present in command-line arguments. @return: A C{dict} with key/value pairs extracted from the argument list. """ arguments = arguments[1:] options = {} while arguments: key = arguments.pop(0) if key in ("-h", "--help"): raise UsageError("Help requested.") if key.startswith("--"): key = key[2:] try: value = arguments.pop(0) except IndexError: raise OptionError("'--%s' is missing a value." % key) options[key] = value else: raise OptionError("Encountered unexpected value '%s'." % key) default_key = os.environ.get("AWS_ACCESS_KEY_ID") if "key" not in options and default_key: options["key"] = default_key default_secret = os.environ.get("AWS_SECRET_ACCESS_KEY") if "secret" not in options and default_secret: options["secret"] = default_secret default_endpoint = os.environ.get("AWS_ENDPOINT") if "endpoint" not in options and default_endpoint: options["endpoint"] = default_endpoint for name in ("key", "secret", "endpoint", "action"): if name not in options: raise OptionError( "The '--%s' command-line argument is required." % name) return options def get_command(arguments, output=None): """Parse C{arguments} and configure a L{Command} instance. An access key, secret key, endpoint and action are required. Additional parameters included with the request are passed as parameters to the method call. For example, the following command will create a L{Command} object that can invoke the C{DescribeRegions} method with the optional C{RegionName.0} parameter included in the request:: txaws-discover --key KEY --secret SECRET --endpoint URL \ --action DescribeRegions --RegionName.0 us-west-1 @param arguments: The command-line arguments to parse. @raises OptionError: Raised if C{arguments} can't be used to create a L{Command} object. @return: A L{Command} instance configured to make an EC2 API method call. """ options = parse_options(arguments) key = options.pop("key") secret = options.pop("secret") endpoint = options.pop("endpoint") action = options.pop("action") return Command(key, secret, endpoint, action, options, output) def main(arguments, output=None, testing_mode=None): """ Entry point parses command-line arguments, runs the specified EC2 API method and prints the response to the screen. @param arguments: Command-line arguments, typically retrieved from C{sys.argv}. @param output: Optionally, a stream to write output to. @param testing_mode: Optionally, a condition that specifies whether or not to run in test mode. When the value is true a reactor will not be run or stopped, to prevent interfering with the test suite. """ def run_command(arguments, output, reactor): if output is None: output = sys.stdout try: command = get_command(arguments, output) except UsageError: print >>output, USAGE_MESSAGE.strip() if reactor: reactor.callLater(0, reactor.stop) except Exception, e: print >>output, "ERROR:", str(e) if reactor: reactor.callLater(0, reactor.stop) else: deferred = command.run() if reactor: deferred.addCallback(lambda ignored: reactor.stop()) if not testing_mode: from twisted.internet import reactor reactor.callLater(0, run_command, arguments, output, reactor) reactor.run() else: run_command(arguments, output, None) txAWS-0.2.3/txaws/client/discover/tests/0000775000175000017500000000000011741312025021612 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/client/discover/tests/__init__.py0000664000175000017500000000000011741311335023714 0ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/client/discover/tests/test_entry_point.py0000664000175000017500000002472511741311335025612 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2010 Jamu Kakar # Licenced under the txaws licence available at /LICENSE in the txaws source. """Unit tests for L{get_command}, L{parse_options} and L{main} functions.""" from cStringIO import StringIO import os import sys from txaws.client.discover.entry_point import ( OptionError, UsageError, get_command, main, parse_options, USAGE_MESSAGE) from txaws.testing.base import TXAWSTestCase class ParseOptionsTestCase(TXAWSTestCase): def test_parse_options(self): """ L{parse_options} returns a C{dict} contains options parsed from the command-line. """ options = parse_options([ "txaws-discover", "--key", "key", "--secret", "secret", "--endpoint", "endpoint", "--action", "action", "--something.else", "something.else"]) self.assertEqual({"key": "key", "secret": "secret", "endpoint": "endpoint", "action": "action", "something.else": "something.else"}, options) def test_parse_options_without_options(self): """An L{OptionError} is raised if no options are provided.""" self.assertRaises(OptionError, parse_options, ["txaws-discover"]) def test_parse_options_with_missing_value(self): """ An L{OptionError} is raised if an option is specified without a value. """ self.assertRaises(OptionError, parse_options, ["txaws-discover", "--key"]) def test_parse_options_with_missing_option(self): """ An L{OptionError} is raised if a value is specified without an option name. """ self.assertRaises( OptionError, parse_options, ["txaws-discover", "--key", "key", "--secret", "secret", "--endpoint", "endpoint", "--action", "action", "random-value"]) def test_parse_options_without_required_arguments(self): """ An access key, access secret, endpoint and action can be specified as command-line arguments. An L{OptionError} is raised if any one of these is missing. """ self.assertRaises(OptionError, parse_options, ["txaws-discover", "--secret", "secret", "--endpoint", "endpoint", "--action", "action"]) self.assertRaises(OptionError, parse_options, ["txaws-discover", "--key", "key", "--endpoint", "endpoint", "--action", "action"]) self.assertRaises(OptionError, parse_options, ["txaws-discover", "--key", "key", "--secret", "secret", "--action", "action"]) self.assertRaises(OptionError, parse_options, ["txaws-discover", "--key", "key", "--secret", "secret", "--endpoint", "endpoint"]) def test_parse_options_gets_key_from_environment(self): """ If the C{AWS_ACCESS_KEY_ID} environment variable is present, it will be used if the C{--key} command-line argument isn't specified. """ os.environ["AWS_ACCESS_KEY_ID"] = "key" options = parse_options([ "txaws-discover", "--secret", "secret", "--endpoint", "endpoint", "--action", "action"]) self.assertEqual({"key": "key", "secret": "secret", "endpoint": "endpoint", "action": "action"}, options) def test_parse_options_prefers_explicit_key(self): """ If an explicit C{--key} command-line argument is specified it will be preferred over the value specified in the C{AWS_ACCESS_KEY_ID} environment variable. """ os.environ["AWS_ACCESS_KEY_ID"] = "fail" options = parse_options([ "txaws-discover", "--key", "key", "--secret", "secret", "--endpoint", "endpoint", "--action", "action"]) self.assertEqual({"key": "key", "secret": "secret", "endpoint": "endpoint", "action": "action"}, options) def test_parse_options_gets_secret_from_environment(self): """ If the C{AWS_SECRET_ACCESS_KEY} environment variable is present, it will be used if the C{--secret} command-line argument isn't specified. """ os.environ["AWS_SECRET_ACCESS_KEY"] = "secret" options = parse_options([ "txaws-discover", "--key", "key", "--endpoint", "endpoint", "--action", "action"]) self.assertEqual({"key": "key", "secret": "secret", "endpoint": "endpoint", "action": "action"}, options) def test_parse_options_prefers_explicit_secret(self): """ If an explicit C{--secret} command-line argument is specified it will be preferred over the value specified in the C{AWS_SECRET_ACCESS_KEY} environment variable. """ os.environ["AWS_SECRET_ACCESS_KEY"] = "fail" options = parse_options([ "txaws-discover", "--key", "key", "--secret", "secret", "--endpoint", "endpoint", "--action", "action"]) self.assertEqual({"key": "key", "secret": "secret", "endpoint": "endpoint", "action": "action"}, options) def test_parse_options_gets_endpoint_from_environment(self): """ If the C{AWS_ENDPOINT} environment variable is present, it will be used if the C{--endpoint} command-line argument isn't specified. """ os.environ["AWS_ENDPOINT"] = "endpoint" options = parse_options([ "txaws-discover", "--key", "key", "--secret", "secret", "--action", "action"]) self.assertEqual({"key": "key", "secret": "secret", "endpoint": "endpoint", "action": "action"}, options) def test_parse_options_prefers_explicit_endpoint(self): """ If an explicit C{--endpoint} command-line argument is specified it will be preferred over the value specified in the C{AWS_ENDPOINT} environment variable. """ os.environ["AWS_ENDPOINT"] = "fail" options = parse_options([ "txaws-discover", "--key", "key", "--secret", "secret", "--endpoint", "endpoint", "--action", "action"]) self.assertEqual({"key": "key", "secret": "secret", "endpoint": "endpoint", "action": "action"}, options) def test_parse_options_raises_usage_error_when_help_specified(self): """ L{UsageError} is raised if C{-h} or C{--help} appears in command-line arguments. """ self.assertRaises(UsageError, parse_options, ["txaws-discover", "-h"]) self.assertRaises(UsageError, parse_options, ["txaws-discover", "--help"]) self.assertRaises(UsageError, parse_options, ["txaws-discover", "--key", "key", "--secret", "secret", "--endpoint", "endpoint", "--action", "action", "--help"]) class GetCommandTestCase(TXAWSTestCase): def test_get_command_without_arguments(self): """An L{OptionError} is raised if no arguments are provided.""" self.assertRaises(OptionError, get_command, ["txaws-discover"]) def test_get_command(self): """ An access key, access secret, endpoint and action can be specified as command-line arguments. """ command = get_command([ "txaws-discover", "--key", "key", "--secret", "secret", "--endpoint", "endpoint", "--action", "action"]) self.assertEqual("key", command.key) self.assertEqual("secret", command.secret) self.assertEqual("endpoint", command.endpoint) self.assertEqual("action", command.action) self.assertIdentical(sys.stdout, command.output) def test_get_command_with_custom_output_stream(self): output = StringIO() command = get_command([ "txaws-discover", "--key", "key", "--secret", "secret", "--endpoint", "endpoint", "--action", "action"], output) self.assertIdentical(output, command.output) def test_get_command_without_required_arguments(self): """ An access key, access secret, endpoint and action can be specified as command-line arguments. An L{OptionError} is raised if any one of these is missing. """ self.assertRaises(OptionError, get_command, ["txaws-discover", "--secret", "secret", "--endpoint", "endpoint", "--action", "action"]) self.assertRaises(OptionError, get_command, ["txaws-discover", "--key", "key", "--endpoint", "endpoint", "--action", "action"]) self.assertRaises(OptionError, get_command, ["txaws-discover", "--key", "key", "--secret", "secret", "--action", "action"]) self.assertRaises(OptionError, get_command, ["txaws-discover", "--key", "key", "--secret", "secret", "--endpoint", "endpoint"]) def test_get_command_passes_additional_parameters_to_command(self): """ Command-line parameters beyond C{--key}, C{--secret}, C{--endpoint} and C{--action} are passed to the L{Command} in a parameter C{dict}. """ command = get_command([ "txaws-discover", "--key", "key", "--secret", "secret", "--endpoint", "endpoint", "--action", "DescribeRegions", "--Region.Name.0", "us-west-1"]) self.assertEqual({"Region.Name.0": "us-west-1"}, command.parameters) class MainTestCase(TXAWSTestCase): def test_usage_message(self): """ If a L{UsageError} is raised, the help screen is written to the output stream. """ output = StringIO() main(["txaws-discover", "--help"], output, True) self.assertEqual(USAGE_MESSAGE, output.getvalue()) def test_error_message(self): """ If an exception is raised, its message is written to the output stream. """ output = StringIO() main(["txaws-discover"], output, True) self.assertEqual( "ERROR: The '--key' command-line argument is required.\n", output.getvalue()) txAWS-0.2.3/txaws/client/discover/tests/test_command.py0000664000175000017500000001776411741311335024663 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2010 Jamu Kakar # Licenced under the txaws licence available at /LICENSE in the txaws source. """Unit tests for L{Command}.""" from cStringIO import StringIO from twisted.internet.defer import succeed, fail from twisted.web.error import Error from txaws.client.discover.command import Command from txaws.ec2.client import Query from txaws.testing.base import TXAWSTestCase class FakeHTTPClient(object): def __init__(self, status, url): self.status = status self.url = url class CommandTestCase(TXAWSTestCase): def prepare_command(self, response, status, action, parameters={}, get_page=None, error=None): """Prepare a L{Command} for testing.""" self.url = None self.method = None self.error = error self.response = response self.status = status self.output = StringIO() self.query = None if get_page is None: get_page = self.get_page self.get_page_function = get_page self.command = Command("key", "secret", "endpoint", action, parameters, self.output, self.query_factory) def query_factory(self, other_params=None, time_tuple=None, api_version=None, *args, **kwargs): """ Create a query with a hard-coded time to generate a fake response. """ time_tuple = (2010, 6, 4, 23, 40, 0, 0, 0, 0) self.query = Query(other_params, time_tuple, api_version, *args, **kwargs) self.query.get_page = self.get_page_function return self.query def get_page(self, url, method=None, timeout=0): """Fake C{get_page} method simulates a successful request.""" self.url = url self.method = method self.query.client = FakeHTTPClient(self.status, url) return succeed(self.response) def get_error_page(self, url, method=None, timeout=0): """Fake C{get_page} method simulates an error.""" self.url = url self.method = method self.query.client = FakeHTTPClient(self.status, url) return fail(self.error or Exception(self.response)) def test_run(self): """ When a method is invoked its HTTP status code and response text is written to the output stream. """ self.prepare_command("The response", 200, "DescribeRegions") def check(result): url = ( "http://endpoint?AWSAccessKeyId=key&" "Action=DescribeRegions&" "Signature=3%2BHSkQQosF1Sr9AL3kdY31tEfTWQ2whjJOUSc3kvc2c%3D&" "SignatureMethod=HmacSHA256&SignatureVersion=2&" "Timestamp=2010-06-04T23%3A40%3A00Z&Version=2009-11-30") self.assertEqual("GET", self.method) self.assertEqual(url, self.url) self.assertEqual("URL: %s\n" "\n" "HTTP status code: 200\n" "\n" "The response\n" % url, self.output.getvalue()) deferred = self.command.run() deferred.addCallback(check) return deferred def test_run_with_parameters(self): """Extra method parameters are included in the request.""" self.prepare_command("The response", 200, "DescribeRegions", {"RegionName.0": "us-west-1"}) def check(result): url = ( "http://endpoint?AWSAccessKeyId=key&" "Action=DescribeRegions&RegionName.0=us-west-1&" "Signature=6D8aCgSPQOYixowRHy26aRFzK2Vwgixl9uwegYX9nLA%3D&" "SignatureMethod=HmacSHA256&SignatureVersion=2&" "Timestamp=2010-06-04T23%3A40%3A00Z&Version=2009-11-30") self.assertEqual("GET", self.method) self.assertEqual(url, self.url) self.assertEqual("URL: %s\n" "\n" "HTTP status code: 200\n" "\n" "The response\n" % url, self.output.getvalue()) deferred = self.command.run() deferred.addCallback(check) return deferred def test_run_with_error(self): """ If an error message is returned by the backend cloud, it will be written to the output stream. """ self.prepare_command("The error response", 400, "DescribeRegions", {"RegionName.0": "us-west-1"}, self.get_error_page) def check(result): url = ( "http://endpoint?AWSAccessKeyId=key&" "Action=DescribeRegions&RegionName.0=us-west-1&" "Signature=6D8aCgSPQOYixowRHy26aRFzK2Vwgixl9uwegYX9nLA%3D&" "SignatureMethod=HmacSHA256&SignatureVersion=2&" "Timestamp=2010-06-04T23%3A40%3A00Z&Version=2009-11-30") self.assertEqual("GET", self.method) self.assertEqual(url, self.url) self.assertEqual("URL: %s\n" "\n" "HTTP status code: 400\n" "\n" "The error response\n" % url, self.output.getvalue()) deferred = self.command.run() return self.assertFailure(deferred, Exception).addErrback(check) def test_run_with_error_strips_non_response_text(self): """ The builtin L{AWSError} exception adds 'Error message: ' to beginning of the text retuned by the backend cloud. This is stripped when the message is written to the output stream. """ self.prepare_command("Error Message: The error response", 400, "DescribeRegions", {"RegionName.0": "us-west-1"}, self.get_error_page) def check(result): url = ( "http://endpoint?AWSAccessKeyId=key&" "Action=DescribeRegions&RegionName.0=us-west-1&" "Signature=P6C7cQJ7j93uIJyv2dTbpQG3EI7ArGBJT%2FzVH%2BDFhyY%3D&" "SignatureMethod=HmacSHA256&SignatureVersion=2&" "Timestamp=2010-06-04T23%3A40%3A00Z&Version=2009-11-30") self.assertEqual("GET", self.method) self.assertEqual(url, self.url) self.assertEqual("URL: %s\n" "\n" "HTTP status code: 400\n" "\n" "The error response\n" % url, self.output.getvalue()) deferred = self.command.run() deferred.addErrback(check) return deferred def test_run_with_error_payload(self): """ If the returned HTTP error contains a payload, it's printed out. """ self.prepare_command("Bad Request", 400, "DescribeRegions", {"RegionName.0": "us-west-1"}, self.get_error_page, Error(400, None, "bar")) def check(result): url = ( "http://endpoint?AWSAccessKeyId=key&" "Action=DescribeRegions&RegionName.0=us-west-1&" "Signature=6D8aCgSPQOYixowRHy26aRFzK2Vwgixl9uwegYX9nLA%3D&" "SignatureMethod=HmacSHA256&SignatureVersion=2&" "Timestamp=2010-06-04T23%3A40%3A00Z&Version=2009-11-30") self.assertEqual("GET", self.method) self.assertEqual(url, self.url) self.assertEqual("URL: %s\n" "\n" "HTTP status code: 400\n" "\n" "400 Bad Request\n" "\n" "bar\n" % url, self.output.getvalue()) deferred = self.command.run() deferred.addCallback(check) return deferred txAWS-0.2.3/txaws/client/__init__.py0000664000175000017500000000000011741311335020734 0ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/client/ssl.py0000664000175000017500000001054511741311335020015 0ustar oubiwannoubiwann00000000000000from glob import glob import os import re import sys from OpenSSL import SSL from OpenSSL.crypto import load_certificate, FILETYPE_PEM from twisted.internet.ssl import CertificateOptions from txaws import exception __all__ = ["VerifyingContextFactory", "get_ca_certs"] # Multiple defaults are supported; just add more paths, separated by colons. if sys.platform == "darwin": DEFAULT_CERTS_PATH = "/System/Library/OpenSSL/certs/:" # XXX Windows users can file a bug to add theirs, since we don't know what # the right path is else: DEFAULT_CERTS_PATH = "/etc/ssl/certs/:" class VerifyingContextFactory(CertificateOptions): """ A SSL context factory to pass to C{connectSSL} to check for hostname validity. """ def __init__(self, host, caCerts=None): if caCerts is None: caCerts = get_global_ca_certs() CertificateOptions.__init__(self, verify=True, caCerts=caCerts) self.host = host def _dnsname_match(self, dn, host): pats = [] for frag in dn.split(r"."): if frag == "*": pats.append("[^.]+") else: frag = re.escape(frag) pats.append(frag.replace(r"\*", "[^.]*")) rx = re.compile(r"\A" + r"\.".join(pats) + r"\Z", re.IGNORECASE) return bool(rx.match(host)) def verify_callback(self, connection, x509, errno, depth, preverifyOK): # Only check depth == 0 on chained certificates. if depth == 0: dns_found = False if getattr(x509, "get_extension", None) is not None: for index in range(x509.get_extension_count()): extension = x509.get_extension(index) if extension.get_short_name() != "subjectAltName": continue data = str(extension) for element in data.split(", "): key, value = element.split(":") if key != "DNS": continue if self._dnsname_match(value, self.host): return preverifyOK dns_found = True break if not dns_found: commonName = x509.get_subject().commonName if commonName is None: return False if not self._dnsname_match(commonName, self.host): return False else: return False return preverifyOK def _makeContext(self): context = CertificateOptions._makeContext(self) context.set_verify( SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT, self.verify_callback) return context def get_ca_certs(): """ Retrieve a list of CAs at either the DEFAULT_CERTS_PATH or the env override, TXAWS_CERTS_PATH. In order to find .pem files, this function checks first for presence of the TXAWS_CERTS_PATH environment variable that should point to a directory containing cert files. In the absense of this variable, the module-level DEFAULT_CERTS_PATH will be used instead. Note that both of these variables have have multiple paths in them, just like the familiar PATH environment variable (separated by colons). """ cert_paths = os.getenv("TXAWS_CERTS_PATH", DEFAULT_CERTS_PATH).split(":") certificate_authority_map = {} for path in cert_paths: for cert_file_name in glob(os.path.join(path, "*.pem")): # There might be some dead symlinks in there, so let's make sure # it's real. if not os.path.exists(cert_file_name): continue cert_file = open(cert_file_name) data = cert_file.read() cert_file.close() x509 = load_certificate(FILETYPE_PEM, data) digest = x509.digest("sha1") # Now, de-duplicate in case the same cert has multiple names. certificate_authority_map[digest] = x509 values = certificate_authority_map.values() if len(values) == 0: raise exception.CertsNotFoundError("Could not find any .pem files.") return values _ca_certs = None def get_global_ca_certs(): """Retrieve a singleton of CA certificates.""" global _ca_certs if _ca_certs is None: _ca_certs = get_ca_certs() return _ca_certs txAWS-0.2.3/txaws/client/gui/0000775000175000017500000000000011741312025017416 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/client/gui/__init__.py0000664000175000017500000000000011741311335021520 0ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/client/gui/tests/0000775000175000017500000000000011741312025020560 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/client/gui/tests/test_gtk.py0000664000175000017500000000043211741311335022760 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2009 Robert Collins # Licenced under the txaws licence available at /LICENSE in the txaws source. from twisted.trial.unittest import TestCase class UITestCase(TestCase): pass # Really need some, but UI testing hurts my brain. txAWS-0.2.3/txaws/client/gui/tests/__init__.py0000664000175000017500000000000011741311335022662 0ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/client/gui/gtk.py0000664000175000017500000001646511741311335020574 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2009 Robert Collins # Licenced under the txaws licence available at /LICENSE in the txaws source. """A GTK client for working with aws.""" from __future__ import absolute_import import gnomekeyring import gobject import gtk # DO NOT IMPORT twisted.internet, or things that import # twisted.internet. # Doing so loads the default Reactor, which is bad. thanks. from txaws.credentials import AWSCredentials __all__ = ["main"] class AWSStatusIcon(gtk.StatusIcon): """A status icon shown when instances are running.""" def __init__(self, reactor): gtk.StatusIcon.__init__(self) self.set_from_stock(gtk.STOCK_NETWORK) self.set_visible(True) self.reactor = reactor self.connect("activate", self.on_activate) self.probing = False # Nested import because otherwise we get "reactor already installed". self.password_dialog = None self.region = None try: creds = AWSCredentials() except ValueError: creds = self.from_gnomekeyring() if self.region is None: self.set_region(creds) self.create_client(creds) menu = """ """ actions = [ ("Menu", None, "Menu"), ("Stop instances", gtk.STOCK_STOP, "_Stop instances...", None, "Stop instances", self.on_stop_instances), ] ag = gtk.ActionGroup("Actions") ag.add_actions(actions) self.manager = gtk.UIManager() self.manager.insert_action_group(ag, 0) self.manager.add_ui_from_string(menu) self.menu = self.manager.get_widget( "/Menubar/Menu/Stop instances").props.parent self.connect("popup-menu", self.on_popup_menu) def set_region(self, creds): from txaws.service import AWSServiceRegion self.region = AWSServiceRegion(creds) def create_client(self, creds): if creds is not None: if self.region is None: self.set_region(creds) self.client = self.region.get_ec2_client() self.on_activate(None) else: # waiting on user entered credentials. self.client = None def from_gnomekeyring(self): # Try for gtk gui specific credentials. try: items = gnomekeyring.find_items_sync( gnomekeyring.ITEM_GENERIC_SECRET, { "aws-host": "aws.amazon.com", }) except (gnomekeyring.NoMatchError, gnomekeyring.DeniedError): self.show_a_password_dialog() return None else: key_id, secret_key = items[0].secret.split(":") return AWSCredentials(access_key=key_id, secret_key=secret_key) def show_a_password_dialog(self): self.password_dialog = gtk.Dialog( "Enter your AWS credentals", None, gtk.DIALOG_MODAL, (gtk.STOCK_OK, gtk.RESPONSE_ACCEPT, gtk.STOCK_CANCEL, gtk.RESPONSE_REJECT)) content = self.password_dialog.get_content_area() def add_entry(name): box = gtk.HBox() box.show() content.add(box) label = gtk.Label(name) label.show() box.add(label) entry = gtk.Entry() entry.show() box.add(entry) label.set_use_underline(True) label.set_mnemonic_widget(entry) add_entry("AWS _Access Key ID") add_entry("AWS _Secret Key") self.password_dialog.show() self.password_dialog.connect("response", self.save_key) self.password_dialog.run() def on_activate(self, data): if self.probing or not self.client: # don't ask multiple times, and don't ask until we have # credentials. return self.probing = True deferred = self.client.describe_instances() deferred.addCallbacks(self.showhide, self.describe_error) def on_popup_menu(self, status, button, time): self.menu.popup(None, None, None, button, time) def on_stop_instances(self, data): # It would be nice to popup a window to select instances.. TODO. deferred = self.client.describe_instances() deferred.addCallbacks(self.shutdown_instances, self.show_error) def save_key(self, response_id, data): try: if data != gtk.RESPONSE_ACCEPT: # User cancelled. They can ask for the password again somehow. return content = self.password_dialog.get_content_area() key_id = content.get_children()[0].get_children()[1].get_text() secret_key = content.get_children()[1].get_children()[1].get_text() creds = AWSCredentials(access_key=key_id, secret_key=secret_key) self.create_client(creds) gnomekeyring.item_create_sync( None, gnomekeyring.ITEM_GENERIC_SECRET, "AWS access credentials", {"aws-host": "aws.amazon.com"}, "%s:%s" % (key_id, secret_key), True) finally: self.password_dialog.hide() # XXX? Does this leak? self.password_dialog = None def showhide(self, reservation): active = 0 for instance in reservation: if instance.instance_state == "running": active += 1 self.set_tooltip("AWS Status - %d instances" % active) self.set_visible(active != 0) self.queue_check() def shutdown_instances(self, reservation): d = self.client.terminate_instances( *[instance.instance_id for instance in reservation]) d.addCallbacks(self.on_activate, self.show_error) def queue_check(self): self.probing = False self.reactor.callLater(60, self.on_activate, None) def show_error(self, error): # debugging output for now. print error.value try: print error.value.response except: pass def describe_error(self, error): from twisted.internet.defer import TimeoutError if isinstance(error.value, TimeoutError): # timeout errors can be ignored - transient network issue or some # such. pass else: # debugging output for now. self.show_error(error) self.queue_check() def main(argv, reactor=None): """Run the client GUI. Typical use: >>> sys.exit(main(sys.argv)) @param argv: The arguments to run it with, e.g. sys.argv. @param reactor: The reactor to use. Must be compatible with gtk as this module uses gtk API"s. @return exitcode: The exit code it returned, as per sys.exit. """ if reactor is None: from twisted.internet import gtk2reactor gtk2reactor.install() from twisted.internet import reactor try: AWSStatusIcon(reactor) gobject.set_application_name("aws-status") reactor.run() except ValueError: # In this case, the user cancelled, and the exception bubbled to here. pass txAWS-0.2.3/txaws/client/tests/0000775000175000017500000000000011741312025017774 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/client/tests/__init__.py0000664000175000017500000000000011741311335022076 0ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/client/tests/test_base.py0000664000175000017500000001620011741311335022321 0ustar oubiwannoubiwann00000000000000import os from twisted.internet import reactor from twisted.internet.error import ConnectionRefusedError from twisted.protocols.policies import WrappingFactory from twisted.python import log from twisted.python.filepath import FilePath from twisted.python.failure import Failure from twisted.test.test_sslverify import makeCertificate from twisted.web import server, static from twisted.web.client import HTTPClientFactory from twisted.web.error import Error as TwistedWebError from txaws.client import ssl from txaws.client.base import BaseClient, BaseQuery, error_wrapper from txaws.service import AWSServiceEndpoint from txaws.testing.base import TXAWSTestCase class ErrorWrapperTestCase(TXAWSTestCase): def test_204_no_content(self): failure = Failure(TwistedWebError(204, "No content")) wrapped = error_wrapper(failure, None) self.assertEquals(wrapped, "204 No content") def test_302_found(self): # XXX I'm not sure we want to raise for 300s... failure = Failure(TwistedWebError(302, "found")) error = self.assertRaises( Exception, error_wrapper, failure, None) self.assertEquals(failure.type, type(error)) self.assertTrue(isinstance(error, TwistedWebError)) self.assertEquals(str(error), "302 found") def test_500(self): failure = Failure(TwistedWebError(500, "internal error")) error = self.assertRaises( Exception, error_wrapper, failure, None) self.assertTrue(isinstance(error, TwistedWebError)) self.assertEquals(str(error), "500 internal error") def test_timeout_error(self): failure = Failure(Exception("timeout")) error = self.assertRaises(Exception, error_wrapper, failure, None) self.assertTrue(isinstance(error, Exception)) self.assertEquals(str(error), "timeout") def test_connection_error(self): failure = Failure(ConnectionRefusedError("timeout")) error = self.assertRaises( Exception, error_wrapper, failure, ConnectionRefusedError) self.assertTrue(isinstance(error, ConnectionRefusedError)) class BaseClientTestCase(TXAWSTestCase): def test_creation(self): client = BaseClient("creds", "endpoint", "query factory", "parser") self.assertEquals(client.creds, "creds") self.assertEquals(client.endpoint, "endpoint") self.assertEquals(client.query_factory, "query factory") self.assertEquals(client.parser, "parser") class BaseQueryTestCase(TXAWSTestCase): def setUp(self): self.cleanupServerConnections = 0 name = self.mktemp() os.mkdir(name) FilePath(name).child("file").setContent("0123456789") r = static.File(name) self.site = server.Site(r, timeout=None) self.wrapper = WrappingFactory(self.site) self.port = self._listen(self.wrapper) self.portno = self.port.getHost().port def tearDown(self): # If the test indicated it might leave some server-side connections # around, clean them up. connections = self.wrapper.protocols.keys() # If there are fewer server-side connections than requested, # that's okay. Some might have noticed that the client closed # the connection and cleaned up after themselves. for n in range(min(len(connections), self.cleanupServerConnections)): proto = connections.pop() log.msg("Closing %r" % (proto,)) proto.transport.loseConnection() if connections: log.msg("Some left-over connections; this test is probably buggy.") return self.port.stopListening() def _listen(self, site): return reactor.listenTCP(0, site, interface="127.0.0.1") def _get_url(self, path): return "http://127.0.0.1:%d/%s" % (self.portno, path) def test_creation(self): query = BaseQuery("an action", "creds", "http://endpoint") self.assertEquals(query.factory, HTTPClientFactory) self.assertEquals(query.action, "an action") self.assertEquals(query.creds, "creds") self.assertEquals(query.endpoint, "http://endpoint") def test_init_requires_action(self): self.assertRaises(TypeError, BaseQuery) def test_init_requires_creds(self): self.assertRaises(TypeError, BaseQuery, None) def test_get_page(self): query = BaseQuery("an action", "creds", "http://endpoint") d = query.get_page(self._get_url("file")) d.addCallback(self.assertEquals, "0123456789") return d def test_get_request_headers_no_client(self): query = BaseQuery("an action", "creds", "http://endpoint") results = query.get_request_headers() self.assertEquals(results, None) def test_get_request_headers_with_client(self): def check_results(results): self.assertEquals(results.keys(), []) self.assertEquals(results.values(), []) query = BaseQuery("an action", "creds", "http://endpoint") d = query.get_page(self._get_url("file")) d.addCallback(query.get_request_headers) return d.addCallback(check_results) def test_get_response_headers_no_client(self): query = BaseQuery("an action", "creds", "http://endpoint") results = query.get_response_headers() self.assertEquals(results, None) def test_get_response_headers_with_client(self): def check_results(results): self.assertEquals(sorted(results.keys()), [ "accept-ranges", "content-length", "content-type", "date", "last-modified", "server"]) self.assertEquals(len(results.values()), 6) query = BaseQuery("an action", "creds", "http://endpoint") d = query.get_page(self._get_url("file")) d.addCallback(query.get_response_headers) return d.addCallback(check_results) # XXX for systems that don't have certs in the DEFAULT_CERT_PATH, this test # will fail; instead, let's create some certs in a temp directory and set # the DEFAULT_CERT_PATH to point there. def test_ssl_hostname_verification(self): """ If the endpoint passed to L{BaseQuery} has C{ssl_hostname_verification} sets to C{True}, a L{VerifyingContextFactory} is passed to C{connectSSL}. """ class FakeReactor(object): def __init__(self): self.connects = [] def connectSSL(self, host, port, client, factory): self.connects.append((host, port, client, factory)) certs = makeCertificate(O="Test Certificate", CN="something")[1] self.patch(ssl, "_ca_certs", certs) fake_reactor = FakeReactor() endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) query = BaseQuery("an action", "creds", endpoint, fake_reactor) query.get_page("https://example.com/file") [(host, port, client, factory)] = fake_reactor.connects self.assertEqual("example.com", host) self.assertEqual(443, port) self.assertTrue(isinstance(factory, ssl.VerifyingContextFactory)) self.assertEqual("example.com", factory.host) self.assertNotEqual([], factory.caCerts) txAWS-0.2.3/txaws/client/tests/test_ssl.py0000664000175000017500000001724311741311335022220 0ustar oubiwannoubiwann00000000000000import os import tempfile from OpenSSL.crypto import dump_certificate, load_certificate, FILETYPE_PEM from OpenSSL.SSL import Error as SSLError from OpenSSL.version import __version__ as pyopenssl_version from twisted.internet import reactor from twisted.internet.ssl import DefaultOpenSSLContextFactory from twisted.protocols.policies import WrappingFactory from twisted.python import log from twisted.python.filepath import FilePath from twisted.test.test_sslverify import makeCertificate from twisted.web import server, static from txaws import exception from txaws.client import ssl from txaws.client.base import BaseQuery from txaws.service import AWSServiceEndpoint from txaws.testing.base import TXAWSTestCase def sibpath(path): return os.path.join(os.path.dirname(__file__), path) PRIVKEY = sibpath("private.ssl") PUBKEY = sibpath("public.ssl") BADPRIVKEY = sibpath("badprivate.ssl") BADPUBKEY = sibpath("badpublic.ssl") PRIVSANKEY = sibpath("private_san.ssl") PUBSANKEY = sibpath("public_san.ssl") class BaseQuerySSLTestCase(TXAWSTestCase): def setUp(self): self.cleanupServerConnections = 0 name = self.mktemp() os.mkdir(name) FilePath(name).child("file").setContent("0123456789") r = static.File(name) self.site = server.Site(r, timeout=None) self.wrapper = WrappingFactory(self.site) pub_key = file(PUBKEY) pub_key_data = pub_key.read() pub_key.close() pub_key_san = file(PUBSANKEY) pub_key_san_data = pub_key_san.read() pub_key_san.close() ssl._ca_certs = [load_certificate(FILETYPE_PEM, pub_key_data), load_certificate(FILETYPE_PEM, pub_key_san_data)] def tearDown(self): ssl._ca_certs = None # If the test indicated it might leave some server-side connections # around, clean them up. connections = self.wrapper.protocols.keys() # If there are fewer server-side connections than requested, # that's okay. Some might have noticed that the client closed # the connection and cleaned up after themselves. for n in range(min(len(connections), self.cleanupServerConnections)): proto = connections.pop() log.msg("Closing %r" % (proto,)) proto.transport.loseConnection() if connections: log.msg("Some left-over connections; this test is probably buggy.") return self.port.stopListening() def _get_url(self, path): return "https://localhost:%d/%s" % (self.portno, path) def test_ssl_verification_positive(self): """ The L{VerifyingContextFactory} properly allows to connect to the endpoint if the certificates match. """ context_factory = DefaultOpenSSLContextFactory(PRIVKEY, PUBKEY) self.port = reactor.listenSSL( 0, self.site, context_factory, interface="127.0.0.1") self.portno = self.port.getHost().port endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) query = BaseQuery("an action", "creds", endpoint) d = query.get_page(self._get_url("file")) return d.addCallback(self.assertEquals, "0123456789") def test_ssl_verification_negative(self): """ The L{VerifyingContextFactory} fails with a SSL error the certificates can't be checked. """ context_factory = DefaultOpenSSLContextFactory(BADPRIVKEY, BADPUBKEY) self.port = reactor.listenSSL( 0, self.site, context_factory, interface="127.0.0.1") self.portno = self.port.getHost().port endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) query = BaseQuery("an action", "creds", endpoint) d = query.get_page(self._get_url("file")) return self.assertFailure(d, SSLError) def test_ssl_verification_bypassed(self): """ L{BaseQuery} doesn't use L{VerifyingContextFactory} if C{ssl_hostname_verification} is C{False}, thus allowing to connect to non-secure endpoints. """ context_factory = DefaultOpenSSLContextFactory(BADPRIVKEY, BADPUBKEY) self.port = reactor.listenSSL( 0, self.site, context_factory, interface="127.0.0.1") self.portno = self.port.getHost().port endpoint = AWSServiceEndpoint(ssl_hostname_verification=False) query = BaseQuery("an action", "creds", endpoint) d = query.get_page(self._get_url("file")) return d.addCallback(self.assertEquals, "0123456789") def test_ssl_subject_alt_name(self): """ L{VerifyingContextFactory} supports checking C{subjectAltName} in the certificate if it's available. """ context_factory = DefaultOpenSSLContextFactory(PRIVSANKEY, PUBSANKEY) self.port = reactor.listenSSL( 0, self.site, context_factory, interface="127.0.0.1") self.portno = self.port.getHost().port endpoint = AWSServiceEndpoint(ssl_hostname_verification=True) query = BaseQuery("an action", "creds", endpoint) d = query.get_page("https://127.0.0.1:%d/file" % (self.portno,)) return d.addCallback(self.assertEquals, "0123456789") if pyopenssl_version < "0.12": test_ssl_subject_alt_name.skip = ( "subjectAltName not supported by older PyOpenSSL") class CertsFilesTestCase(TXAWSTestCase): def setUp(self): super(CertsFilesTestCase, self).setUp() # set up temp dir with no certs self.no_certs_dir = tempfile.mkdtemp() # create certs cert1 = makeCertificate(O="Server Certificate 1", CN="cn1") cert2 = makeCertificate(O="Server Certificate 2", CN="cn2") cert3 = makeCertificate(O="Server Certificate 3", CN="cn3") # set up temp dir with one cert self.one_cert_dir = tempfile.mkdtemp() self.cert1 = self._write_pem(cert1, self.one_cert_dir, "cert1.pem") # set up temp dir with two certs self.two_certs_dir = tempfile.mkdtemp() self.cert2 = self._write_pem(cert2, self.two_certs_dir, "cert2.pem") self.cert3 = self._write_pem(cert3, self.two_certs_dir, "cert3.pem") def tearDown(self): super(CertsFilesTestCase, self).tearDown() os.unlink(self.cert1) os.unlink(self.cert2) os.unlink(self.cert3) os.removedirs(self.no_certs_dir) os.removedirs(self.one_cert_dir) os.removedirs(self.two_certs_dir) def _write_pem(self, cert, dir, filename): data = dump_certificate(FILETYPE_PEM, cert[1]) full_path = os.path.join(dir, filename) fh = open(full_path, "w") fh.write(data) fh.close() return full_path def test_get_ca_certs_no_certs(self): os.environ["TXAWS_CERTS_PATH"] = self.no_certs_dir self.patch(ssl, "DEFAULT_CERTS_PATH", self.no_certs_dir) self.assertRaises(exception.CertsNotFoundError, ssl.get_ca_certs) def test_get_ca_certs_with_default_path(self): self.patch(ssl, "DEFAULT_CERTS_PATH", self.two_certs_dir) certs = ssl.get_ca_certs() self.assertEqual(len(certs), 2) def test_get_ca_certs_with_env_path(self): os.environ["TXAWS_CERTS_PATH"] = self.one_cert_dir certs = ssl.get_ca_certs() self.assertEqual(len(certs), 1) def test_get_ca_certs_multiple_paths(self): os.environ["TXAWS_CERTS_PATH"] = "%s:%s" % ( self.one_cert_dir, self.two_certs_dir) certs = ssl.get_ca_certs() self.assertEqual(len(certs), 3) def test_get_ca_certs_one_empty_path(self): os.environ["TXAWS_CERTS_PATH"] = "%s:%s" % ( self.no_certs_dir, self.one_cert_dir) certs = ssl.get_ca_certs() self.assertEqual(len(certs), 1) txAWS-0.2.3/txaws/client/base.py0000664000175000017500000001115511741311335020124 0ustar oubiwannoubiwann00000000000000try: from xml.etree.ElementTree import ParseError except ImportError: from xml.parsers.expat import ExpatError as ParseError from twisted.internet.ssl import ClientContextFactory from twisted.web import http from twisted.web.client import HTTPClientFactory from twisted.web.error import Error as TwistedWebError from txaws.util import parse from txaws.credentials import AWSCredentials from txaws.exception import AWSResponseParseError from txaws.service import AWSServiceEndpoint from txaws.client.ssl import VerifyingContextFactory def error_wrapper(error, errorClass): """ We want to see all error messages from cloud services. Amazon's EC2 says that their errors are accompanied either by a 400-series or 500-series HTTP response code. As such, the first thing we want to do is check to see if the error is in that range. If it is, we then need to see if the error message is an EC2 one. In the event that an error is not a Twisted web error nor an EC2 one, the original exception is raised. """ http_status = 0 if error.check(TwistedWebError): xml_payload = error.value.response if error.value.status: http_status = int(error.value.status) else: error.raiseException() if http_status >= 400: if not xml_payload: error.raiseException() try: fallback_error = errorClass( xml_payload, error.value.status, str(error.value), error.value.response) except (ParseError, AWSResponseParseError): error_message = http.RESPONSES.get(http_status) fallback_error = TwistedWebError( http_status, error_message, error.value.response) raise fallback_error elif 200 <= http_status < 300: return str(error.value) else: error.raiseException() class BaseClient(object): """Create an AWS client. @param creds: User authentication credentials to use. @param endpoint: The service endpoint URI. @param query_factory: The class or function that produces a query object for making requests to the EC2 service. @param parser: A parser object for parsing responses from the EC2 service. """ def __init__(self, creds=None, endpoint=None, query_factory=None, parser=None): if creds is None: creds = AWSCredentials() if endpoint is None: endpoint = AWSServiceEndpoint() self.creds = creds self.endpoint = endpoint self.query_factory = query_factory self.parser = parser class BaseQuery(object): def __init__(self, action=None, creds=None, endpoint=None, reactor=None): if not action: raise TypeError("The query requires an action parameter.") self.factory = HTTPClientFactory self.action = action self.creds = creds self.endpoint = endpoint if reactor is None: from twisted.internet import reactor self.reactor = reactor self.client = None def get_page(self, url, *args, **kwds): """ Define our own get_page method so that we can easily override the factory when we need to. This was copied from the following: * twisted.web.client.getPage * twisted.web.client._makeGetterFactory """ contextFactory = None scheme, host, port, path = parse(url) self.client = self.factory(url, *args, **kwds) if scheme == "https": if self.endpoint.ssl_hostname_verification: contextFactory = VerifyingContextFactory(host) else: contextFactory = ClientContextFactory() self.reactor.connectSSL(host, port, self.client, contextFactory) else: self.reactor.connectTCP(host, port, self.client) return self.client.deferred def get_request_headers(self, *args, **kwds): """ A convenience method for obtaining the headers that were sent to the S3 server. The AWS S3 API depends upon setting headers. This method is provided as a convenience for debugging issues with the S3 communications. """ if self.client: return self.client.headers def get_response_headers(self, *args, **kwargs): """ A convenience method for obtaining the headers that were sent from the S3 server. The AWS S3 API depends upon setting headers. This method is used by the head_object API call for getting a S3 object's metadata. """ if self.client: return self.client.response_headers txAWS-0.2.3/txaws/tests/0000775000175000017500000000000011741312025016516 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/tests/test_credentials.py0000664000175000017500000000266411741311335022437 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2009 Robert Collins # Licenced under the txaws licence available at /LICENSE in the txaws source. import os from txaws.credentials import AWSCredentials, ENV_ACCESS_KEY, ENV_SECRET_KEY from txaws.testing.base import TXAWSTestCase class CredentialsTestCase(TXAWSTestCase): def test_no_access_errors(self): # Without anything in os.environ, AWSService() blows up os.environ[ENV_SECRET_KEY] = "bar" self.assertRaises(ValueError, AWSCredentials) def test_no_secret_errors(self): # Without anything in os.environ, AWSService() blows up os.environ[ENV_ACCESS_KEY] = "foo" self.assertRaises(ValueError, AWSCredentials) def test_found_values_used(self): os.environ[ENV_ACCESS_KEY] = "foo" os.environ[ENV_SECRET_KEY] = "bar" service = AWSCredentials() self.assertEqual("foo", service.access_key) self.assertEqual("bar", service.secret_key) def test_explicit_access_key(self): os.environ[ENV_SECRET_KEY] = "foo" service = AWSCredentials(access_key="bar") self.assertEqual("foo", service.secret_key) self.assertEqual("bar", service.access_key) def test_explicit_secret_key(self): os.environ[ENV_ACCESS_KEY] = "bar" service = AWSCredentials(secret_key="foo") self.assertEqual("foo", service.secret_key) self.assertEqual("bar", service.access_key) txAWS-0.2.3/txaws/tests/test_service.py0000664000175000017500000002134611741311335021600 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2009 Duncan McGreggor # Licenced under the txaws licence available at /LICENSE in the txaws source. from txaws.credentials import AWSCredentials from txaws.ec2.client import EC2Client try: from txaws.s3.client import S3Client except ImportError: s3clientSkip = ("S3Client couldn't be imported (perhaps because dateutil, " "on which it depends, isn't present)") else: s3clientSkip = None from txaws.service import (AWSServiceEndpoint, AWSServiceRegion, EC2_ENDPOINT_EU, EC2_ENDPOINT_US, REGION_EU) from txaws.testing.base import TXAWSTestCase class AWSServiceEndpointTestCase(TXAWSTestCase): def setUp(self): self.endpoint = AWSServiceEndpoint(uri="http://my.service/da_endpoint") def test_simple_creation(self): endpoint = AWSServiceEndpoint() self.assertEquals(endpoint.scheme, "http") self.assertEquals(endpoint.host, "") self.assertEquals(endpoint.port, None) self.assertEquals(endpoint.path, "/") self.assertEquals(endpoint.method, "GET") def test_custom_method(self): endpoint = AWSServiceEndpoint( uri="http://service/endpoint", method="PUT") self.assertEquals(endpoint.method, "PUT") def test_parse_uri(self): self.assertEquals(self.endpoint.scheme, "http") self.assertEquals(self.endpoint.host, "my.service") self.assertIdentical(self.endpoint.port, None) self.assertEquals(self.endpoint.path, "/da_endpoint") def test_parse_uri_https_and_custom_port(self): endpoint = AWSServiceEndpoint(uri="https://my.service:8080/endpoint") self.assertEquals(endpoint.scheme, "https") self.assertEquals(endpoint.host, "my.service") self.assertEquals(endpoint.port, 8080) self.assertEquals(endpoint.path, "/endpoint") def test_get_uri(self): uri = self.endpoint.get_uri() self.assertEquals(uri, "http://my.service/da_endpoint") def test_get_uri_custom_port(self): uri = "https://my.service:8080/endpoint" endpoint = AWSServiceEndpoint(uri=uri) new_uri = endpoint.get_uri() self.assertEquals(new_uri, uri) def test_set_host(self): self.assertEquals(self.endpoint.host, "my.service") self.endpoint.set_host("newhost.com") self.assertEquals(self.endpoint.host, "newhost.com") def test_get_host(self): self.assertEquals(self.endpoint.host, self.endpoint.get_host()) def test_get_canonical_host(self): """ If the port is not specified the canonical host is the same as the host. """ uri = "http://my.service/endpoint" endpoint = AWSServiceEndpoint(uri=uri) self.assertEquals("my.service", endpoint.get_canonical_host()) def test_get_canonical_host_with_non_default_port(self): """ If the port is not the default, the canonical host includes it. """ uri = "http://my.service:99/endpoint" endpoint = AWSServiceEndpoint(uri=uri) self.assertEquals("my.service:99", endpoint.get_canonical_host()) def test_get_canonical_host_is_lower_case(self): """ The canonical host is guaranteed to be lower case. """ uri = "http://MY.SerVice:99/endpoint" endpoint = AWSServiceEndpoint(uri=uri) self.assertEquals("my.service:99", endpoint.get_canonical_host()) def test_set_canonical_host(self): """ The canonical host is converted to lower case. """ endpoint = AWSServiceEndpoint() endpoint.set_canonical_host("My.Service") self.assertEquals("my.service", endpoint.host) self.assertIdentical(None, endpoint.port) def test_set_canonical_host_with_port(self): """ The canonical host can optionally have a port. """ endpoint = AWSServiceEndpoint() endpoint.set_canonical_host("my.service:99") self.assertEquals("my.service", endpoint.host) self.assertEquals(99, endpoint.port) def test_set_canonical_host_with_empty_port(self): """ The canonical host can also have no port. """ endpoint = AWSServiceEndpoint() endpoint.set_canonical_host("my.service:") self.assertEquals("my.service", endpoint.host) self.assertIdentical(None, endpoint.port) def test_set_path(self): self.endpoint.set_path("/newpath") self.assertEquals( self.endpoint.get_uri(), "http://my.service/newpath") def test_set_method(self): self.assertEquals(self.endpoint.method, "GET") self.endpoint.set_method("PUT") self.assertEquals(self.endpoint.method, "PUT") class AWSServiceRegionTestCase(TXAWSTestCase): def setUp(self): self.creds = AWSCredentials("foo", "bar") self.region = AWSServiceRegion(creds=self.creds) def test_simple_creation(self): self.assertEquals(self.creds, self.region.creds) self.assertEquals(self.region._clients, {}) self.assertEquals(self.region.ec2_endpoint.get_uri(), EC2_ENDPOINT_US) def test_creation_with_keys(self): region = AWSServiceRegion(access_key="baz", secret_key="quux") self.assertEquals(region.creds.access_key, "baz") self.assertEquals(region.creds.secret_key, "quux") def test_creation_with_keys_and_creds(self): """ creds take precedence over individual access key/secret key pairs. """ region = AWSServiceRegion(self.creds, access_key="baz", secret_key="quux") self.assertEquals(region.creds.access_key, "foo") self.assertEquals(region.creds.secret_key, "bar") def test_creation_with_uri(self): region = AWSServiceRegion( creds=self.creds, ec2_uri="http://foo/bar") self.assertEquals(region.ec2_endpoint.get_uri(), "http://foo/bar") def test_creation_with_uri_backwards_compatible(self): region = AWSServiceRegion( creds=self.creds, uri="http://foo/bar") self.assertEquals(region.ec2_endpoint.get_uri(), "http://foo/bar") def test_creation_with_uri_and_region(self): region = AWSServiceRegion( creds=self.creds, region=REGION_EU, ec2_uri="http://foo/bar") self.assertEquals(region.ec2_endpoint.get_uri(), "http://foo/bar") def test_creation_with_region_override(self): region = AWSServiceRegion(creds=self.creds, region=REGION_EU) self.assertEquals(region.ec2_endpoint.get_uri(), EC2_ENDPOINT_EU) def test_get_ec2_client_with_empty_cache(self): key = str(EC2Client) + str(self.creds) + str(self.region.ec2_endpoint) original_client = self.region._clients.get(key) new_client = self.region.get_client( EC2Client, creds=self.creds, endpoint=self.region.ec2_endpoint) self.assertEquals(original_client, None) self.assertTrue(isinstance(new_client, EC2Client)) self.assertNotEquals(original_client, new_client) def test_get_ec2_client_from_cache_default(self): client1 = self.region.get_ec2_client() client2 = self.region.get_ec2_client() self.assertTrue(isinstance(client1, EC2Client)) self.assertTrue(isinstance(client2, EC2Client)) self.assertEquals(client1, client2) def test_get_ec2_client_from_cache(self): client1 = self.region.get_client( EC2Client, creds=self.creds, endpoint=self.region.ec2_endpoint) client2 = self.region.get_client( EC2Client, creds=self.creds, endpoint=self.region.ec2_endpoint) self.assertTrue(isinstance(client1, EC2Client)) self.assertTrue(isinstance(client2, EC2Client)) self.assertEquals(client1, client2) def test_get_ec2_client_from_cache_with_purge(self): client1 = self.region.get_client( EC2Client, creds=self.creds, endpoint=self.region.ec2_endpoint, purge_cache=True) client2 = self.region.get_client( EC2Client, creds=self.creds, endpoint=self.region.ec2_endpoint, purge_cache=True) self.assertTrue(isinstance(client1, EC2Client)) self.assertTrue(isinstance(client2, EC2Client)) self.assertNotEquals(client1, client2) def test_get_s3_client_with_empty_cache(self): key = str(S3Client) + str(self.creds) + str(self.region.s3_endpoint) original_client = self.region._clients.get(key) new_client = self.region.get_client( S3Client, creds=self.creds, endpoint=self.region.s3_endpoint) self.assertEquals(original_client, None) self.assertTrue(isinstance(new_client, S3Client)) self.assertNotEquals(original_client, new_client) test_get_s3_client_with_empty_cache.skip = s3clientSkip txAWS-0.2.3/txaws/tests/__init__.py0000664000175000017500000000000011741311335020620 0ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/tests/test_exception.py0000664000175000017500000001013411741311335022127 0ustar oubiwannoubiwann00000000000000# Copyright (c) 2009 Canonical Ltd # Licenced under the txaws licence available at /LICENSE in the txaws source. from twisted.trial.unittest import TestCase from txaws.exception import AWSError from txaws.exception import AWSResponseParseError from txaws.util import XML REQUEST_ID = "0ef9fc37-6230-4d81-b2e6-1b36277d4247" class AWSErrorTestCase(TestCase): def test_creation(self): error = AWSError("", 500, "Server Error", "") self.assertEquals(error.status, 500) self.assertEquals(error.response, "") self.assertEquals(error.original, "") self.assertEquals(error.errors, []) self.assertEquals(error.request_id, "") def test_node_to_dict(self): xml = "text1text2" error = AWSError("", 400) data = error._node_to_dict(XML(xml)) self.assertEquals(data, {"child1": "text1", "child2": "text2"}) def test_set_request_id(self): xml = "%s" % REQUEST_ID error = AWSError("", 400) error._set_request_id(XML(xml)) self.assertEquals(error.request_id, REQUEST_ID) def test_set_host_id(self): host_id = "ASD@#FDG$E%FG" xml = "%s" % host_id error = AWSError("", 400) error._set_host_id(XML(xml)) self.assertEquals(error.host_id, host_id) def test_set_empty_errors(self): xml = "" error = AWSError("", 500) error._set_500_error(XML(xml)) self.assertEquals(error.errors, []) def test_set_empty_error(self): xml = "" error = AWSError("", 500) error._set_500_error(XML(xml)) self.assertEquals(error.errors, []) def test_parse_without_xml(self): xml = "" error = AWSError(xml, 400) error.parse() self.assertEquals(error.original, xml) def test_parse_with_xml(self): xml1 = "" xml2 = "" error = AWSError(xml1, 400) error.parse(xml2) self.assertEquals(error.original, xml2) def test_parse_html(self): xml = "a page" self.assertRaises(AWSResponseParseError, AWSError, xml, 400) def test_empty_xml(self): self.assertRaises(ValueError, AWSError, "", 400) def test_no_request_id(self): errors = "" xml = "%s" % errors error = AWSError(xml, 400) self.assertEquals(error.request_id, "") def test_no_request_id_node(self): errors = "" xml = "%s" % errors error = AWSError(xml, 400) self.assertEquals(error.request_id, "") def test_no_errors_node(self): xml = "" error = AWSError(xml, 400) self.assertEquals(error.errors, []) def test_no_error_node(self): xml = "" error = AWSError(xml, 400) self.assertEquals(error.errors, []) def test_no_error_code_node(self): errors = "" xml = "%s" % errors error = AWSError(xml, 400) self.assertEquals(error.errors, []) def test_no_error_message_node(self): errors = "" xml = "%s" % errors error = AWSError(xml, 400) self.assertEquals(error.errors, []) def test_set_500_error(self): xml = "500Oops" error = AWSError("", 500) error._set_500_error(XML(xml)) self.assertEquals(error.errors[0]["Code"], "500") self.assertEquals(error.errors[0]["Message"], "Oops") txAWS-0.2.3/txaws/tests/test_util.py0000664000175000017500000000463311741311335021115 0ustar oubiwannoubiwann00000000000000from urlparse import urlparse from twisted.trial.unittest import TestCase from txaws.util import hmac_sha1, iso8601time, parse class MiscellaneousTestCase(TestCase): def test_hmac_sha1(self): cases = [ ("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b".decode("hex"), "Hi There", "thcxhlUFcmTii8C2+zeMjvFGvgA="), ("Jefe", "what do ya want for nothing?", "7/zfauXrL6LSdBbV8YTfnCWafHk="), ("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".decode("hex"), "\xdd" * 50, "El1zQrmsEc2Ro5r0iqF7T2PxddM="), ] for key, data, expected in cases: self.assertEqual(hmac_sha1(key, data), expected) def test_iso8601time(self): self.assertEqual("2006-07-07T15:04:56Z", iso8601time((2006, 7, 7, 15, 4, 56, 0, 0, 0))) class ParseUrlTestCase(TestCase): """ Test URL parsing facility and defaults values. """ def test_parse(self): """ L{parse} correctly parses a URL into its various components. """ # The default port for HTTP is 80. self.assertEqual( parse("http://127.0.0.1/"), ("http", "127.0.0.1", 80, "/")) # The default port for HTTPS is 443. self.assertEqual( parse("https://127.0.0.1/"), ("https", "127.0.0.1", 443, "/")) # Specifying a port. self.assertEqual( parse("http://spam:12345/"), ("http", "spam", 12345, "/")) # Weird (but commonly accepted) structure uses default port. self.assertEqual( parse("http://spam:/"), ("http", "spam", 80, "/")) # Spaces in the hostname are trimmed, the default path is /. self.assertEqual( parse("http://foo "), ("http", "foo", 80, "/")) def test_externalUnicodeInterference(self): """ L{parse} should return C{str} for the scheme, host, and path elements of its return tuple, even when passed an URL which has previously been passed to L{urlparse} as a C{unicode} string. """ badInput = u"http://example1.com/path" goodInput = badInput.encode("ascii") urlparse(badInput) scheme, host, port, path = parse(goodInput) self.assertTrue(isinstance(scheme, str)) self.assertTrue(isinstance(host, str)) self.assertTrue(isinstance(path, str)) txAWS-0.2.3/txaws/tests/test_wsdl.py0000664000175000017500000010172511741311335021111 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2010-2012 Canonical Ltd. # Licenced under the txaws licence available at /LICENSE in the txaws source. import os from twisted.trial.unittest import TestCase from txaws.wsdl import ( WSDLParseError, LeafSchema, NodeSchema, NodeItem, SequenceSchema, SequenceItem, WSDLParser, etree) class WsdlBaseTestCase(TestCase): if not etree: skip = "lxml is either not installed or broken on your system." class NodeSchemaTestCase(WsdlBaseTestCase): def test_create_with_bad_tag(self): """ L{NodeSchema.create} raises an error if the tag of the given element doesn't match the expected one. """ schema = NodeSchema("foo", [LeafSchema("bar")]) root = etree.fromstring("spam") error = self.assertRaises(WSDLParseError, schema.create, root) self.assertEqual("Expected response with tag 'foo', but got " "'egg' instead", error.args[0]) def test_add_with_invalid_min(self): """ L{NodeSchema.add} allows the C{min_occurs} parameter to only be C{None}, zero or one. """ schema = NodeSchema("foo") self.assertRaises(RuntimeError, schema.add, LeafSchema("bar"), min_occurs=-1) self.assertRaises(RuntimeError, schema.add, LeafSchema("bar"), min_occurs=2) def test_dump(self): """ L{NodeSchema.dump} creates an L{etree.Element} out of a L{NodeItem}. """ schema = NodeSchema("foo", [LeafSchema("bar")]) foo = NodeItem(schema) foo.bar = "spam" self.assertEqual("spam", etree.tostring(schema.dump(foo))) def test_dump_with_multiple_children(self): """ L{NodeSchema.dump} supports multiple children. """ schema = NodeSchema("foo", [LeafSchema("bar"), LeafSchema("egg")]) foo = NodeItem(schema) foo.bar = "spam1" foo.egg = "spam2" self.assertEqual("spam1spam2", etree.tostring(schema.dump(foo))) def test_dump_with_missing_attribute(self): """ L{NodeSchema.dump} ignores missing attributes if C{min_occurs} is zero. """ schema = NodeSchema("foo") schema.add(LeafSchema("bar"), min_occurs=0) foo = NodeItem(schema) self.assertEqual("", etree.tostring(schema.dump(foo))) class NodeItemTestCase(WsdlBaseTestCase): def test_get(self): """ The child leaf elements of a L{NodeItem} can be accessed as attributes. """ schema = NodeSchema("foo", [LeafSchema("bar")]) root = etree.fromstring("egg") foo = schema.create(root) self.assertEqual("egg", foo.bar) def test_get_with_many_children(self): """ Multiple children are supported. """ schema = NodeSchema("foo", [LeafSchema("bar"), LeafSchema("egg")]) root = etree.fromstring("spam1spam2") foo = schema.create(root) self.assertEqual("spam1", foo.bar) self.assertEqual("spam2", foo.egg) def test_get_with_namespace(self): """ The child leaf elements of a L{NodeItem} can be accessed as attributes. """ schema = NodeSchema("foo", [LeafSchema("bar")]) root = etree.fromstring("egg") foo = schema.create(root) self.assertEqual("egg", foo.bar) def test_get_with_unknown_tag(self): """ An error is raised when trying to access an attribute not in the schema. """ schema = NodeSchema("foo", [LeafSchema("bar")]) root = etree.fromstring("eggboom") foo = schema.create(root) error = self.assertRaises(WSDLParseError, getattr, foo, "spam") self.assertEqual("Unknown tag 'spam'", error.args[0]) def test_get_with_duplicate_tag(self): """ An error is raised when trying to access an attribute associated with a tag that appears more than once. """ schema = NodeSchema("foo", [LeafSchema("bar")]) root = etree.fromstring("spam1spam2") item = schema.create(root) error = self.assertRaises(WSDLParseError, getattr, item, "bar") self.assertEqual("Duplicate tag 'bar'", error.args[0]) def test_get_with_missing_required_tag(self): """ An error is raised when trying to access a required attribute and the associated tag is missing. """ schema = NodeSchema("foo", [LeafSchema("bar")]) root = etree.fromstring("") item = schema.create(root) error = self.assertRaises(WSDLParseError, getattr, item, "bar") self.assertEqual("Missing tag 'bar'", error.args[0]) def test_get_with_empty_required_tag(self): """ An error is raised if an expected required tag is found but has and empty value. """ schema = NodeSchema("foo", [LeafSchema("bar")]) root = etree.fromstring("") item = schema.create(root) error = self.assertRaises(WSDLParseError, getattr, item, "bar") self.assertEqual("Missing tag 'bar'", error.args[0]) def test_get_with_non_required_tag(self): """ No error is raised if a tag is missing and its min count is zero. """ schema = NodeSchema("foo") schema.add(LeafSchema("bar"), min_occurs=0) root = etree.fromstring("") foo = schema.create(root) self.assertIdentical(None, foo.bar) def test_get_with_reserved_keyword(self): """ Attributes associated to tags named against required attributes can be accessed appending a '_' to the name. """ schema = NodeSchema("foo", [LeafSchema("return")]) root = etree.fromstring("true") foo = schema.create(root) self.assertEqual("true", foo.return_) def test_get_with_nested(self): """ It is possible to access nested nodes. """ schema = NodeSchema("foo", [NodeSchema("bar", [LeafSchema("egg")])]) root = etree.fromstring("spam") foo = schema.create(root) self.assertEqual("spam", foo.bar.egg) def test_get_with_non_required_nested(self): """ It is possible to access a non-required nested node that has no associated element in the XML yet, in that case a new element is created for it. """ schema = NodeSchema("foo") schema.add(NodeSchema("bar", [LeafSchema("egg")]), min_occurs=0) root = etree.fromstring("") foo = schema.create(root) foo.bar.egg = "spam" self.assertEqual("spam", etree.tostring(schema.dump(foo))) def test_set_with_unknown_tag(self): """ An error is raised when trying to set an attribute not in the schema. """ schema = NodeSchema("foo") foo = schema.create() error = self.assertRaises(WSDLParseError, setattr, foo, "bar", "egg") self.assertEqual("Unknown tag 'bar'", error.args[0]) def test_set_with_duplicate_tag(self): """ An error is raised when trying to set an attribute associated with a tag that appears more than once. """ schema = NodeSchema("foo", [LeafSchema("bar")]) root = etree.fromstring("spam1spam2") foo = schema.create(root) error = self.assertRaises(WSDLParseError, setattr, foo, "bar", "egg") self.assertEqual("Duplicate tag 'bar'", error.args[0]) def test_set_with_required_tag(self): """ An error is raised when trying to set a required attribute to C{None}. """ schema = NodeSchema("foo", [LeafSchema("bar")]) root = etree.fromstring("spam") foo = schema.create(root) error = self.assertRaises(WSDLParseError, setattr, foo, "bar", None) self.assertEqual("Missing tag 'bar'", error.args[0]) self.assertEqual("spam", foo.bar) def test_set_with_non_required_tag(self): """ It is possible to set a non-required tag value to C{None}, in that case the element will be removed if present. """ schema = NodeSchema("foo") schema.add(LeafSchema("bar"), min_occurs=0) root = etree.fromstring("spam") foo = schema.create(root) foo.bar = None self.assertEqual("", etree.tostring(schema.dump(foo))) def test_set_with_non_leaf_tag(self): """ An error is raised when trying to set a non-leaf attribute to a value other than C{None}. """ schema = NodeSchema("foo", [NodeSchema("bar", [LeafSchema("egg")])]) root = etree.fromstring("spam") foo = schema.create(root) error = self.assertRaises(WSDLParseError, setattr, foo, "bar", "yo") self.assertEqual("Can't set non-leaf tag 'bar'", error.args[0]) def test_set_with_optional_node_tag(self): """ It is possible to set an optional node tag to C{None}, in that case it will be removed from the tree. """ schema = NodeSchema("foo") schema.add(NodeSchema("bar", [LeafSchema("egg")]), min_occurs=0) root = etree.fromstring("spam") foo = schema.create(root) foo.bar = None self.assertEqual("", etree.tostring(schema.dump(foo))) def test_set_with_sequence_tag(self): """ It is possible to set a sequence tag to C{None}, in that case all its children will be removed """ schema = NodeSchema("foo") schema.add(SequenceSchema("bar", NodeSchema("item", [LeafSchema("egg")]))) root = etree.fromstring("" "spam<" "/foo>") foo = schema.create(root) foo.bar = None self.assertEqual("", etree.tostring(schema.dump(foo))) def test_set_with_required_non_leaf_tag(self): """ An error is raised when trying to set a required non-leaf tag to C{None}. """ schema = NodeSchema("foo", [NodeSchema("bar", [LeafSchema("egg")])]) root = etree.fromstring("spam") foo = schema.create(root) error = self.assertRaises(WSDLParseError, setattr, foo, "bar", None) self.assertEqual("Missing tag 'bar'", error.args[0]) self.assertTrue(hasattr(foo, "bar")) class SequenceSchemaTestCase(WsdlBaseTestCase): def test_create_with_bad_tag(self): """ L{SequenceSchema.create} raises an error if the tag of the given element doesn't match the expected one. """ schema = SequenceSchema("foo", NodeSchema("item", [LeafSchema("bar")])) root = etree.fromstring("egg") error = self.assertRaises(WSDLParseError, schema.create, root) self.assertEqual("Expected response with tag 'foo', but got " "'spam' instead", error.args[0]) def test_set_with_leaf(self): """ L{SequenceSchema.set} raises an error if the given child is a leaf node """ schema = SequenceSchema("foo") error = self.assertRaises(RuntimeError, schema.set, LeafSchema("bar")) self.assertEqual("Sequence can't have leaf children", str(error)) def test_set_with_previous_child(self): """ L{SequenceSchema.set} raises an error if the sequence has already a child. """ schema = SequenceSchema("foo", NodeSchema("item", [LeafSchema("bar")])) error = self.assertRaises(RuntimeError, schema.set, NodeSchema("egg")) self.assertEqual("Sequence has already a child", str(error)) def test_set_with_no_min_or_max(self): """ L{SequenceSchema.set} raises an error if no values are provided for the min and max parameters. """ schema = SequenceSchema("foo") child = NodeSchema("item", [LeafSchema("bar")]) error = self.assertRaises(RuntimeError, schema.set, child, min_occurs=0, max_occurs=None) self.assertEqual("Sequence node without min or max", str(error)) error = self.assertRaises(RuntimeError, schema.set, child, min_occurs=None, max_occurs=1) self.assertEqual("Sequence node without min or max", str(error)) def test_dump(self): """ L{SequenceSchema.dump} creates a L{etree.Element} out of a L{SequenceItem}. """ schema = SequenceSchema("foo", NodeSchema("item", [LeafSchema("bar")])) foo = SequenceItem(schema) foo.append().bar = "egg" self.assertEqual("egg", etree.tostring(schema.dump(foo))) def test_dump_with_many_items(self): """ L{SequenceSchema.dump} supports many child items in the sequence. """ schema = SequenceSchema("foo", NodeSchema("item", [LeafSchema("bar")])) foo = SequenceItem(schema) foo.append().bar = "spam0" foo.append().bar = "spam1" self.assertEqual("" "spam0" "spam1" "", etree.tostring(schema.dump(foo))) class SequenceItemTestCase(WsdlBaseTestCase): def test_get(self): """ The child elements of a L{SequenceItem} can be accessed as attributes. """ schema = SequenceSchema("foo", NodeSchema("item", [LeafSchema("bar")])) root = etree.fromstring("egg") foo = schema.create(root) self.assertEqual("egg", foo[0].bar) def test_get_items(self): """L{SequenceItem} supports elements with many child items.""" schema = SequenceSchema("foo", NodeSchema("item", [LeafSchema("bar")])) root = etree.fromstring("" "egg0" "egg1" "") foo = schema.create(root) self.assertEqual("egg0", foo[0].bar) self.assertEqual("egg1", foo[1].bar) def test_get_with_namespace(self): """ The child elements of a L{SequenceItem} can be accessed as attributes. """ schema = SequenceSchema("foo", NodeSchema("item", [LeafSchema("bar")])) root = etree.fromstring("" "egg" "") foo = schema.create(root) self.assertEqual("egg", foo[0].bar) def test_get_with_non_existing_index(self): """An error is raised when trying to access a non existing item.""" schema = SequenceSchema("foo", NodeSchema("item", [LeafSchema("bar")])) root = etree.fromstring("egg") foo = schema.create(root) error = self.assertRaises(WSDLParseError, foo.__getitem__, 1) self.assertEqual("Non existing item in tag 'foo'", error.args[0]) def test_get_with_index_higher_than_max(self): """ An error is raised when trying to access an item above the allowed max value. """ schema = SequenceSchema("foo") schema.set(NodeSchema("item", [LeafSchema("bar")]), min_occurs=0, max_occurs=1) root = etree.fromstring("" "egg0" "egg1" "") foo = schema.create(root) error = self.assertRaises(WSDLParseError, foo.__getitem__, 1) self.assertEqual("Out of range item in tag 'foo'", error.args[0]) def test_append(self): """ L{SequenceItem.append} adds a new item to the sequence, appending it at the end. """ schema = SequenceSchema("foo", NodeSchema("item", [LeafSchema("bar")])) root = etree.fromstring("egg0") foo = schema.create(root) foo.append().bar = "egg1" self.assertEqual("egg1", foo[1].bar) self.assertEqual("" "egg0" "egg1" "", etree.tostring(schema.dump(foo))) def test_append_with_too_many_items(self): """ An error is raised when trying to append items above the max. """ schema = SequenceSchema("foo") schema.set(NodeSchema("item", [LeafSchema("bar")]), min_occurs=0, max_occurs=1) root = etree.fromstring("egg") foo = schema.create(root) error = self.assertRaises(WSDLParseError, foo.append) self.assertEqual("Too many items in tag 'foo'", error.args[0]) self.assertEqual(1, len(list(foo))) def test_delitem(self): """ L{SequenceItem.__delitem__} removes from the sequence the item with the given index. """ schema = SequenceSchema("foo", NodeSchema("item", [LeafSchema("bar")])) root = etree.fromstring("" "egg0" "egg1" "") foo = schema.create(root) del foo[0] self.assertEqual("egg1", foo[0].bar) self.assertEqual("egg1", etree.tostring(schema.dump(foo))) def test_delitem_with_not_enough_items(self): """ L{SequenceItem.__delitem__} raises an error if trying to remove an item would make the sequence shorter than the required minimum. """ schema = SequenceSchema("foo") schema.set(NodeSchema("item", [LeafSchema("bar")]), min_occurs=1, max_occurs=10) root = etree.fromstring("egg") foo = schema.create(root) error = self.assertRaises(WSDLParseError, foo.__delitem__, 0) self.assertEqual("Not enough items in tag 'foo'", error.args[0]) self.assertEqual(1, len(list(foo))) def test_remove(self): """ L{SequenceItem.remove} removes the given item from the sequence. """ schema = SequenceSchema("foo", NodeSchema("item", [LeafSchema("bar")])) root = etree.fromstring("" "egg0" "egg1" "") foo = schema.create(root) foo.remove(foo[0]) self.assertEqual("egg1", foo[0].bar) self.assertEqual("egg1", etree.tostring(schema.dump(foo))) def test_remove_with_non_existing_item(self): """ L{SequenceItem.remove} raises an exception when trying to remove a non existing item """ schema = SequenceSchema("foo", NodeSchema("item", [LeafSchema("bar")])) root = etree.fromstring("egg") foo = schema.create(root) item = foo.remove(foo[0]) error = self.assertRaises(WSDLParseError, foo.remove, item) self.assertEqual("Non existing item in tag 'foo'", error.args[0]) def test_iter(self): """L{SequenceItem} objects are iterable.""" schema = SequenceSchema("foo", NodeSchema("item", [LeafSchema("bar")])) root = etree.fromstring("" "egg0" "egg1" "") foo = schema.create(root) [item0, item1] = list(foo) self.assertEqual("egg0", item0.bar) self.assertEqual("egg1", item1.bar) class WDSLParserTestCase(WsdlBaseTestCase): def setUp(self): super(WDSLParserTestCase, self).setUp() parser = WSDLParser() wsdl_dir = os.path.join(os.path.dirname(__file__), "../../wsdl") wsdl_path = os.path.join(wsdl_dir, "2009-11-30.ec2.wsdl") self.schemas = parser.parse(open(wsdl_path).read()) def test_parse_create_key_pair_response(self): """Parse a CreateKeyPairResponse payload.""" schema = self.schemas["CreateKeyPairResponse"] xmlns = "http://ec2.amazonaws.com/doc/2008-12-01/" xml = ("" "65d85081-abbc" "foo" "9a:81:96:46" "MIIEowIBAAKCAQEAi" "" % xmlns) response = schema.create(etree.fromstring(xml)) self.assertEqual("65d85081-abbc", response.requestId) self.assertEqual("foo", response.keyName) self.assertEqual("9a:81:96:46", response.keyFingerprint) self.assertEqual("MIIEowIBAAKCAQEAi", response.keyMaterial) self.assertEqual(xml, etree.tostring(schema.dump(response))) def test_parse_delete_key_pair_response(self): """Parse a DeleteKeyPairResponse payload.""" schema = self.schemas["DeleteKeyPairResponse"] xmlns = "http://ec2.amazonaws.com/doc/2008-12-01/" xml = ("" "acc41b73-4c47-4f80" "true" "" % xmlns) root = etree.fromstring(xml) response = schema.create(root) self.assertEqual("acc41b73-4c47-4f80", response.requestId) self.assertEqual("true", response.return_) self.assertEqual(xml, etree.tostring(schema.dump(response))) def test_parse_describe_key_pairs_response(self): """Parse a DescribeKeyPairsResponse payload.""" schema = self.schemas["DescribeKeyPairsResponse"] xmlns = "http://ec2.amazonaws.com/doc/2008-12-01/" xml = ("" "3ef0aa1d-57dd-4272" "" "" "europe-key" "94:88:29:60:cf" "" "" "" % xmlns) root = etree.fromstring(xml) response = schema.create(root) self.assertEqual("3ef0aa1d-57dd-4272", response.requestId) self.assertEqual("europe-key", response.keySet[0].keyName) self.assertEqual("94:88:29:60:cf", response.keySet[0].keyFingerprint) self.assertEqual(xml, etree.tostring(schema.dump(response))) def test_modify_describe_key_pairs_response(self): """Modify a DescribeKeyPairsResponse payload.""" schema = self.schemas["DescribeKeyPairsResponse"] xmlns = "http://ec2.amazonaws.com/doc/2008-12-01/" xml = ("" "3ef0aa1d-57dd-4272" "" "" "europe-key" "94:88:29:60:cf" "" "" "" % xmlns) root = etree.fromstring(xml) response = schema.create(root) response.keySet[0].keyName = "new-key" xml = ("" "3ef0aa1d-57dd-4272" "" "" "new-key" "94:88:29:60:cf" "" "" "" % xmlns) self.assertEqual(xml, etree.tostring(schema.dump(response))) def test_create_describe_key_pairs_response(self): """Create a DescribeKeyPairsResponse payload.""" schema = self.schemas["DescribeKeyPairsResponse"] xmlns = "http://ec2.amazonaws.com/doc/2008-12-01/" response = schema.create(namespace=xmlns) response.requestId = "abc" key = response.keySet.append() key.keyName = "some-key" key.keyFingerprint = "11:22:33:44" xml = ("" "abc" "" "" "some-key" "11:22:33:44" "" "" "" % xmlns) self.assertEqual(xml, etree.tostring(schema.dump(response))) def test_create_describe_addresses_response(self): """Create a DescribeAddressesResponse payload. """ schema = self.schemas["DescribeAddressesResponse"] xmlns = "http://ec2.amazonaws.com/doc/2008-12-01/" response = schema.create(namespace=xmlns) response.requestId = "abc" address = response.addressesSet.append() address.publicIp = "192.168.0.1" xml = ("" "abc" "" "" "192.168.0.1" "" "" "" % xmlns) self.assertEqual(xml, etree.tostring(schema.dump(response))) def test_create_describe_instances_response_with_username(self): """Create a DescribeInstancesResponse payload. """ schema = self.schemas["DescribeInstancesResponse"] xmlns = "http://ec2.amazonaws.com/doc/2008-12-01/" response = schema.create(namespace=xmlns) response.requestId = "abc" reservation = response.reservationSet.append() instance = reservation.instancesSet.append() instance.instanceId = "i-01234567" xml = ("" "abc" "" "" "" "" "i-01234567" "" "" "" "" "" % xmlns) self.assertEqual(xml, etree.tostring(schema.dump(response))) def test_create_describe_instances_response(self): """Create a DescribeInstancesResponse payload. """ schema = self.schemas["DescribeInstancesResponse"] xmlns = "http://ec2.amazonaws.com/doc/2008-12-01/" response = schema.create(namespace=xmlns) response.requestId = "abc" reservation = response.reservationSet.append() instance = reservation.instancesSet.append() instance.instanceId = "i-01234567" xml = ("" "abc" "" "" "" "" "i-01234567" "" "" "" "" "" % xmlns) self.assertEqual(xml, etree.tostring(schema.dump(response))) def test_parse_describe_security_groups_response(self): """Parse a DescribeSecurityGroupsResponse payload.""" schema = self.schemas["DescribeSecurityGroupsResponse"] xmlns = "http://ec2.amazonaws.com/doc/2008-12-01/" xml = ("" "3ef0aa1d-57dd-4272" "" "" "UYY3TLBUXIEON5NQVUUX6OMPWBZIQNFM" "WebServers" "Web" "" "" "tcp" "80" "80" "" "" "" "0.0.0.0/0" "" "" "" "" "" "" "" % xmlns) root = etree.fromstring(xml) response = schema.create(root) self.assertEqual("3ef0aa1d-57dd-4272", response.requestId) self.assertEqual("UYY3TLBUXIEON5NQVUUX6OMPWBZIQNFM", response.securityGroupInfo[0].ownerId) self.assertEqual("WebServers", response.securityGroupInfo[0].groupName) self.assertEqual("Web", response.securityGroupInfo[0].groupDescription) self.assertEqual(xml, etree.tostring(schema.dump(response))) def test_modify_describe_security_groups_response(self): """Modify a DescribeSecurityGroupsResponse payload.""" schema = self.schemas["DescribeSecurityGroupsResponse"] xmlns = "http://ec2.amazonaws.com/doc/2008-12-01/" xml = ("" "3ef0aa1d-57dd-4272" "" "" "UYY3TLBUXIEON5NQVUUX6OMPWBZIQNFM" "WebServers" "Web" "" "" "tcp" "80" "80" "" "" "" "0.0.0.0/0" "" "" "" "" "" "" "" % xmlns) root = etree.fromstring(xml) response = schema.create(root) response.securityGroupInfo[0].ownerId = "abc123" response.securityGroupInfo[0].groupName = "Everybody" response.securityGroupInfo[0].groupDescription = "All People" xml = ("" "3ef0aa1d-57dd-4272" "" "" "abc123" "Everybody" "All People" "" "" "tcp" "80" "80" "" "" "" "0.0.0.0/0" "" "" "" "" "" "" "" % xmlns) self.assertEqual(xml, etree.tostring(schema.dump(response))) def test_create_describe_security_groups_response(self): """Create a DescribeSecurityGroupsResponse payload.""" schema = self.schemas["DescribeSecurityGroupsResponse"] xmlns = "http://ec2.amazonaws.com/doc/2008-12-01/" response = schema.create(namespace=xmlns) response.requestId = "requestId123" group = response.securityGroupInfo.append() group.ownerId = "deadbeef31337" group.groupName = "hexadecimalonly" group.groupDescription = "All people that love hex" xml = ("" "requestId123" "" "" "deadbeef31337" "hexadecimalonly" "All people that love hex" "" "" "" % xmlns) self.assertEqual(xml, etree.tostring(schema.dump(response))) txAWS-0.2.3/txaws/reactor.py0000664000175000017500000000145711741311335017377 0ustar oubiwannoubiwann00000000000000'''Reactor utilities.''' def get_exitcode_reactor(): """ This is only neccesary until a fix like the one outlined here is implemented for Twisted: http://twistedmatrix.com/trac/ticket/2182 """ from twisted.internet.main import installReactor from twisted.internet.selectreactor import SelectReactor class ExitCodeReactor(SelectReactor): def stop(self, exitStatus=0): super(ExitCodeReactor, self).stop() self.exitStatus = exitStatus def run(self, *args, **kwargs): super(ExitCodeReactor, self).run(*args, **kwargs) return self.exitStatus reactor = ExitCodeReactor() installReactor(reactor) return reactor try: reactor = get_exitcode_reactor() except: from twisted.internet import reactor txAWS-0.2.3/txaws/ec2/0000775000175000017500000000000011741312025016025 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/ec2/exception.py0000664000175000017500000000106011741311335020375 0ustar oubiwannoubiwann00000000000000# Copyright (c) 2009 Canonical Ltd # Licenced under the txaws licence available at /LICENSE in the txaws source. from txaws.exception import AWSError class EC2Error(AWSError): """ A error class providing custom methods on EC2 errors. """ def _set_400_error(self, tree): errors_node = tree.find(".//Errors") if errors_node is not None: for error in errors_node: data = self._node_to_dict(error) if data: self.errors.append(data) txAWS-0.2.3/txaws/ec2/__init__.py0000664000175000017500000000000011741311335020127 0ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/ec2/client.py0000664000175000017500000012537411741311335017674 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2009 Robert Collins # Copyright (C) 2009 Canonical Ltd # Copyright (C) 2009 Duncan McGreggor # Licenced under the txaws licence available at /LICENSE in the txaws source. """EC2 client support.""" from datetime import datetime from urllib import quote from base64 import b64encode from txaws import version from txaws.client.base import BaseClient, BaseQuery, error_wrapper from txaws.ec2 import model from txaws.ec2.exception import EC2Error from txaws.util import iso8601time, XML __all__ = ["EC2Client", "Query", "Parser"] def ec2_error_wrapper(error): error_wrapper(error, EC2Error) class EC2Client(BaseClient): """A client for EC2.""" def __init__(self, creds=None, endpoint=None, query_factory=None, parser=None): if query_factory is None: query_factory = Query if parser is None: parser = Parser() super(EC2Client, self).__init__(creds, endpoint, query_factory, parser) def describe_instances(self, *instance_ids): """Describe current instances.""" instances = {} for pos, instance_id in enumerate(instance_ids): instances["InstanceId.%d" % (pos + 1)] = instance_id query = self.query_factory( action="DescribeInstances", creds=self.creds, endpoint=self.endpoint, other_params=instances) d = query.submit() return d.addCallback(self.parser.describe_instances) def run_instances(self, image_id, min_count, max_count, security_groups=None, key_name=None, instance_type=None, user_data=None, availability_zone=None, kernel_id=None, ramdisk_id=None): """Run new instances. TODO: blockDeviceMapping, monitoring, subnetId """ params = {"ImageId": image_id, "MinCount": str(min_count), "MaxCount": str(max_count)} if key_name is not None: params["KeyName"] = key_name if security_groups is not None: for i, name in enumerate(security_groups): params["SecurityGroup.%d" % (i + 1)] = name if user_data is not None: params["UserData"] = b64encode(user_data) if instance_type is not None: params["InstanceType"] = instance_type if availability_zone is not None: params["Placement.AvailabilityZone"] = availability_zone if kernel_id is not None: params["KernelId"] = kernel_id if ramdisk_id is not None: params["RamdiskId"] = ramdisk_id query = self.query_factory( action="RunInstances", creds=self.creds, endpoint=self.endpoint, other_params=params) d = query.submit() return d.addCallback(self.parser.run_instances) def terminate_instances(self, *instance_ids): """Terminate some instances. @param instance_ids: The ids of the instances to terminate. @return: A deferred which on success gives an iterable of (id, old-state, new-state) tuples. """ instances = {} for pos, instance_id in enumerate(instance_ids): instances["InstanceId.%d" % (pos + 1)] = instance_id query = self.query_factory( action="TerminateInstances", creds=self.creds, endpoint=self.endpoint, other_params=instances) d = query.submit() return d.addCallback(self.parser.terminate_instances) def describe_security_groups(self, *names): """Describe security groups. @param names: Optionally, a list of security group names to describe. Defaults to all security groups in the account. @return: A C{Deferred} that will fire with a list of L{SecurityGroup}s retrieved from the cloud. """ group_names = {} if names: group_names = dict([("GroupName.%d" % (i + 1), name) for i, name in enumerate(names)]) query = self.query_factory( action="DescribeSecurityGroups", creds=self.creds, endpoint=self.endpoint, other_params=group_names) d = query.submit() return d.addCallback(self.parser.describe_security_groups) def create_security_group(self, name, description): """Create security group. @param name: Name of the new security group. @param description: Description of the new security group. @return: A C{Deferred} that will fire with a truth value for the success of the operation. """ parameters = {"GroupName": name, "GroupDescription": description} query = self.query_factory( action="CreateSecurityGroup", creds=self.creds, endpoint=self.endpoint, other_params=parameters) d = query.submit() return d.addCallback(self.parser.truth_return) def delete_security_group(self, name): """ @param name: Name of the new security group. @return: A C{Deferred} that will fire with a truth value for the success of the operation. """ parameter = {"GroupName": name} query = self.query_factory( action="DeleteSecurityGroup", creds=self.creds, endpoint=self.endpoint, other_params=parameter) d = query.submit() return d.addCallback(self.parser.truth_return) def authorize_security_group( self, group_name, source_group_name="", source_group_owner_id="", ip_protocol="", from_port="", to_port="", cidr_ip=""): """ There are two ways to use C{authorize_security_group}: 1) associate an existing group (source group) with the one that you are targeting (group_name) with an authorization update; or 2) associate a set of IP permissions with the group you are targeting with an authorization update. @param group_name: The group you will be modifying with a new authorization. Optionally, the following parameters: @param source_group_name: Name of security group to authorize access to when operating on a user/group pair. @param source_group_owner_id: Owner of security group to authorize access to when operating on a user/group pair. If those parameters are not specified, then the following must be: @param ip_protocol: IP protocol to authorize access to when operating on a CIDR IP. @param from_port: Bottom of port range to authorize access to when operating on a CIDR IP. This contains the ICMP type if ICMP is being authorized. @param to_port: Top of port range to authorize access to when operating on a CIDR IP. This contains the ICMP code if ICMP is being authorized. @param cidr_ip: CIDR IP range to authorize access to when operating on a CIDR IP. @return: A C{Deferred} that will fire with a truth value for the success of the operation. """ if source_group_name and source_group_owner_id: parameters = { "SourceSecurityGroupName": source_group_name, "SourceSecurityGroupOwnerId": source_group_owner_id, } elif ip_protocol and from_port and to_port and cidr_ip: parameters = { "IpProtocol": ip_protocol, "FromPort": from_port, "ToPort": to_port, "CidrIp": cidr_ip, } else: msg = ("You must specify either both group parameters or " "all the ip parameters.") raise ValueError(msg) parameters["GroupName"] = group_name query = self.query_factory( action="AuthorizeSecurityGroupIngress", creds=self.creds, endpoint=self.endpoint, other_params=parameters) d = query.submit() return d.addCallback(self.parser.truth_return) def authorize_group_permission( self, group_name, source_group_name, source_group_owner_id): """ This is a convenience function that wraps the "authorize group" functionality of the C{authorize_security_group} method. For an explanation of the parameters, see C{authorize_security_group}. """ d = self.authorize_security_group( group_name, source_group_name=source_group_name, source_group_owner_id=source_group_owner_id) return d def authorize_ip_permission( self, group_name, ip_protocol, from_port, to_port, cidr_ip): """ This is a convenience function that wraps the "authorize ip permission" functionality of the C{authorize_security_group} method. For an explanation of the parameters, see C{authorize_security_group}. """ d = self.authorize_security_group( group_name, ip_protocol=ip_protocol, from_port=from_port, to_port=to_port, cidr_ip=cidr_ip) return d def revoke_security_group( self, group_name, source_group_name="", source_group_owner_id="", ip_protocol="", from_port="", to_port="", cidr_ip=""): """ There are two ways to use C{revoke_security_group}: 1) associate an existing group (source group) with the one that you are targeting (group_name) with the revoke update; or 2) associate a set of IP permissions with the group you are targeting with a revoke update. @param group_name: The group you will be modifying with an authorization removal. Optionally, the following parameters: @param source_group_name: Name of security group to revoke access from when operating on a user/group pair. @param source_group_owner_id: Owner of security group to revoke access from when operating on a user/group pair. If those parameters are not specified, then the following must be: @param ip_protocol: IP protocol to revoke access from when operating on a CIDR IP. @param from_port: Bottom of port range to revoke access from when operating on a CIDR IP. This contains the ICMP type if ICMP is being revoked. @param to_port: Top of port range to revoke access from when operating on a CIDR IP. This contains the ICMP code if ICMP is being revoked. @param cidr_ip: CIDR IP range to revoke access from when operating on a CIDR IP. @return: A C{Deferred} that will fire with a truth value for the success of the operation. """ if source_group_name and source_group_owner_id: parameters = { "SourceSecurityGroupName": source_group_name, "SourceSecurityGroupOwnerId": source_group_owner_id, } elif ip_protocol and from_port and to_port and cidr_ip: parameters = { "IpProtocol": ip_protocol, "FromPort": from_port, "ToPort": to_port, "CidrIp": cidr_ip, } else: msg = ("You must specify either both group parameters or " "all the ip parameters.") raise ValueError(msg) parameters["GroupName"] = group_name query = self.query_factory( action="RevokeSecurityGroupIngress", creds=self.creds, endpoint=self.endpoint, other_params=parameters) d = query.submit() return d.addCallback(self.parser.truth_return) def revoke_group_permission( self, group_name, source_group_name, source_group_owner_id): """ This is a convenience function that wraps the "authorize group" functionality of the C{authorize_security_group} method. For an explanation of the parameters, see C{revoke_security_group}. """ d = self.revoke_security_group( group_name, source_group_name=source_group_name, source_group_owner_id=source_group_owner_id) return d def revoke_ip_permission( self, group_name, ip_protocol, from_port, to_port, cidr_ip): """ This is a convenience function that wraps the "authorize ip permission" functionality of the C{authorize_security_group} method. For an explanation of the parameters, see C{revoke_security_group}. """ d = self.revoke_security_group( group_name, ip_protocol=ip_protocol, from_port=from_port, to_port=to_port, cidr_ip=cidr_ip) return d def describe_volumes(self, *volume_ids): """Describe available volumes.""" volumeset = {} for pos, volume_id in enumerate(volume_ids): volumeset["VolumeId.%d" % (pos + 1)] = volume_id query = self.query_factory( action="DescribeVolumes", creds=self.creds, endpoint=self.endpoint, other_params=volumeset) d = query.submit() return d.addCallback(self.parser.describe_volumes) def create_volume(self, availability_zone, size=None, snapshot_id=None): """Create a new volume.""" params = {"AvailabilityZone": availability_zone} if ((snapshot_id is None and size is None) or (snapshot_id is not None and size is not None)): raise ValueError("Please provide either size or snapshot_id") if size is not None: params["Size"] = str(size) if snapshot_id is not None: params["SnapshotId"] = snapshot_id query = self.query_factory( action="CreateVolume", creds=self.creds, endpoint=self.endpoint, other_params=params) d = query.submit() return d.addCallback(self.parser.create_volume) def delete_volume(self, volume_id): query = self.query_factory( action="DeleteVolume", creds=self.creds, endpoint=self.endpoint, other_params={"VolumeId": volume_id}) d = query.submit() return d.addCallback(self.parser.truth_return) def describe_snapshots(self, *snapshot_ids): """Describe available snapshots. TODO: ownerSet, restorableBySet """ snapshot_set = {} for pos, snapshot_id in enumerate(snapshot_ids): snapshot_set["SnapshotId.%d" % (pos + 1)] = snapshot_id query = self.query_factory( action="DescribeSnapshots", creds=self.creds, endpoint=self.endpoint, other_params=snapshot_set) d = query.submit() return d.addCallback(self.parser.snapshots) def create_snapshot(self, volume_id): """Create a new snapshot of an existing volume. TODO: description """ query = self.query_factory( action="CreateSnapshot", creds=self.creds, endpoint=self.endpoint, other_params={"VolumeId": volume_id}) d = query.submit() return d.addCallback(self.parser.create_snapshot) def delete_snapshot(self, snapshot_id): """Remove a previously created snapshot.""" query = self.query_factory( action="DeleteSnapshot", creds=self.creds, endpoint=self.endpoint, other_params={"SnapshotId": snapshot_id}) d = query.submit() return d.addCallback(self.parser.truth_return) def attach_volume(self, volume_id, instance_id, device): """Attach the given volume to the specified instance at C{device}.""" query = self.query_factory( action="AttachVolume", creds=self.creds, endpoint=self.endpoint, other_params={"VolumeId": volume_id, "InstanceId": instance_id, "Device": device}) d = query.submit() return d.addCallback(self.parser.attach_volume) def describe_keypairs(self, *keypair_names): """Returns information about key pairs available.""" keypairs = {} for index, keypair_name in enumerate(keypair_names): keypairs["KeyName.%d" % (index + 1)] = keypair_name query = self.query_factory( action="DescribeKeyPairs", creds=self.creds, endpoint=self.endpoint, other_params=keypairs) d = query.submit() return d.addCallback(self.parser.describe_keypairs) def create_keypair(self, keypair_name): """ Create a new 2048 bit RSA key pair and return a unique ID that can be used to reference the created key pair when launching new instances. """ query = self.query_factory( action="CreateKeyPair", creds=self.creds, endpoint=self.endpoint, other_params={"KeyName": keypair_name}) d = query.submit() return d.addCallback(self.parser.create_keypair) def delete_keypair(self, keypair_name): """Delete a given keypair.""" query = self.query_factory( action="DeleteKeyPair", creds=self.creds, endpoint=self.endpoint, other_params={"KeyName": keypair_name}) d = query.submit() return d.addCallback(self.parser.truth_return) def import_keypair(self, keypair_name, key_material): """ Import an existing SSH key into EC2. It supports: * OpenSSH public key format (e.g., the format in ~/.ssh/authorized_keys) * Base64 encoded DER format * SSH public key file format as specified in RFC4716 @param keypair_name: The name of the key to create. @param key_material: The material in one of the supported format. @return: A L{Deferred} firing with a L{model.Keypair} instance if successful. TODO: there is no corresponding method in the 2009-11-30 version of the ec2 wsdl. Delete this? """ query = self.query_factory( action="ImportKeyPair", creds=self.creds, endpoint=self.endpoint, other_params={"KeyName": keypair_name, "PublicKeyMaterial": b64encode(key_material)}) d = query.submit() return d.addCallback(self.parser.import_keypair, key_material) def allocate_address(self): """ Acquire an elastic IP address to be attached subsequently to EC2 instances. @return: the IP address allocated. """ # XXX remove empty other_params query = self.query_factory( action="AllocateAddress", creds=self.creds, endpoint=self.endpoint, other_params={}) d = query.submit() return d.addCallback(self.parser.allocate_address) def release_address(self, address): """ Release a previously allocated address returned by C{allocate_address}. @return: C{True} if the operation succeeded. """ query = self.query_factory( action="ReleaseAddress", creds=self.creds, endpoint=self.endpoint, other_params={"PublicIp": address}) d = query.submit() return d.addCallback(self.parser.truth_return) def associate_address(self, instance_id, address): """ Associate an allocated C{address} with the instance identified by C{instance_id}. @return: C{True} if the operation succeeded. """ query = self.query_factory( action="AssociateAddress", creds=self.creds, endpoint=self.endpoint, other_params={"InstanceId": instance_id, "PublicIp": address}) d = query.submit() return d.addCallback(self.parser.truth_return) def disassociate_address(self, address): """ Disassociate an address previously associated with C{associate_address}. This is an idempotent operation, so it can be called several times without error. """ query = self.query_factory( action="DisassociateAddress", creds=self.creds, endpoint=self.endpoint, other_params={"PublicIp": address}) d = query.submit() return d.addCallback(self.parser.truth_return) def describe_addresses(self, *addresses): """ List the elastic IPs allocated in this account. @param addresses: if specified, the addresses to get information about. @return: a C{list} of (address, instance_id). If the elastic IP is not associated currently, C{instance_id} will be C{None}. """ address_set = {} for pos, address in enumerate(addresses): address_set["PublicIp.%d" % (pos + 1)] = address query = self.query_factory( action="DescribeAddresses", creds=self.creds, endpoint=self.endpoint, other_params=address_set) d = query.submit() return d.addCallback(self.parser.describe_addresses) def describe_availability_zones(self, names=None): zone_names = None if names: zone_names = dict([("ZoneName.%d" % (i + 1), name) for i, name in enumerate(names)]) query = self.query_factory( action="DescribeAvailabilityZones", creds=self.creds, endpoint=self.endpoint, other_params=zone_names) d = query.submit() return d.addCallback(self.parser.describe_availability_zones) class Parser(object): """A parser for EC2 responses""" def instances_set(self, root, reservation): """Parse instance data out of an XML payload. @param root: The root node of the XML payload. @param reservation: The L{Reservation} associated with the instances from the response. @return: A C{list} of L{Instance}s. """ instances = [] for instance_data in root.find("instancesSet"): instances.append(self.instance(instance_data, reservation)) return instances def instance(self, instance_data, reservation): """Parse instance data out of an XML payload. @param instance_data: An XML node containing instance data. @param reservation: The L{Reservation} associated with the instance. @return: An L{Instance}. TODO: reason, platform, monitoring, subnetId, vpcId, privateIpAddress, ipAddress, stateReason, architecture, rootDeviceName, blockDeviceMapping, instanceLifecycle, spotInstanceRequestId. """ instance_id = instance_data.findtext("instanceId") instance_state = instance_data.find( "instanceState").findtext("name") private_dns_name = instance_data.findtext("privateDnsName") dns_name = instance_data.findtext("dnsName") private_ip_address = instance_data.findtext("privateIpAddress") ip_address = instance_data.findtext("ipAddress") key_name = instance_data.findtext("keyName") ami_launch_index = instance_data.findtext("amiLaunchIndex") products = [] product_codes = instance_data.find("productCodes") if product_codes is not None: for product_data in instance_data.find("productCodes"): products.append(product_data.text) instance_type = instance_data.findtext("instanceType") launch_time = instance_data.findtext("launchTime") placement = instance_data.find("placement").findtext( "availabilityZone") kernel_id = instance_data.findtext("kernelId") ramdisk_id = instance_data.findtext("ramdiskId") image_id = instance_data.findtext("imageId") instance = model.Instance( instance_id, instance_state, instance_type, image_id, private_dns_name, dns_name, private_ip_address, ip_address, key_name, ami_launch_index, launch_time, placement, products, kernel_id, ramdisk_id, reservation=reservation) return instance def describe_instances(self, xml_bytes): """ Parse the reservations XML payload that is returned from an AWS describeInstances API call. Instead of returning the reservations as the "top-most" object, we return the object that most developers and their code will be interested in: the instances. In instances reservation is available on the instance object. The following instance attributes are optional: * ami_launch_index * key_name * kernel_id * product_codes * ramdisk_id * reason @param xml_bytes: raw XML payload from AWS. """ root = XML(xml_bytes) results = [] # May be a more elegant way to do this: for reservation_data in root.find("reservationSet"): # Get the security group information. groups = [] for group_data in reservation_data.find("groupSet"): group_id = group_data.findtext("groupId") groups.append(group_id) # Create a reservation object with the parsed data. reservation = model.Reservation( reservation_id=reservation_data.findtext("reservationId"), owner_id=reservation_data.findtext("ownerId"), groups=groups) # Get the list of instances. instances = self.instances_set( reservation_data, reservation) results.extend(instances) return results def run_instances(self, xml_bytes): """ Parse the reservations XML payload that is returned from an AWS RunInstances API call. @param xml_bytes: raw XML bytes with a C{RunInstancesResponse} root element. """ root = XML(xml_bytes) # Get the security group information. groups = [] for group_data in root.find("groupSet"): group_id = group_data.findtext("groupId") groups.append(group_id) # Create a reservation object with the parsed data. reservation = model.Reservation( reservation_id=root.findtext("reservationId"), owner_id=root.findtext("ownerId"), groups=groups) # Get the list of instances. instances = self.instances_set(root, reservation) return instances def terminate_instances(self, xml_bytes): """Parse the XML returned by the C{TerminateInstances} function. @param xml_bytes: XML bytes with a C{TerminateInstancesResponse} root element. @return: An iterable of C{tuple} of (instanceId, previousState, currentState) for the ec2 instances that where terminated. """ root = XML(xml_bytes) result = [] # May be a more elegant way to do this: instances = root.find("instancesSet") if instances is not None: for instance in instances: instanceId = instance.findtext("instanceId") previousState = instance.find("previousState").findtext( "name") currentState = instance.find("currentState").findtext( "name") result.append((instanceId, previousState, currentState)) return result def describe_security_groups(self, xml_bytes): """Parse the XML returned by the C{DescribeSecurityGroups} function. @param xml_bytes: XML bytes with a C{DescribeSecurityGroupsResponse} root element. @return: A list of L{SecurityGroup} instances. """ root = XML(xml_bytes) result = [] for group_info in root.findall("securityGroupInfo/item"): name = group_info.findtext("groupName") description = group_info.findtext("groupDescription") owner_id = group_info.findtext("ownerId") allowed_groups = [] allowed_ips = [] ip_permissions = group_info.find("ipPermissions") if ip_permissions is None: ip_permissions = () for ip_permission in ip_permissions: # openstack doesn't handle self authorized groups properly # XXX this is an upstream problem and should be addressed there # lp bug #829609 ip_protocol = ip_permission.findtext("ipProtocol") from_port = ip_permission.findtext("fromPort") to_port = ip_permission.findtext("toPort") if from_port: from_port = int(from_port) if to_port: to_port = int(to_port) for groups in ip_permission.findall("groups/item") or (): user_id = groups.findtext("userId") group_name = groups.findtext("groupName") if user_id and group_name: if (user_id, group_name) not in allowed_groups: allowed_groups.append((user_id, group_name)) for ip_ranges in ip_permission.findall("ipRanges/item") or (): cidr_ip = ip_ranges.findtext("cidrIp") allowed_ips.append( model.IPPermission( ip_protocol, from_port, to_port, cidr_ip)) allowed_groups = [model.UserIDGroupPair(user_id, group_name) for user_id, group_name in allowed_groups] security_group = model.SecurityGroup( name, description, owner_id=owner_id, groups=allowed_groups, ips=allowed_ips) result.append(security_group) return result def truth_return(self, xml_bytes): """Parse the XML for a truth value. @param xml_bytes: XML bytes. @return: True if the node contains "return" otherwise False. """ root = XML(xml_bytes) return root.findtext("return") == "true" def describe_volumes(self, xml_bytes): """Parse the XML returned by the C{DescribeVolumes} function. @param xml_bytes: XML bytes with a C{DescribeVolumesResponse} root element. @return: A list of L{Volume} instances. TODO: attachementSetItemResponseType#deleteOnTermination """ root = XML(xml_bytes) result = [] for volume_data in root.find("volumeSet"): volume_id = volume_data.findtext("volumeId") size = int(volume_data.findtext("size")) snapshot_id = volume_data.findtext("snapshotId") availability_zone = volume_data.findtext("availabilityZone") status = volume_data.findtext("status") create_time = volume_data.findtext("createTime") create_time = datetime.strptime( create_time[:19], "%Y-%m-%dT%H:%M:%S") volume = model.Volume( volume_id, size, status, create_time, availability_zone, snapshot_id) result.append(volume) for attachment_data in volume_data.find("attachmentSet"): instance_id = attachment_data.findtext("instanceId") device = attachment_data.findtext("device") status = attachment_data.findtext("status") attach_time = attachment_data.findtext("attachTime") attach_time = datetime.strptime( attach_time[:19], "%Y-%m-%dT%H:%M:%S") attachment = model.Attachment( instance_id, device, status, attach_time) volume.attachments.append(attachment) return result def create_volume(self, xml_bytes): """Parse the XML returned by the C{CreateVolume} function. @param xml_bytes: XML bytes with a C{CreateVolumeResponse} root element. @return: The L{Volume} instance created. """ root = XML(xml_bytes) volume_id = root.findtext("volumeId") size = int(root.findtext("size")) snapshot_id = root.findtext("snapshotId") availability_zone = root.findtext("availabilityZone") status = root.findtext("status") create_time = root.findtext("createTime") create_time = datetime.strptime( create_time[:19], "%Y-%m-%dT%H:%M:%S") volume = model.Volume( volume_id, size, status, create_time, availability_zone, snapshot_id) return volume def snapshots(self, xml_bytes): """Parse the XML returned by the C{DescribeSnapshots} function. @param xml_bytes: XML bytes with a C{DescribeSnapshotsResponse} root element. @return: A list of L{Snapshot} instances. TODO: ownersSet, restorableBySet, ownerId, volumeSize, description, ownerAlias. """ root = XML(xml_bytes) result = [] for snapshot_data in root.find("snapshotSet"): snapshot_id = snapshot_data.findtext("snapshotId") volume_id = snapshot_data.findtext("volumeId") status = snapshot_data.findtext("status") start_time = snapshot_data.findtext("startTime") start_time = datetime.strptime( start_time[:19], "%Y-%m-%dT%H:%M:%S") progress = snapshot_data.findtext("progress")[:-1] progress = float(progress or "0") / 100. snapshot = model.Snapshot( snapshot_id, volume_id, status, start_time, progress) result.append(snapshot) return result def create_snapshot(self, xml_bytes): """Parse the XML returned by the C{CreateSnapshot} function. @param xml_bytes: XML bytes with a C{CreateSnapshotResponse} root element. @return: The L{Snapshot} instance created. TODO: ownerId, volumeSize, description. """ root = XML(xml_bytes) snapshot_id = root.findtext("snapshotId") volume_id = root.findtext("volumeId") status = root.findtext("status") start_time = root.findtext("startTime") start_time = datetime.strptime( start_time[:19], "%Y-%m-%dT%H:%M:%S") progress = root.findtext("progress")[:-1] progress = float(progress or "0") / 100. return model.Snapshot( snapshot_id, volume_id, status, start_time, progress) def attach_volume(self, xml_bytes): """Parse the XML returned by the C{AttachVolume} function. @param xml_bytes: XML bytes with a C{AttachVolumeResponse} root element. @return: a C{dict} with status and attach_time keys. TODO: volumeId, instanceId, device """ root = XML(xml_bytes) status = root.findtext("status") attach_time = root.findtext("attachTime") attach_time = datetime.strptime( attach_time[:19], "%Y-%m-%dT%H:%M:%S") return {"status": status, "attach_time": attach_time} def describe_keypairs(self, xml_bytes): """Parse the XML returned by the C{DescribeKeyPairs} function. @param xml_bytes: XML bytes with a C{DescribeKeyPairsResponse} root element. @return: a C{list} of L{Keypair}. """ results = [] root = XML(xml_bytes) keypairs = root.find("keySet") if keypairs is None: return results for keypair_data in keypairs: key_name = keypair_data.findtext("keyName") key_fingerprint = keypair_data.findtext("keyFingerprint") results.append(model.Keypair(key_name, key_fingerprint)) return results def create_keypair(self, xml_bytes): """Parse the XML returned by the C{CreateKeyPair} function. @param xml_bytes: XML bytes with a C{CreateKeyPairResponse} root element. @return: The L{Keypair} instance created. """ keypair_data = XML(xml_bytes) key_name = keypair_data.findtext("keyName") key_fingerprint = keypair_data.findtext("keyFingerprint") key_material = keypair_data.findtext("keyMaterial") return model.Keypair(key_name, key_fingerprint, key_material) def import_keypair(self, xml_bytes, key_material): """Extract the key name and the fingerprint from the result. TODO: there is no corresponding method in the 2009-11-30 version of the ec2 wsdl. Delete this? """ keypair_data = XML(xml_bytes) key_name = keypair_data.findtext("keyName") key_fingerprint = keypair_data.findtext("keyFingerprint") return model.Keypair(key_name, key_fingerprint, key_material) def allocate_address(self, xml_bytes): """Parse the XML returned by the C{AllocateAddress} function. @param xml_bytes: XML bytes with a C{AllocateAddress} root element. @return: The public ip address as a string. """ address_data = XML(xml_bytes) return address_data.findtext("publicIp") def describe_addresses(self, xml_bytes): """Parse the XML returned by the C{DescribeAddresses} function. @param xml_bytes: XML bytes with a C{DescribeAddressesResponse} root element. @return: a C{list} of L{tuple} of (publicIp, instancId). """ results = [] root = XML(xml_bytes) for address_data in root.find("addressesSet"): address = address_data.findtext("publicIp") instance_id = address_data.findtext("instanceId") results.append((address, instance_id)) return results def describe_availability_zones(self, xml_bytes): """Parse the XML returned by the C{DescribeAvailibilityZones} function. @param xml_bytes: XML bytes with a C{DescribeAvailibilityZonesResponse} root element. @return: a C{list} of L{AvailabilityZone}. TODO: regionName, messageSet """ results = [] root = XML(xml_bytes) for zone_data in root.find("availabilityZoneInfo"): zone_name = zone_data.findtext("zoneName") zone_state = zone_data.findtext("zoneState") results.append(model.AvailabilityZone(zone_name, zone_state)) return results class Query(BaseQuery): """A query that may be submitted to EC2.""" timeout = 30 def __init__(self, other_params=None, time_tuple=None, api_version=None, *args, **kwargs): """Create a Query to submit to EC2.""" super(Query, self).__init__(*args, **kwargs) # Currently, txAWS only supports version 2009-11-30 if api_version is None: api_version = version.ec2_api self.params = { "Version": api_version, "SignatureVersion": "2", "Action": self.action, "AWSAccessKeyId": self.creds.access_key, } if other_params is None or "Expires" not in other_params: # Only add a Timestamp parameter, if Expires isn't used, # since both can't be used in the same request. self.params["Timestamp"] = iso8601time(time_tuple) if other_params: self.params.update(other_params) self.signature = Signature(self.creds, self.endpoint, self.params) def sign(self, hash_type="sha256"): """Sign this query using its built in credentials. @param hash_type: if the SignatureVersion is 2, specify the type of hash to use, either "sha1" or "sha256". It defaults to the latter. This prepares it to be sent, and should be done as the last step before submitting the query. Signing is done automatically - this is a public method to facilitate testing. """ version = self.params["SignatureVersion"] if version == "2": self.params["SignatureMethod"] = "Hmac%s" % hash_type.upper() self.params["Signature"] = self.signature.compute() def submit(self): """Submit this query. @return: A deferred from get_page """ self.sign() url = self.endpoint.get_uri() method = self.endpoint.method params = self.signature.get_canonical_query_params() headers = {} kwargs = {"method": method} if method == "POST": headers["Content-Type"] = "application/x-www-form-urlencoded" kwargs["postdata"] = params else: url += "?%s" % params if self.endpoint.get_host() != self.endpoint.get_canonical_host(): headers["Host"] = self.endpoint.get_canonical_host() if headers: kwargs["headers"] = headers if self.timeout: kwargs["timeout"] = self.timeout d = self.get_page(url, **kwargs) return d.addErrback(ec2_error_wrapper) class Signature(object): """Compute EC2-compliant signatures for requests. @ivar creds: The L{AWSCredentials} to use to compute the signature. @ivar endpoint: The {AWSServiceEndpoint} to consider. @ivar params: A C{dict} of parameters to consider. They should be byte strings, but unicode strings are supported and will be encoded in UTF-8. """ def __init__(self, creds, endpoint, params): """Create a Query to submit to EC2.""" self.creds = creds self.endpoint = endpoint self.params = params def compute(self): """Compute and return the signature according to the given data.""" if "Signature" in self.params: raise RuntimeError("Existing signature in parameters") version = self.params["SignatureVersion"] if version == "1": bytes = self.old_signing_text() hash_type = "sha1" elif version == "2": bytes = self.signing_text() hash_type = self.params["SignatureMethod"][len("Hmac"):].lower() else: raise RuntimeError("Unsupported SignatureVersion: '%s'" % version) return self.creds.sign(bytes, hash_type) def old_signing_text(self): """Return the text needed for signing using SignatureVersion 1.""" result = [] lower_cmp = lambda x, y: cmp(x[0].lower(), y[0].lower()) for key, value in sorted(self.params.items(), cmp=lower_cmp): result.append("%s%s" % (key, value)) return "".join(result) def signing_text(self): """Return the text to be signed when signing the query.""" result = "%s\n%s\n%s\n%s" % (self.endpoint.method, self.endpoint.get_canonical_host(), self.endpoint.path, self.get_canonical_query_params()) return result def get_canonical_query_params(self): """Return the canonical query params (used in signing).""" result = [] for key, value in self.sorted_params(): result.append("%s=%s" % (self.encode(key), self.encode(value))) return "&".join(result) def encode(self, string): """Encode a_string as per the canonicalisation encoding rules. See the AWS dev reference page 186 (2009-11-30 version). @return: a_string encoded. """ if isinstance(string, unicode): string = string.encode("utf-8") return quote(string, safe="~") def sorted_params(self): """Return the query parameters sorted appropriately for signing.""" return sorted(self.params.items()) txAWS-0.2.3/txaws/ec2/model.py0000664000175000017500000001325611741311335017511 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2009 Robert Collins # Copyright (C) 2009 Canonical Ltd # Copyright (C) 2009 Duncan McGreggor # Licenced under the txaws licence available at /LICENSE in the txaws source. """EC2 client support.""" class Reservation(object): """An Amazon EC2 Reservation. @attrib reservation_id: Unique ID of the reservation. @attrib owner_id: AWS Access Key ID of the user who owns the reservation. @attrib groups: A list of security groups. """ def __init__(self, reservation_id, owner_id, groups=None): self.reservation_id = reservation_id self.owner_id = owner_id self.groups = groups or [] class Instance(object): """An Amazon EC2 Instance. @attrib instance_id: The instance ID of this instance. @attrib instance_state: The current state of this instance. @attrib instance_type: The instance type. @attrib image_id: Image ID of the AMI used to launch the instance. @attrib private_dns_name: The private DNS name assigned to the instance. This DNS name can only be used inside the Amazon EC2 network. This element remains empty until the instance enters a running state. @attrib dns_name: The public DNS name assigned to the instance. This DNS name is contactable from outside the Amazon EC2 network. This element remains empty until the instance enters a running state. @attrib private_ip_address: The private IP address assigned to the instance. @attrib ip_address: The IP address of the instance. @attrib key_name: If this instance was launched with an associated key pair, this displays the key pair name. @attrib ami_launch_index: The AMI launch index, which can be used to find this instance within the launch group. @attrib product_codes: Product codes attached to this instance. @attrib launch_time: The time the instance launched. @attrib placement: The location where the instance launched. @attrib kernel_id: Optional. Kernel associated with this instance. @attrib ramdisk_id: Optional. RAM disk associated with this instance. """ def __init__(self, instance_id, instance_state, instance_type="", image_id="", private_dns_name="", dns_name="", private_ip_address="", ip_address="", key_name="", ami_launch_index="", launch_time="", placement="", product_codes=[], kernel_id=None, ramdisk_id=None, reservation=None): self.instance_id = instance_id self.instance_state = instance_state self.instance_type = instance_type self.image_id = image_id self.private_dns_name = private_dns_name self.dns_name = dns_name self.private_ip_address = private_ip_address self.ip_address = ip_address self.key_name = key_name self.ami_launch_index = ami_launch_index self.launch_time = launch_time self.placement = placement self.product_codes = product_codes self.kernel_id = kernel_id self.ramdisk_id = ramdisk_id self.reservation = reservation class SecurityGroup(object): """An EC2 security group. @ivar owner_id: The AWS access key ID of the owner of this security group. @ivar name: The name of the security group. @ivar description: The description of this security group. @ivar allowed_groups: The sequence of L{UserIDGroupPair} instances for this security group. @ivar allowed_ips: The sequence of L{IPPermission} instances for this security group. """ def __init__(self, name, description, owner_id="", groups=None, ips=None): self.name = name self.description = description self.owner_id = owner_id self.allowed_groups = groups or [] self.allowed_ips = ips or [] class UserIDGroupPair(object): """A user ID/group name pair associated with a L{SecurityGroup}.""" def __init__(self, user_id, group_name): self.user_id = user_id self.group_name = group_name class IPPermission(object): """An IP permission associated with a L{SecurityGroup}.""" def __init__(self, ip_protocol, from_port, to_port, cidr_ip): self.ip_protocol = ip_protocol self.from_port = from_port self.to_port = to_port self.cidr_ip = cidr_ip class Volume(object): """An EBS volume instance.""" def __init__(self, id, size, status, create_time, availability_zone, snapshot_id): self.id = id self.size = size self.status = status self.create_time = create_time self.availability_zone = availability_zone self.snapshot_id = snapshot_id self.attachments = [] class Attachment(object): """An attachment of a L{Volume}.""" def __init__(self, instance_id, device, status, attach_time): self.instance_id = instance_id self.device = device self.status = status self.attach_time = attach_time class Snapshot(object): """A snapshot of a L{Volume}.""" def __init__(self, id, volume_id, status, start_time, progress): self.id = id self.volume_id = volume_id self.status = status self.start_time = start_time self.progress = progress class Keypair(object): """A convenience object for holding keypair data.""" def __init__(self, name, fingerprint, material=None): self.name = name self.fingerprint = fingerprint self.material = material class AvailabilityZone(object): """A convenience object for holding availability zone data.""" def __init__(self, name, state): self.name = name self.state = state txAWS-0.2.3/txaws/ec2/tests/0000775000175000017500000000000011741312025017167 5ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/ec2/tests/__init__.py0000664000175000017500000000000011741311335021271 0ustar oubiwannoubiwann00000000000000txAWS-0.2.3/txaws/ec2/tests/test_model.py0000664000175000017500000000367411741311335021715 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2009 Canonical Ltd # Licenced under the txaws licence available at /LICENSE in the txaws source. from txaws.ec2 import model from txaws.testing.base import TXAWSTestCase class SecurityGroupTestCase(TXAWSTestCase): def test_creation_defaults(self): group = model.SecurityGroup("name", "desc") self.assertEquals(group.name, "name") self.assertEquals(group.description, "desc") self.assertEquals(group.owner_id, "") self.assertEquals(group.allowed_groups, []) self.assertEquals(group.allowed_ips, []) def test_creation_all_parameters(self): user = "somegal24" other_groups = [ model.SecurityGroup("other1", "another group 1"), model.SecurityGroup("other2", "another group 2")] user_group_pairs = [ model.UserIDGroupPair(user, other_groups[0].name), model.UserIDGroupPair(user, other_groups[1].name)] ips = [model.IPPermission("tcp", "80", "80", "10.0.1.0/24")] group = model.SecurityGroup( "name", "desc", owner_id="me", groups=user_group_pairs, ips=ips) self.assertEquals(group.name, "name") self.assertEquals(group.description, "desc") self.assertEquals(group.owner_id, "me") self.assertEquals(group.allowed_groups[0].user_id, "somegal24") self.assertEquals(group.allowed_groups[0].group_name, "other1") self.assertEquals(group.allowed_groups[1].user_id, "somegal24") self.assertEquals(group.allowed_groups[1].group_name, "other2") self.assertEquals(group.allowed_ips[0].cidr_ip, "10.0.1.0/24") class UserIDGroupPairTestCase(TXAWSTestCase): def test_creation(self): user_id = "cowboy22" group_name = "Rough Riders" user_group_pair = model.UserIDGroupPair(user_id, group_name) self.assertEquals(user_group_pair.user_id, "cowboy22") self.assertEquals(user_group_pair.group_name, "Rough Riders") txAWS-0.2.3/txaws/ec2/tests/test_exception.py0000664000175000017500000000646111741311335022610 0ustar oubiwannoubiwann00000000000000# Copyright (c) 2009 Canonical Ltd # Licenced under the txaws licence available at /LICENSE in the txaws source. from twisted.trial.unittest import TestCase from txaws.ec2.exception import EC2Error from txaws.testing import payload from txaws.util import XML REQUEST_ID = "0ef9fc37-6230-4d81-b2e6-1b36277d4247" class EC2ErrorTestCase(TestCase): def test_set_400_error(self): errorsXML = "12" xml = "%s" % errorsXML error = EC2Error("", 400) error._set_400_error(XML(xml)) self.assertEquals(error.errors[0]["Code"], "1") self.assertEquals(error.errors[0]["Message"], "2") def test_has_error(self): errorsXML = "Code12" xml = "%s" % errorsXML error = EC2Error(xml, 400) self.assertTrue(error.has_error("Code1")) def test_single_error(self): error = EC2Error(payload.sample_ec2_error_message, 400) self.assertEquals(len(error.errors), 1) def test_multiple_errors(self): error = EC2Error(payload.sample_ec2_error_messages, 400) self.assertEquals(len(error.errors), 2) def test_single_error_str(self): error = EC2Error(payload.sample_ec2_error_message, 400) self.assertEquals(str(error), "Error Message: Message for Error.Code") def test_multiple_errors_str(self): error = EC2Error(payload.sample_ec2_error_messages, 400) self.assertEquals(str(error), "Multiple EC2 Errors.") def test_single_error_repr(self): error = EC2Error(payload.sample_ec2_error_message, 400) self.assertEquals( repr(error), "") def test_multiple_errors_repr(self): error = EC2Error(payload.sample_ec2_error_messages, 400) self.assertEquals(repr(error), "") def test_dupliate_keypair_result(self): error = EC2Error(payload.sample_duplicate_keypair_result, 400) self.assertEquals( error.get_error_messages(), "The key pair 'key1' already exists.") def test_dupliate_create_security_group_result(self): error = EC2Error( payload.sample_duplicate_create_security_group_result, 400) self.assertEquals( error.get_error_messages(), "The security group 'group1' already exists.") def test_invalid_create_security_group_result(self): error = EC2Error( payload.sample_invalid_create_security_group_result, 400) self.assertEquals( error.get_error_messages(), "Specified group name is a reserved name.") def test_invalid_client_token_id(self): error = EC2Error(payload.sample_invalid_client_token_result, 400) self.assertEquals( error.get_error_messages(), ("The AWS Access Key Id you provided does not exist in our " "records.")) def test_restricted_resource_access_attempt(self): error = EC2Error(payload.sample_restricted_resource_result, 400) self.assertEquals( error.get_error_messages(), "Unauthorized attempt to access restricted resource") txAWS-0.2.3/txaws/ec2/tests/test_client.py0000664000175000017500000024520611741311335022072 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2009 Robert Collins # Copyright (C) 2009 Canonical Ltd # Copyright (C) 2009 Duncan McGreggor # Licenced under the txaws licence available at /LICENSE in the txaws source. from datetime import datetime import os from twisted.internet import reactor from twisted.internet.defer import succeed, fail from twisted.internet.error import ConnectionRefusedError from twisted.python.failure import Failure from twisted.python.filepath import FilePath from twisted.web import server, static, util from twisted.web.error import Error as TwistedWebError from twisted.protocols.policies import WrappingFactory from txaws.util import iso8601time from txaws.credentials import AWSCredentials from txaws.ec2 import client from txaws.ec2 import model from txaws.ec2.exception import EC2Error from txaws.service import AWSServiceEndpoint, EC2_ENDPOINT_US from txaws.testing import payload from txaws.testing.base import TXAWSTestCase from txaws.testing.ec2 import FakePageGetter class ReservationTestCase(TXAWSTestCase): def test_reservation_creation(self): reservation = model.Reservation( "id1", "owner", groups=["one", "two"]) self.assertEquals(reservation.reservation_id, "id1") self.assertEquals(reservation.owner_id, "owner") self.assertEquals(reservation.groups, ["one", "two"]) class InstanceTestCase(TXAWSTestCase): def test_instance_creation(self): instance = model.Instance( "id1", "running", "type", "id2", "dns1", "dns2", "ip1", "ip2", "key", "ami", "time", "placement", ["prod1", "prod2"], "id3", "id4") self.assertEquals(instance.instance_id, "id1") self.assertEquals(instance.instance_state, "running") self.assertEquals(instance.instance_type, "type") self.assertEquals(instance.image_id, "id2") self.assertEquals(instance.private_dns_name, "dns1") self.assertEquals(instance.dns_name, "dns2") self.assertEquals(instance.private_ip_address, "ip1") self.assertEquals(instance.ip_address, "ip2") self.assertEquals(instance.key_name, "key") self.assertEquals(instance.ami_launch_index, "ami") self.assertEquals(instance.launch_time, "time") self.assertEquals(instance.placement, "placement") self.assertEquals(instance.product_codes, ["prod1", "prod2"]) self.assertEquals(instance.kernel_id, "id3") self.assertEquals(instance.ramdisk_id, "id4") class EC2ClientTestCase(TXAWSTestCase): def test_init_no_creds(self): os.environ["AWS_SECRET_ACCESS_KEY"] = "foo" os.environ["AWS_ACCESS_KEY_ID"] = "bar" ec2 = client.EC2Client() self.assertNotEqual(None, ec2.creds) def test_post_method(self): """ If the method of the endpoint is POST, the parameters are passed in the body. """ self.addCleanup(setattr, client.Query, "get_page", client.Query.get_page) def get_page(query, url, *args, **kwargs): self.assertEquals(args, ()) self.assertEquals( kwargs["headers"], {"Content-Type": "application/x-www-form-urlencoded"}) self.assertIn("postdata", kwargs) self.assertEquals(kwargs["method"], "POST") self.assertEquals(kwargs["timeout"], 30) return succeed(payload.sample_describe_instances_result) client.Query.get_page = get_page creds = AWSCredentials("foo", "bar") endpoint = AWSServiceEndpoint(uri=EC2_ENDPOINT_US, method="POST") ec2 = client.EC2Client(creds=creds, endpoint=endpoint) return ec2.describe_instances() def test_init_no_creds_non_available_errors(self): self.assertRaises(ValueError, client.EC2Client) def test_init_explicit_creds(self): creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds=creds) self.assertEqual(creds, ec2.creds) def test_describe_availability_zones_single(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DescribeAvailabilityZones") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual( other_params, {"ZoneName.1": "us-east-1a"}) def submit(self): return succeed( payload.sample_describe_availability_zones_single_result) def check_parsed_availability_zone(results): self.assertEquals(len(results), 1) [zone] = results self.assertEquals(zone.name, "us-east-1a") self.assertEquals(zone.state, "available") creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.describe_availability_zones(["us-east-1a"]) d.addCallback(check_parsed_availability_zone) return d def test_describe_availability_zones_multiple(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DescribeAvailabilityZones") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") def submit(self): return succeed( payload. sample_describe_availability_zones_multiple_results) def check_parsed_availability_zones(results): self.assertEquals(len(results), 3) self.assertEquals(results[0].name, "us-east-1a") self.assertEquals(results[0].state, "available") self.assertEquals(results[1].name, "us-east-1b") self.assertEquals(results[1].state, "available") self.assertEquals(results[2].name, "us-east-1c") self.assertEquals(results[2].state, "available") creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.describe_availability_zones() d.addCallback(check_parsed_availability_zones) return d class EC2ClientInstancesTestCase(TXAWSTestCase): def check_parsed_instances(self, results): instance = results[0] # check reservations reservation = instance.reservation self.assertEquals(reservation.reservation_id, "r-cf24b1a6") self.assertEquals(reservation.owner_id, "123456789012") # check groups group = reservation.groups[0] self.assertEquals(group, "default") # check instance self.assertEquals(instance.instance_id, "i-abcdef01") self.assertEquals(instance.instance_state, "running") self.assertEquals(instance.instance_type, "c1.xlarge") self.assertEquals(instance.image_id, "ami-12345678") self.assertEquals( instance.private_dns_name, "domU-12-31-39-03-15-11.compute-1.internal") self.assertEquals( instance.dns_name, "ec2-75-101-245-65.compute-1.amazonaws.com") self.assertEquals(instance.private_ip_address, "10.0.0.1") self.assertEquals(instance.ip_address, "75.101.245.65") self.assertEquals(instance.key_name, "keyname") self.assertEquals(instance.ami_launch_index, "0") self.assertEquals(instance.launch_time, "2009-04-27T02:23:18.000Z") self.assertEquals(instance.placement, "us-east-1c") self.assertEquals(instance.product_codes, ["774F4FF8"]) self.assertEquals(instance.kernel_id, "aki-b51cf9dc") self.assertEquals(instance.ramdisk_id, "ari-b31cf9da") def check_parsed_instances_required(self, results): instance = results[0] # check reservations reservation = instance.reservation self.assertEquals(reservation.reservation_id, "r-cf24b1a6") self.assertEquals(reservation.owner_id, "123456789012") # check groups group = reservation.groups[0] self.assertEquals(group, "default") # check instance self.assertEquals(instance.instance_id, "i-abcdef01") self.assertEquals(instance.instance_state, "running") self.assertEquals(instance.instance_type, "c1.xlarge") self.assertEquals(instance.image_id, "ami-12345678") self.assertEquals( instance.private_dns_name, "domU-12-31-39-03-15-11.compute-1.internal") self.assertEquals( instance.dns_name, "ec2-75-101-245-65.compute-1.amazonaws.com") self.assertEquals(instance.private_ip_address, "10.0.0.1") self.assertEquals(instance.ip_address, "75.101.245.65") self.assertEquals(instance.key_name, None) self.assertEquals(instance.ami_launch_index, None) self.assertEquals(instance.launch_time, "2009-04-27T02:23:18.000Z") self.assertEquals(instance.placement, "us-east-1c") self.assertEquals(instance.product_codes, []) self.assertEquals(instance.kernel_id, None) self.assertEquals(instance.ramdisk_id, None) def test_parse_reservation(self): creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds=creds) results = ec2.parser.describe_instances( payload.sample_describe_instances_result) self.check_parsed_instances(results) def test_describe_instances(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DescribeInstances") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEquals(other_params, {}) def submit(self): return succeed(payload.sample_describe_instances_result) creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.describe_instances() d.addCallback(self.check_parsed_instances) return d def test_describe_instances_required(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DescribeInstances") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEquals(other_params, {}) def submit(self): return succeed( payload.sample_required_describe_instances_result) creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.describe_instances() d.addCallback(self.check_parsed_instances_required) return d def test_describe_instances_specific_instances(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DescribeInstances") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEquals( other_params, {"InstanceId.1": "i-16546401", "InstanceId.2": "i-49873415"}) def submit(self): return succeed( payload.sample_required_describe_instances_result) creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.describe_instances("i-16546401", "i-49873415") d.addCallback(self.check_parsed_instances_required) return d def test_terminate_instances(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "TerminateInstances") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual( other_params, {"InstanceId.1": "i-1234", "InstanceId.2": "i-5678"}) def submit(self): return succeed(payload.sample_terminate_instances_result) def check_transition(changes): self.assertEqual([("i-1234", "running", "shutting-down"), ("i-5678", "shutting-down", "shutting-down")], sorted(changes)) creds = AWSCredentials("foo", "bar") endpoint = AWSServiceEndpoint(uri=EC2_ENDPOINT_US) ec2 = client.EC2Client(creds=creds, endpoint=endpoint, query_factory=StubQuery) d = ec2.terminate_instances("i-1234", "i-5678") d.addCallback(check_transition) return d def check_parsed_run_instances(self, results): instance = results[0] # check reservations reservation = instance.reservation self.assertEquals(reservation.reservation_id, "r-47a5402e") self.assertEquals(reservation.owner_id, "495219933132") # check groups group = reservation.groups[0] self.assertEquals(group, "default") # check instance self.assertEquals(instance.instance_id, "i-2ba64342") self.assertEquals(instance.instance_state, "pending") self.assertEquals(instance.instance_type, "m1.small") self.assertEquals(instance.placement, "us-east-1b") instance = results[1] self.assertEquals(instance.instance_id, "i-2bc64242") self.assertEquals(instance.instance_state, "pending") self.assertEquals(instance.instance_type, "m1.small") self.assertEquals(instance.placement, "us-east-1b") instance = results[2] self.assertEquals(instance.instance_id, "i-2be64332") self.assertEquals(instance.instance_state, "pending") self.assertEquals(instance.instance_type, "m1.small") self.assertEquals(instance.placement, "us-east-1b") def test_run_instances(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "RunInstances") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEquals( other_params, {"ImageId": "ami-1234", "MaxCount": "2", "MinCount": "1", "SecurityGroup.1": u"group1", "KeyName": u"default", "UserData": "Zm9v", "InstanceType": u"m1.small", "Placement.AvailabilityZone": u"us-east-1b", "KernelId": u"k-1234", "RamdiskId": u"r-1234"}) def submit(self): return succeed( payload.sample_run_instances_result) creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.run_instances("ami-1234", 1, 2, security_groups=[u"group1"], key_name=u"default", user_data=u"foo", instance_type=u"m1.small", availability_zone=u"us-east-1b", kernel_id=u"k-1234", ramdisk_id=u"r-1234") d.addCallback(self.check_parsed_run_instances) class EC2ClientSecurityGroupsTestCase(TXAWSTestCase): def test_describe_security_groups(self): """ L{EC2Client.describe_security_groups} returns a C{Deferred} that eventually fires with a list of L{SecurityGroup} instances created using XML data received from the cloud. """ class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DescribeSecurityGroups") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(other_params, {}) def submit(self): return succeed(payload.sample_describe_security_groups_result) def check_results(security_groups): [security_group] = security_groups self.assertEquals(security_group.owner_id, "UYY3TLBUXIEON5NQVUUX6OMPWBZIQNFM") self.assertEquals(security_group.name, "WebServers") self.assertEquals(security_group.description, "Web Servers") self.assertEquals(security_group.allowed_groups, []) self.assertEquals( [(ip.ip_protocol, ip.from_port, ip.to_port, ip.cidr_ip) for ip in security_group.allowed_ips], [("tcp", 80, 80, "0.0.0.0/0")]) creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.describe_security_groups() return d.addCallback(check_results) def test_describe_security_groups_with_multiple_results(self): """ The C{DescribeSecurityGroupsResponse} XML payload retrieved when L{EC2Client.describe_security_groups} is called can contain information about more than one L{SecurityGroup}. """ class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DescribeSecurityGroups") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(other_params, {}) def submit(self): return succeed( payload.sample_describe_security_groups_multiple_result) def check_results(security_groups): self.assertEquals(len(security_groups), 2) security_group = security_groups[0] self.assertEquals(security_group.owner_id, "UYY3TLBUXIEON5NQVUUX6OMPWBZIQNFM") self.assertEquals(security_group.name, "MessageServers") self.assertEquals(security_group.description, "Message Servers") self.assertEquals(security_group.allowed_groups, []) self.assertEquals( [(ip.ip_protocol, ip.from_port, ip.to_port, ip.cidr_ip) for ip in security_group.allowed_ips], [("tcp", 80, 80, "0.0.0.0/0")]) security_group = security_groups[1] self.assertEquals(security_group.owner_id, "UYY3TLBUXIEON5NQVUUX6OMPWBZIQNFM") self.assertEquals(security_group.name, "WebServers") self.assertEquals(security_group.description, "Web Servers") self.assertEquals([(pair.user_id, pair.group_name) for pair in security_group.allowed_groups], [("group-user-id", "group-name1"), ("group-user-id", "group-name2")]) self.assertEquals( [(ip.ip_protocol, ip.from_port, ip.to_port, ip.cidr_ip) for ip in security_group.allowed_ips], [("tcp", 80, 80, "0.0.0.0/0")]) creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.describe_security_groups() return d.addCallback(check_results) def test_describe_security_groups_with_multiple_groups(self): """ Several groups can be contained in a single ip permissions content, and there are recognized by the group parser. """ class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DescribeSecurityGroups") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(other_params, {}) def submit(self): return succeed( payload.sample_describe_security_groups_multiple_groups) def check_results(security_groups): self.assertEquals(len(security_groups), 1) security_group = security_groups[0] self.assertEquals(security_group.name, "web/ssh") self.assertEquals([(pair.user_id, pair.group_name) for pair in security_group.allowed_groups], [("170723411662", "default"), ("175723011368", "test1")]) self.assertEquals( [(ip.ip_protocol, ip.from_port, ip.to_port, ip.cidr_ip) for ip in security_group.allowed_ips], [('tcp', 22, 22, '0.0.0.0/0'), ("tcp", 80, 80, "0.0.0.0/0")]) creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.describe_security_groups() return d.addCallback(check_results) def test_describe_security_groups_with_name(self): """ L{EC2Client.describe_security_groups} optionally takes a list of security group names to limit results to. """ class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DescribeSecurityGroups") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(other_params, {"GroupName.1": "WebServers"}) def submit(self): return succeed(payload.sample_describe_security_groups_result) def check_result(security_groups): [security_group] = security_groups self.assertEquals(security_group.name, "WebServers") creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.describe_security_groups("WebServers") return d.addCallback(check_result) def test_describe_security_groups_with_openstack(self): """ L{EC2Client.describe_security_groups} can work with openstack responses, which may lack proper port information for self-referencing group. Verifying that the response doesn't cause an internal error, workaround for nova launchpad bug #829609. """ class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DescribeSecurityGroups") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(other_params, {"GroupName.1": "WebServers"}) def submit(self): return succeed( payload.sample_describe_security_groups_with_openstack) def check_result(security_groups): [security_group] = security_groups self.assertEquals(security_group.name, "WebServers") self.assertEqual( security_group.allowed_groups[0].group_name, "WebServers") creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.describe_security_groups("WebServers") return d.addCallback(check_result) def test_create_security_group(self): """ L{EC2Client.create_security_group} returns a C{Deferred} that eventually fires with a true value, indicating the success of the operation. """ class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "CreateSecurityGroup") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(other_params, { "GroupName": "WebServers", "GroupDescription": "The group for the web server farm.", }) def submit(self): return succeed(payload.sample_create_security_group) creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.create_security_group( "WebServers", "The group for the web server farm.") return self.assertTrue(d) def test_delete_security_group(self): """ L{EC2Client.delete_security_group} returns a C{Deferred} that eventually fires with a true value, indicating the success of the operation. """ class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DeleteSecurityGroup") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(other_params, { "GroupName": "WebServers", }) def submit(self): return succeed(payload.sample_delete_security_group) creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.delete_security_group("WebServers") return self.assertTrue(d) def test_delete_security_group_failure(self): """ L{EC2Client.delete_security_group} returns a C{Deferred} that eventually fires with a failure when EC2 is asked to delete a group that another group uses in that other group's policy. """ class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DeleteSecurityGroup") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(other_params, { "GroupName": "GroupReferredTo", }) def submit(self): error = EC2Error( payload.sample_delete_security_group_failure, 400) return fail(error) def check_error(error): self.assertEquals( str(error), ("Error Message: Group groupID1:GroupReferredTo is used by " "groups: groupID2:UsingGroup")) creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) failure = ec2.delete_security_group("GroupReferredTo") d = self.assertFailure(failure, EC2Error) return d.addCallback(check_error) def test_authorize_security_group_with_user_group_pair(self): """ L{EC2Client.authorize_security_group} returns a C{Deferred} that eventually fires with a true value, indicating the success of the operation. There are two ways to use the method: set another group's IP permissions or set new IP permissions; this test checks the first way. """ class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "AuthorizeSecurityGroupIngress") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(other_params, { "GroupName": "WebServers", "SourceSecurityGroupName": "AppServers", "SourceSecurityGroupOwnerId": "123456789123", }) def submit(self): return succeed(payload.sample_authorize_security_group) creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.authorize_security_group( "WebServers", source_group_name="AppServers", source_group_owner_id="123456789123") return self.assertTrue(d) def test_authorize_security_group_with_ip_permissions(self): """ L{EC2Client.authorize_security_group} returns a C{Deferred} that eventually fires with a true value, indicating the success of the operation. There are two ways to use the method: set another group's IP permissions or set new IP permissions; this test checks the second way. """ class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "AuthorizeSecurityGroupIngress") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(other_params, { "GroupName": "WebServers", "FromPort": "22", "ToPort": "80", "IpProtocol": "tcp", "CidrIp": "0.0.0.0/0", }) def submit(self): return succeed(payload.sample_authorize_security_group) creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.authorize_security_group( "WebServers", ip_protocol="tcp", from_port="22", to_port="80", cidr_ip="0.0.0.0/0") return self.assertTrue(d) def test_authorize_security_group_with_missing_parameters(self): """ L{EC2Client.authorize_security_group} returns a C{Deferred} that eventually fires with a true value, indicating the success of the operation. There are two ways to use the method: set another group's IP permissions or set new IP permissions. If not all group-setting parameters are set and not all IP permission parameters are set, an error is raised. """ creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds) self.assertRaises(ValueError, ec2.authorize_security_group, "WebServers", ip_protocol="tcp", from_port="22") try: ec2.authorize_security_group( "WebServers", ip_protocol="tcp", from_port="22") except Exception, error: self.assertEquals( str(error), ("You must specify either both group parameters or all the " "ip parameters.")) def test_authorize_group_permission(self): """ L{EC2Client.authorize_group_permission} returns a C{Deferred} that eventually fires with a true value, indicating the success of the operation. """ class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "AuthorizeSecurityGroupIngress") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(other_params, { "GroupName": "WebServers", "SourceSecurityGroupName": "AppServers", "SourceSecurityGroupOwnerId": "123456789123", }) def submit(self): return succeed(payload.sample_authorize_security_group) creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.authorize_group_permission( "WebServers", source_group_name="AppServers", source_group_owner_id="123456789123") return self.assertTrue(d) def test_authorize_ip_permission(self): """ L{EC2Client.authorize_ip_permission} returns a C{Deferred} that eventually fires with a true value, indicating the success of the operation. """ class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "AuthorizeSecurityGroupIngress") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(other_params, { "GroupName": "WebServers", "FromPort": "22", "ToPort": "80", "IpProtocol": "tcp", "CidrIp": "0.0.0.0/0", }) def submit(self): return succeed(payload.sample_authorize_security_group) creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.authorize_ip_permission( "WebServers", ip_protocol="tcp", from_port="22", to_port="80", cidr_ip="0.0.0.0/0") return self.assertTrue(d) def test_revoke_security_group_with_user_group_pair(self): """ L{EC2Client.revoke_security_group} returns a C{Deferred} that eventually fires with a true value, indicating the success of the operation. There are two ways to use the method: set another group's IP permissions or set new IP permissions; this test checks the first way. """ class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "RevokeSecurityGroupIngress") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(other_params, { "GroupName": "WebServers", "SourceSecurityGroupName": "AppServers", "SourceSecurityGroupOwnerId": "123456789123", }) def submit(self): return succeed(payload.sample_revoke_security_group) creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.revoke_security_group( "WebServers", source_group_name="AppServers", source_group_owner_id="123456789123") return self.assertTrue(d) def test_revoke_security_group_with_ip_permissions(self): """ L{EC2Client.revoke_security_group} returns a C{Deferred} that eventually fires with a true value, indicating the success of the operation. There are two ways to use the method: set another group's IP permissions or set new IP permissions; this test checks the second way. """ class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "RevokeSecurityGroupIngress") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(other_params, { "GroupName": "WebServers", "FromPort": "22", "ToPort": "80", "IpProtocol": "tcp", "CidrIp": "0.0.0.0/0", }) def submit(self): return succeed(payload.sample_revoke_security_group) creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.revoke_security_group( "WebServers", ip_protocol="tcp", from_port="22", to_port="80", cidr_ip="0.0.0.0/0") return self.assertTrue(d) def test_revoke_security_group_with_missing_parameters(self): """ L{EC2Client.revoke_security_group} returns a C{Deferred} that eventually fires with a true value, indicating the success of the operation. There are two ways to use the method: set another group's IP permissions or set new IP permissions. If not all group-setting parameters are set and not all IP permission parameters are set, an error is raised. """ creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds) self.assertRaises(ValueError, ec2.authorize_security_group, "WebServers", ip_protocol="tcp", from_port="22") try: ec2.authorize_security_group( "WebServers", ip_protocol="tcp", from_port="22") except Exception, error: self.assertEquals( str(error), ("You must specify either both group parameters or all the " "ip parameters.")) def test_revoke_group_permission(self): """ L{EC2Client.revoke_group_permission} returns a C{Deferred} that eventually fires with a true value, indicating the success of the operation. """ class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "RevokeSecurityGroupIngress") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(other_params, { "GroupName": "WebServers", "SourceSecurityGroupName": "AppServers", "SourceSecurityGroupOwnerId": "123456789123", }) def submit(self): return succeed(payload.sample_revoke_security_group) creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.revoke_group_permission( "WebServers", source_group_name="AppServers", source_group_owner_id="123456789123") return self.assertTrue(d) def test_revoke_ip_permission(self): """ L{EC2Client.revoke_ip_permission} returns a C{Deferred} that eventually fires with a true value, indicating the success of the operation. """ class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "RevokeSecurityGroupIngress") self.assertEqual(creds.access_key, "foo") self.assertEqual(creds.secret_key, "bar") self.assertEqual(other_params, { "GroupName": "WebServers", "FromPort": "22", "ToPort": "80", "IpProtocol": "tcp", "CidrIp": "0.0.0.0/0", }) def submit(self): return succeed(payload.sample_revoke_security_group) creds = AWSCredentials("foo", "bar") ec2 = client.EC2Client(creds, query_factory=StubQuery) d = ec2.revoke_ip_permission( "WebServers", ip_protocol="tcp", from_port="22", to_port="80", cidr_ip="0.0.0.0/0") return self.assertTrue(d) class EC2ClientEBSTestCase(TXAWSTestCase): def setUp(self): TXAWSTestCase.setUp(self) self.creds = AWSCredentials("foo", "bar") self.endpoint = AWSServiceEndpoint(uri=EC2_ENDPOINT_US) def check_parsed_volumes(self, volumes): self.assertEquals(len(volumes), 1) volume = volumes[0] self.assertEquals(volume.id, "vol-4282672b") self.assertEquals(volume.size, 800) self.assertEquals(volume.status, "in-use") self.assertEquals(volume.availability_zone, "us-east-1a") self.assertEquals(volume.snapshot_id, "snap-12345678") create_time = datetime(2008, 05, 07, 11, 51, 50) self.assertEquals(volume.create_time, create_time) self.assertEquals(len(volume.attachments), 1) attachment = volume.attachments[0] self.assertEquals(attachment.instance_id, "i-6058a509") self.assertEquals(attachment.status, "attached") self.assertEquals(attachment.device, u"/dev/sdh") attach_time = datetime(2008, 05, 07, 12, 51, 50) self.assertEquals(attachment.attach_time, attach_time) def test_describe_volumes(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DescribeVolumes") self.assertEqual(self.creds, creds) self.assertEqual(self.endpoint, endpoint) self.assertEquals(other_params, {}) def submit(self): return succeed(payload.sample_describe_volumes_result) ec2 = client.EC2Client(creds=self.creds, endpoint=self.endpoint, query_factory=StubQuery) d = ec2.describe_volumes() d.addCallback(self.check_parsed_volumes) return d def test_describe_specified_volumes(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DescribeVolumes") self.assertEqual(self.creds, creds) self.assertEqual(self.endpoint, endpoint) self.assertEquals( other_params, {"VolumeId.1": "vol-4282672b"}) def submit(self): return succeed(payload.sample_describe_volumes_result) ec2 = client.EC2Client(creds=self.creds, endpoint=self.endpoint, query_factory=StubQuery) d = ec2.describe_volumes("vol-4282672b") d.addCallback(self.check_parsed_volumes) return d def check_parsed_snapshots(self, snapshots): self.assertEquals(len(snapshots), 1) snapshot = snapshots[0] self.assertEquals(snapshot.id, "snap-78a54011") self.assertEquals(snapshot.volume_id, "vol-4d826724") self.assertEquals(snapshot.status, "pending") start_time = datetime(2008, 05, 07, 12, 51, 50) self.assertEquals(snapshot.start_time, start_time) self.assertEquals(snapshot.progress, 0.8) def test_describe_snapshots(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DescribeSnapshots") self.assertEqual(self.creds, creds) self.assertEqual(self.endpoint, endpoint) self.assertEquals(other_params, {}) def submit(self): return succeed(payload.sample_describe_snapshots_result) ec2 = client.EC2Client(creds=self.creds, endpoint=self.endpoint, query_factory=StubQuery) d = ec2.describe_snapshots() d.addCallback(self.check_parsed_snapshots) return d def test_describe_specified_snapshots(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DescribeSnapshots") self.assertEqual(self.creds, creds) self.assertEqual(self.endpoint, endpoint) self.assertEquals( other_params, {"SnapshotId.1": "snap-78a54011"}) def submit(self): return succeed(payload.sample_describe_snapshots_result) ec2 = client.EC2Client(creds=self.creds, endpoint=self.endpoint, query_factory=StubQuery) d = ec2.describe_snapshots("snap-78a54011") d.addCallback(self.check_parsed_snapshots) return d def test_create_volume(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "CreateVolume") self.assertEqual(self.creds, creds) self.assertEqual(self.endpoint, endpoint) self.assertEqual( other_params, {"AvailabilityZone": "us-east-1", "Size": "800"}) def submit(self): return succeed(payload.sample_create_volume_result) def check_parsed_volume(volume): self.assertEquals(volume.id, "vol-4d826724") self.assertEquals(volume.size, 800) self.assertEquals(volume.snapshot_id, "") create_time = datetime(2008, 05, 07, 11, 51, 50) self.assertEquals(volume.create_time, create_time) ec2 = client.EC2Client(creds=self.creds, endpoint=self.endpoint, query_factory=StubQuery) d = ec2.create_volume("us-east-1", size=800) d.addCallback(check_parsed_volume) return d def test_create_volume_with_snapshot(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "CreateVolume") self.assertEqual(self.creds, creds) self.assertEqual(self.endpoint, endpoint) self.assertEqual( other_params, {"AvailabilityZone": "us-east-1", "SnapshotId": "snap-12345678"}) def submit(self): return succeed(payload.sample_create_volume_result) def check_parsed_volume(volume): self.assertEquals(volume.id, "vol-4d826724") self.assertEquals(volume.size, 800) create_time = datetime(2008, 05, 07, 11, 51, 50) self.assertEquals(volume.create_time, create_time) ec2 = client.EC2Client(creds=self.creds, endpoint=self.endpoint, query_factory=StubQuery) d = ec2.create_volume("us-east-1", snapshot_id="snap-12345678") d.addCallback(check_parsed_volume) return d def test_create_volume_no_params(self): ec2 = client.EC2Client(creds=self.creds, endpoint=self.endpoint) error = self.assertRaises(ValueError, ec2.create_volume, "us-east-1") self.assertEquals( str(error), "Please provide either size or snapshot_id") def test_create_volume_both_params(self): ec2 = client.EC2Client(creds=self.creds, endpoint=self.endpoint) error = self.assertRaises(ValueError, ec2.create_volume, "us-east-1", size=800, snapshot_id="snap-12345678") self.assertEquals( str(error), "Please provide either size or snapshot_id") def test_delete_volume(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DeleteVolume") self.assertEqual(self.creds, creds) self.assertEqual(self.endpoint, endpoint) self.assertEqual( other_params, {"VolumeId": "vol-4282672b"}) def submit(self): return succeed(payload.sample_delete_volume_result) ec2 = client.EC2Client(creds=self.creds, endpoint=self.endpoint, query_factory=StubQuery) d = ec2.delete_volume("vol-4282672b") d.addCallback(self.assertEquals, True) return d def test_create_snapshot(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "CreateSnapshot") self.assertEqual(self.creds, creds) self.assertEqual(self.endpoint, endpoint) self.assertEqual( other_params, {"VolumeId": "vol-4d826724"}) def submit(self): return succeed(payload.sample_create_snapshot_result) def check_parsed_snapshot(snapshot): self.assertEquals(snapshot.id, "snap-78a54011") self.assertEquals(snapshot.volume_id, "vol-4d826724") self.assertEquals(snapshot.status, "pending") start_time = datetime(2008, 05, 07, 12, 51, 50) self.assertEquals(snapshot.start_time, start_time) self.assertEquals(snapshot.progress, 0) ec2 = client.EC2Client(creds=self.creds, endpoint=self.endpoint, query_factory=StubQuery) d = ec2.create_snapshot("vol-4d826724") d.addCallback(check_parsed_snapshot) return d def test_delete_snapshot(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DeleteSnapshot") self.assertEqual(self.creds, creds) self.assertEqual(self.endpoint, endpoint) self.assertEqual( other_params, {"SnapshotId": "snap-78a54011"}) def submit(self): return succeed(payload.sample_delete_snapshot_result) ec2 = client.EC2Client(creds=self.creds, endpoint=self.endpoint, query_factory=StubQuery) d = ec2.delete_snapshot("snap-78a54011") d.addCallback(self.assertEquals, True) return d def test_attach_volume(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "AttachVolume") self.assertEqual(self.creds, creds) self.assertEqual(self.endpoint, endpoint) self.assertEqual( other_params, {"VolumeId": "vol-4d826724", "InstanceId": "i-6058a509", "Device": "/dev/sdh"}) def submit(self): return succeed(payload.sample_attach_volume_result) def check_parsed_response(response): self.assertEquals( response, {"status": "attaching", "attach_time": datetime(2008, 05, 07, 11, 51, 50)}) ec2 = client.EC2Client(creds=self.creds, endpoint=self.endpoint, query_factory=StubQuery) d = ec2.attach_volume("vol-4d826724", "i-6058a509", "/dev/sdh") d.addCallback(check_parsed_response) return d def check_parsed_keypairs(self, results): self.assertEquals(len(results), 1) keypair = results[0] self.assertEquals(keypair.name, "gsg-keypair") self.assertEquals( keypair.fingerprint, "1f:51:ae:28:bf:89:e9:d8:1f:25:5d:37:2d:7d:b8:ca:9f:f5:f1:6f") def test_single_describe_keypairs(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DescribeKeyPairs") self.assertEqual("foo", creds) self.assertEquals(other_params, {}) def submit(self): return succeed(payload.sample_single_describe_keypairs_result) ec2 = client.EC2Client(creds="foo", query_factory=StubQuery) d = ec2.describe_keypairs() d.addCallback(self.check_parsed_keypairs) return d def test_multiple_describe_keypairs(self): def check_parsed_keypairs(results): self.assertEquals(len(results), 2) keypair1, keypair2 = results self.assertEquals(keypair1.name, "gsg-keypair-1") self.assertEquals( keypair1.fingerprint, "1f:51:ae:28:bf:89:e9:d8:1f:25:5d:37:2d:7d:b8:ca:9f:f5:f1:6f") self.assertEquals(keypair2.name, "gsg-keypair-2") self.assertEquals( keypair2.fingerprint, "1f:51:ae:28:bf:89:e9:d8:1f:25:5d:37:2d:7d:b8:ca:9f:f5:f1:70") class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DescribeKeyPairs") self.assertEqual("foo", creds) self.assertEquals(other_params, {}) def submit(self): return succeed( payload.sample_multiple_describe_keypairs_result) ec2 = client.EC2Client(creds="foo", query_factory=StubQuery) d = ec2.describe_keypairs() d.addCallback(check_parsed_keypairs) return d def test_describe_specified_keypairs(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DescribeKeyPairs") self.assertEqual("foo", creds) self.assertEquals( other_params, {"KeyName.1": "gsg-keypair"}) def submit(self): return succeed(payload.sample_single_describe_keypairs_result) ec2 = client.EC2Client(creds="foo", query_factory=StubQuery) d = ec2.describe_keypairs("gsg-keypair") d.addCallback(self.check_parsed_keypairs) return d def test_create_keypair(self): def check_parsed_create_keypair(keypair): self.assertEquals(keypair.name, "example-key-name") self.assertEquals( keypair.fingerprint, "1f:51:ae:28:bf:89:e9:d8:1f:25:5d:37:2d:7d:b8:ca:9f:f5:f1:6f") self.assertTrue(keypair.material.startswith( "-----BEGIN RSA PRIVATE KEY-----")) self.assertTrue(keypair.material.endswith( "-----END RSA PRIVATE KEY-----")) self.assertEquals(len(keypair.material), 1670) class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "CreateKeyPair") self.assertEqual("foo", creds) self.assertEquals( other_params, {"KeyName": "example-key-name"}) def submit(self): return succeed(payload.sample_create_keypair_result) ec2 = client.EC2Client(creds="foo", query_factory=StubQuery) d = ec2.create_keypair("example-key-name") d.addCallback(check_parsed_create_keypair) return d def test_delete_keypair_true_result(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DeleteKeyPair") self.assertEqual("foo", creds) self.assertEqual("http:///", endpoint.get_uri()) self.assertEquals( other_params, {"KeyName": "example-key-name"}) def submit(self): return succeed(payload.sample_delete_keypair_true_result) ec2 = client.EC2Client(creds="foo", query_factory=StubQuery) d = ec2.delete_keypair("example-key-name") d.addCallback(self.assertTrue) return d def test_delete_keypair_false_result(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DeleteKeyPair") self.assertEqual("foo", creds) self.assertEqual("http:///", endpoint.get_uri()) self.assertEquals( other_params, {"KeyName": "example-key-name"}) def submit(self): return succeed(payload.sample_delete_keypair_false_result) ec2 = client.EC2Client(creds="foo", query_factory=StubQuery) d = ec2.delete_keypair("example-key-name") d.addCallback(self.assertFalse) return d def test_delete_keypair_no_result(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DeleteKeyPair") self.assertEqual("foo", creds) self.assertEqual("http:///", endpoint.get_uri()) self.assertEquals( other_params, {"KeyName": "example-key-name"}) def submit(self): return succeed(payload.sample_delete_keypair_no_result) ec2 = client.EC2Client(creds="foo", query_factory=StubQuery) d = ec2.delete_keypair("example-key-name") d.addCallback(self.assertFalse) return d def test_import_keypair(self): """ L{client.EC2Client.import_keypair} calls the C{ImportKeyPair} method with the given arguments, encoding the key material in base64, and returns a C{Keypair} instance. """ def check_parsed_import_keypair(keypair): self.assertEquals(keypair.name, "example-key-name") self.assertEquals( keypair.fingerprint, "1f:51:ae:28:bf:89:e9:d8:1f:25:5d:37:2d:7d:b8:ca:9f:f5:f1:6f") self.assertEquals(keypair.material, material) class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "ImportKeyPair") self.assertEqual("foo", creds) self.assertEquals( other_params, {"KeyName": "example-key-name", "PublicKeyMaterial": "c3NoLWRzcyBBQUFBQjNOemFDMWtjM01BQUFDQkFQNmFjakFQeitUR" "jJkREtmZGlhcnp2cXBBcjhlbUl6UElBWUp6QXNoTFgvUTJCZ2tWc0" "42eGI2QUlIUGE1MUFtWXVieU5PYjMxeVhWS2FRQTF6L213SHZtRld" "LQ1ZFQ0wwPSkgdXNlckBob3N0"}) def submit(self): return succeed(payload.sample_import_keypair_result) ec2 = client.EC2Client(creds="foo", query_factory=StubQuery) material = ( "ssh-dss AAAAB3NzaC1kc3MAAACBAP6acjAPz+TF2dDKfdiarzvqpAr8emIzPIAY" "JzAshLX/Q2BgkVsN6xb6AIHPa51AmYubyNOb31yXVKaQA1z/mwHvmFWKCVECL0=)" " user@host") d = ec2.import_keypair("example-key-name", material) d.addCallback(check_parsed_import_keypair) return d class EC2ErrorWrapperTestCase(TXAWSTestCase): def setUp(self): TXAWSTestCase.setUp(self) def make_failure(self, status=None, type=None, message="", response=""): if type == TwistedWebError: error = type(status) elif message: error = type(message) else: error = type() failure = Failure(error) if not response: response = payload.sample_ec2_error_message failure.value.response = response failure.value.status = status return failure def test_302_error(self): failure = self.make_failure(302, Exception, "found") error = self.assertRaises(Exception, client.ec2_error_wrapper, failure) self.assertEquals(failure.type, type(error)) self.assertFalse(isinstance(error, EC2Error)) self.assertTrue(isinstance(error, Exception)) self.assertEquals(str(error), "found") def test_400_error(self): failure = self.make_failure(400, TwistedWebError) error = self.assertRaises(EC2Error, client.ec2_error_wrapper, failure) self.assertNotEquals(failure.type, type(error)) self.assertTrue(isinstance(error, EC2Error)) self.assertEquals(error.get_error_codes(), "Error.Code") self.assertEquals(error.get_error_messages(), "Message for Error.Code") def test_404_error(self): failure = self.make_failure(404, TwistedWebError) error = self.assertRaises(EC2Error, client.ec2_error_wrapper, failure) self.assertNotEquals(failure.type, type(error)) self.assertTrue(isinstance(error, EC2Error)) self.assertEquals(error.get_error_codes(), "Error.Code") self.assertEquals(error.get_error_messages(), "Message for Error.Code") def test_non_EC2_404_error(self): """ The error wrapper should handle cases where an endpoint returns a non-EC2 404. """ some_html = "404" failure = self.make_failure(404, TwistedWebError, "not found", some_html) error = self.assertRaises( TwistedWebError, client.ec2_error_wrapper, failure) self.assertTrue(isinstance(error, TwistedWebError)) self.assertEquals(error.status, 404) self.assertEquals(str(error), "404 Not Found") def test_500_error(self): failure = self.make_failure( 500, type=TwistedWebError, response=payload.sample_server_internal_error_result) error = self.assertRaises(EC2Error, client.ec2_error_wrapper, failure) self.assertTrue(isinstance(error, EC2Error)) self.assertEquals(error.get_error_codes(), "InternalError") self.assertEquals( error.get_error_messages(), "We encountered an internal error. Please try again.") self.assertEquals(error.request_id, "A2A7E5395E27DFBB") self.assertEquals( error.host_id, "f691zulHNsUqonsZkjhILnvWwD3ZnmOM4ObM1wXTc6xuS3GzPmjArp8QC/sGsn6K") def test_non_EC2_500_error(self): failure = self.make_failure(500, Exception, "A server error occurred") error = self.assertRaises(Exception, client.ec2_error_wrapper, failure) self.assertFalse(isinstance(error, EC2Error)) self.assertTrue(isinstance(error, Exception)) self.assertEquals(str(error), "A server error occurred") def test_timeout_error(self): failure = self.make_failure(type=Exception, message="timeout") error = self.assertRaises(Exception, client.ec2_error_wrapper, failure) self.assertFalse(isinstance(error, EC2Error)) self.assertTrue(isinstance(error, Exception)) self.assertEquals(str(error), "timeout") def test_connection_error(self): failure = self.make_failure(type=ConnectionRefusedError) error = self.assertRaises(ConnectionRefusedError, client.ec2_error_wrapper, failure) self.assertFalse(isinstance(error, EC2Error)) self.assertTrue(isinstance(error, ConnectionRefusedError)) def test_response_parse_error(self): bad_payload = "" failure = self.make_failure(400, type=TwistedWebError, response=bad_payload) error = self.assertRaises(Exception, client.ec2_error_wrapper, failure) self.assertEquals(str(error), "400 Bad Request") class QueryTestCase(TXAWSTestCase): def setUp(self): TXAWSTestCase.setUp(self) self.creds = AWSCredentials("foo", "bar") self.endpoint = AWSServiceEndpoint(uri=EC2_ENDPOINT_US) def test_init_minimum(self): query = client.Query( action="DescribeInstances", creds=self.creds, endpoint=self.endpoint) self.assertTrue("Timestamp" in query.params) del query.params["Timestamp"] self.assertEqual( query.params, {"AWSAccessKeyId": "foo", "Action": "DescribeInstances", "SignatureVersion": "2", "Version": "2009-11-30"}) def test_init_other_args_are_params(self): query = client.Query( action="DescribeInstances", creds=self.creds, endpoint=self.endpoint, other_params={"InstanceId.0": "12345"}, time_tuple=(2007, 11, 12, 13, 14, 15, 0, 0, 0)) self.assertEqual( query.params, {"AWSAccessKeyId": "foo", "Action": "DescribeInstances", "InstanceId.0": "12345", "SignatureVersion": "2", "Timestamp": "2007-11-12T13:14:15Z", "Version": "2009-11-30"}) def test_no_timestamp_if_expires_in_other_params(self): """ If Expires is present in other_params, Timestamp won't be added, since a request should contain either Expires or Timestamp, but not both. """ query = client.Query( action="DescribeInstances", creds=self.creds, endpoint=self.endpoint, other_params={"Expires": "2007-11-12T13:14:15Z"}) self.assertEqual( query.params, {"AWSAccessKeyId": "foo", "Action": "DescribeInstances", "SignatureVersion": "2", "Expires": "2007-11-12T13:14:15Z", "Version": "2009-11-30"}) def test_sign(self): query = client.Query( action="DescribeInstances", creds=self.creds, endpoint=self.endpoint, time_tuple=(2007, 11, 12, 13, 14, 15, 0, 0, 0)) query.sign() self.assertEqual("G4c2NtQaFNhWWT8EWPVIIOpHVr0mGUYwJVYss9krsMU=", query.params["Signature"]) def test_old_sign(self): query = client.Query( action="DescribeInstances", creds=self.creds, endpoint=self.endpoint, time_tuple=(2007, 11, 12, 13, 14, 15, 0, 0, 0), other_params={"SignatureVersion": "1"}) query.sign() self.assertEqual( "9xP+PIs/3QXW+4mWX6WGR4nGqfE=", query.params["Signature"]) def test_unsupported_sign(self): query = client.Query( action="DescribeInstances", creds=self.creds, endpoint=self.endpoint, time_tuple=(2007, 11, 12, 13, 14, 15, 0, 0, 0), other_params={"SignatureVersion": "0"}) self.assertRaises(RuntimeError, query.sign) def test_submit_with_port(self): """ If the endpoint port differs from the default one, the Host header of the request will include it. """ self.addCleanup(setattr, client.Query, "get_page", client.Query.get_page) def get_page(query, url, **kwargs): self.assertEqual("example.com:99", kwargs["headers"]["Host"]) return succeed(None) client.Query.get_page = get_page endpoint = AWSServiceEndpoint(uri="http://example.com:99/foo") query = client.Query(action="SomeQuery", creds=self.creds, endpoint=endpoint) d = query.submit() return d def test_submit_400(self): """A 4xx response status from EC2 should raise a txAWS EC2Error.""" status = 400 self.addCleanup(setattr, client.Query, "get_page", client.Query.get_page) fake_page_getter = FakePageGetter( status, payload.sample_ec2_error_message) client.Query.get_page = fake_page_getter.get_page_with_exception def check_error(error): self.assertTrue(isinstance(error, EC2Error)) self.assertEquals(error.get_error_codes(), "Error.Code") self.assertEquals( error.get_error_messages(), "Message for Error.Code") self.assertEquals(error.status, status) self.assertEquals(error.response, payload.sample_ec2_error_message) query = client.Query( action='BadQuery', creds=self.creds, endpoint=self.endpoint, time_tuple=(2009, 8, 15, 13, 14, 15, 0, 0, 0)) failure = query.submit() d = self.assertFailure(failure, TwistedWebError) d.addCallback(check_error) return d def test_submit_non_EC2_400(self): """ A 4xx response status from a non-EC2 compatible service should raise a Twisted web error. """ status = 400 self.addCleanup(setattr, client.Query, "get_page", client.Query.get_page) fake_page_getter = FakePageGetter( status, payload.sample_ec2_error_message) client.Query.get_page = fake_page_getter.get_page_with_exception def check_error(error): self.assertTrue(isinstance(error, TwistedWebError)) self.assertEquals(error.status, status) query = client.Query( action='BadQuery', creds=self.creds, endpoint=self.endpoint, time_tuple=(2009, 8, 15, 13, 14, 15, 0, 0, 0)) failure = query.submit() d = self.assertFailure(failure, TwistedWebError) d.addCallback(check_error) return d def test_submit_500(self): """ A 5xx response status from EC2 should raise the original Twisted exception. """ status = 500 self.addCleanup(setattr, client.Query, "get_page", client.Query.get_page) fake_page_getter = FakePageGetter( status, payload.sample_server_internal_error_result) client.Query.get_page = fake_page_getter.get_page_with_exception def check_error(error): self.assertTrue(isinstance(error, EC2Error)) self.assertEquals(error.status, status) self.assertEquals(error.get_error_codes(), "InternalError") self.assertEquals( error.get_error_messages(), "We encountered an internal error. Please try again.") query = client.Query( action='BadQuery', creds=self.creds, endpoint=self.endpoint, time_tuple=(2009, 8, 15, 13, 14, 15, 0, 0, 0)) failure = query.submit() d = self.assertFailure(failure, TwistedWebError) d.addCallback(check_error) return d class SignatureTestCase(TXAWSTestCase): def setUp(self): TXAWSTestCase.setUp(self) self.creds = AWSCredentials("foo", "bar") self.endpoint = AWSServiceEndpoint(uri=EC2_ENDPOINT_US) self.params = {} def test_encode_unreserved(self): all_unreserved = ("ABCDEFGHIJKLMNOPQRSTUVWXYZ" "abcdefghijklmnopqrstuvwxyz0123456789-_.~") signature = client.Signature(self.creds, self.endpoint, self.params) self.assertEqual(all_unreserved, signature.encode(all_unreserved)) def test_encode_space(self): """This may be just 'url encode', but the AWS manual isn't clear.""" signature = client.Signature(self.creds, self.endpoint, self.params) self.assertEqual("a%20space", signature.encode("a space")) def test_encode_unicode(self): """ L{Signature.encode} accepts unicode strings and encode them un UTF-8. """ signature = client.Signature(self.creds, self.endpoint, self.params) self.assertEqual( "f%C3%A9e", signature.encode(u"f\N{LATIN SMALL LETTER E WITH ACUTE}e")) def test_canonical_query(self): signature = client.Signature(self.creds, self.endpoint, self.params) time_tuple = (2007, 11, 12, 13, 14, 15, 0, 0, 0) self.params.update({"AWSAccessKeyId": "foo", "fu n": "g/ames", "argwithnovalue": "", "SignatureVersion": "2", "Timestamp": iso8601time(time_tuple), "Version": "2009-11-30", "Action": "DescribeInstances", "InstanceId.1": "i-1234"}) expected_params = ("AWSAccessKeyId=foo&Action=DescribeInstances" "&InstanceId.1=i-1234" "&SignatureVersion=2&" "Timestamp=2007-11-12T13%3A14%3A15Z&Version=2009-11-30&" "argwithnovalue=&fu%20n=g%2Fames") self.assertEqual( expected_params, signature.get_canonical_query_params()) def test_signing_text(self): signature = client.Signature(self.creds, self.endpoint, self.params) self.params.update({"AWSAccessKeyId": "foo", "SignatureVersion": "2", "Action": "DescribeInstances"}) signing_text = ("GET\n%s\n/\n" % self.endpoint.host + "AWSAccessKeyId=foo&Action=DescribeInstances&" + "SignatureVersion=2") self.assertEqual(signing_text, signature.signing_text()) def test_signing_text_with_non_default_port(self): """ The signing text uses the canonical host name, which includes the port number, if it differs from the default one. """ endpoint = AWSServiceEndpoint(uri="http://example.com:99/path") signature = client.Signature(self.creds, endpoint, self.params) self.params.update({"AWSAccessKeyId": "foo", "SignatureVersion": "2", "Action": "DescribeInstances"}) signing_text = ("GET\n" "example.com:99\n" "/path\n" "AWSAccessKeyId=foo&" "Action=DescribeInstances&" "SignatureVersion=2") self.assertEqual(signing_text, signature.signing_text()) def test_old_signing_text(self): signature = client.Signature(self.creds, self.endpoint, self.params) self.params.update({"AWSAccessKeyId": "foo", "SignatureVersion": "1", "Action": "DescribeInstances"}) signing_text = ( "ActionDescribeInstancesAWSAccessKeyIdfooSignatureVersion1") self.assertEqual(signing_text, signature.old_signing_text()) def test_sorted_params(self): signature = client.Signature(self.creds, self.endpoint, self.params) self.params.update({"AWSAccessKeyId": "foo", "fun": "games", "SignatureVersion": "2", "Version": "2009-11-30", "Action": "DescribeInstances"}) self.assertEqual([ ("AWSAccessKeyId", "foo"), ("Action", "DescribeInstances"), ("SignatureVersion", "2"), ("Version", "2009-11-30"), ("fun", "games"), ], signature.sorted_params()) class QueryPageGetterTestCase(TXAWSTestCase): def setUp(self): TXAWSTestCase.setUp(self) self.creds = AWSCredentials("foo", "bar") self.endpoint = AWSServiceEndpoint(uri=EC2_ENDPOINT_US) self.twisted_client_test_setup() self.cleanupServerConnections = 0 def tearDown(self): """Copied from twisted.web.test.test_webclient.""" # If the test indicated it might leave some server-side connections # around, clean them up. connections = self.wrapper.protocols.keys() # If there are fewer server-side connections than requested, # that's okay. Some might have noticed that the client closed # the connection and cleaned up after themselves. for n in range(min(len(connections), self.cleanupServerConnections)): proto = connections.pop() #msg("Closing %r" % (proto,)) proto.transport.loseConnection() if connections: #msg("Some left-over connections; this test is probably buggy.") pass return self.port.stopListening() def _listen(self, site): return reactor.listenTCP(0, site, interface="127.0.0.1") def twisted_client_test_setup(self): name = self.mktemp() os.mkdir(name) FilePath(name).child("file").setContent("0123456789") resource = static.File(name) resource.putChild("redirect", util.Redirect("/file")) self.site = server.Site(resource, timeout=None) self.wrapper = WrappingFactory(self.site) self.port = self._listen(self.wrapper) self.portno = self.port.getHost().port def get_url(self, path): return "http://127.0.0.1:%d/%s" % (self.portno, path) def test_get_page(self): """Copied from twisted.web.test.test_webclient.""" query = client.Query( action="DummyQuery", creds=self.creds, endpoint=self.endpoint, time_tuple=(2009, 8, 15, 13, 14, 15, 0, 0, 0)) deferred = query.get_page(self.get_url("file")) deferred.addCallback(self.assertEquals, "0123456789") return deferred class EC2ClientAddressTestCase(TXAWSTestCase): def setUp(self): TXAWSTestCase.setUp(self) self.creds = AWSCredentials("foo", "bar") self.endpoint = AWSServiceEndpoint(uri=EC2_ENDPOINT_US) def test_describe_addresses(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DescribeAddresses") self.assertEqual(self.creds, creds) self.assertEqual(self.endpoint, endpoint) self.assertEquals(other_params, {}) def submit(self): return succeed(payload.sample_describe_addresses_result) ec2 = client.EC2Client(creds=self.creds, endpoint=self.endpoint, query_factory=StubQuery) d = ec2.describe_addresses() d.addCallback( self.assertEquals, [("67.202.55.255", "i-28a64341"), ("67.202.55.233", None)]) return d def test_describe_specified_addresses(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DescribeAddresses") self.assertEqual(self.creds, creds) self.assertEqual(self.endpoint, endpoint) self.assertEquals( other_params, {"PublicIp.1": "67.202.55.255"}) def submit(self): return succeed(payload.sample_describe_addresses_result) ec2 = client.EC2Client(creds=self.creds, endpoint=self.endpoint, query_factory=StubQuery) d = ec2.describe_addresses("67.202.55.255") d.addCallback( self.assertEquals, [("67.202.55.255", "i-28a64341"), ("67.202.55.233", None)]) return d def test_associate_address(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "AssociateAddress") self.assertEqual(self.creds, creds) self.assertEqual(self.endpoint, endpoint) self.assertEquals( other_params, {"InstanceId": "i-28a64341", "PublicIp": "67.202.55.255"}) def submit(self): return succeed(payload.sample_associate_address_result) ec2 = client.EC2Client(creds=self.creds, endpoint=self.endpoint, query_factory=StubQuery) d = ec2.associate_address("i-28a64341", "67.202.55.255") d.addCallback(self.assertTrue) return d def test_allocate_address(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "AllocateAddress") self.assertEqual(self.creds, creds) self.assertEqual(self.endpoint, endpoint) self.assertEquals(other_params, {}) def submit(self): return succeed(payload.sample_allocate_address_result) ec2 = client.EC2Client(creds=self.creds, endpoint=self.endpoint, query_factory=StubQuery) d = ec2.allocate_address() d.addCallback(self.assertEquals, "67.202.55.255") return d def test_release_address(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "ReleaseAddress") self.assertEqual(self.creds, creds) self.assertEqual(self.endpoint, endpoint) self.assertEquals(other_params, {"PublicIp": "67.202.55.255"}) def submit(self): return succeed(payload.sample_release_address_result) ec2 = client.EC2Client(creds=self.creds, endpoint=self.endpoint, query_factory=StubQuery) d = ec2.release_address("67.202.55.255") d.addCallback(self.assertTrue) return d def test_disassociate_address(self): class StubQuery(object): def __init__(stub, action="", creds=None, endpoint=None, other_params={}): self.assertEqual(action, "DisassociateAddress") self.assertEqual(self.creds, creds) self.assertEqual(self.endpoint, endpoint) self.assertEquals(other_params, {"PublicIp": "67.202.55.255"}) def submit(self): return succeed(payload.sample_disassociate_address_result) ec2 = client.EC2Client(creds=self.creds, endpoint=self.endpoint, query_factory=StubQuery) d = ec2.disassociate_address("67.202.55.255") d.addCallback(self.assertTrue) return d class EC2ParserTestCase(TXAWSTestCase): def setUp(self): self.parser = client.Parser() def test_ec2_terminate_instances(self): """ Given a well formed response from EC2, parse the correct thing. """ ec2_xml = """ d0adc305-7f97-4652-b7c2-6993b2bb8260 i-cab0c1aa 32 shutting-down 16 running """ ec2_response = self.parser.terminate_instances(ec2_xml) self.assertEquals( [('i-cab0c1aa', 'running', 'shutting-down')], ec2_response) def test_nova_terminate_instances(self): """ Ensure parser can handle the somewhat non-standard response from nova Note that the bug has been reported in nova here: https://launchpad.net/bugs/862680 """ nova_xml = ( '' '4fe6643d-2346-4add-adb7-a1f61f37c043' 'true') nova_response = self.parser.terminate_instances(nova_xml) self.assertEquals([], nova_response) txAWS-0.2.3/txaws/wsdl.py0000664000175000017500000005226411741311335016713 0ustar oubiwannoubiwann00000000000000# Copyright (C) 2010-2012 Canonical Ltd. # Licenced under the txaws licence available at /LICENSE in the txaws source. """Parse WSDL definitions and generate schemas. To understand how the machinery in this module works, let's consider the following bit of the WSDL definition, that specifies the format for the response of a DescribeKeyPairs query: The L{WSDLParser} will take the above XML input and automatically generate a top-level L{NodeSchema} that can be used to access and modify the XML content of an actual DescribeKeyPairsResponse payload in an easy way. The automatically generated L{NodeSchema} object will be the same as the following manually created one: >>> child1 = LeafSchema('requestId') >>> sub_sub_child1 = LeafSchema('key_name') >>> sub_sub_child2 = LeafSchema('key_fingerprint') >>> sub_child = NodeSchema('item') >>> sub_child.add(sub_sub_child1) >>> sub_child.add(sub_sub_child2) >>> child2 = SequenceSchema('keySet') >>> child2.set(sub_child) >>> schema = NodeSchema('DescribeKeyPairsResponse') >>> schema.add(child1) >>> schema.add(child2) Now this L{NodeSchema} object can be used to access and modify a response XML payload, for example: 3ef0aa1d-57dd-4272 some-key 94:88:29:60:cf Let's assume to have an 'xml' variable that holds the XML payload above, now we can: >>> response = schema.create(etree.fromstring(xml)) >>> response.requestId 3ef0aa1d-57dd-4272 >>> response.keySet[0].keyName some-key >>> response.keySet[0].keyFingerprint 94:88:29:60:cf Note that there is no upfront parsing, the schema just makes sure that the response elements one actually accesses are consistent with the WDSL definition and that all modifications of those items are consistent as well. """ try: from lxml import etree except ImportError: etree = None class WSDLParseError(Exception): """Raised when a response doesn't comply with its schema.""" class LeafSchema(object): """Schema for a single XML leaf element in a response. @param tag: The name of the XML element tag this schema is for. """ def __init__(self, tag): self.tag = tag class NodeSchema(object): """Schema for a single XML inner node in a response. A L{Node} can have other L{Node} or L{LeafSchema} objects as children. @param tag: The name of the XML element tag this schema is for. @param _children: Optionally, the schemas for the child nodes, used only by tests. """ reserved = ["return"] def __init__(self, tag, _children=None): self.tag = tag self.children = {} self.children_min_occurs = {} if _children: for child in _children: self.add(child) def create(self, root=None, namespace=None): """Create an inner node element. @param root: The inner C{etree.Element} the item will be rooted at. @result: A L{NodeItem} with the given root, or a new one if none. @raises L{ECResponseError}: If the given C{root} has a bad tag. """ if root is not None: tag = root.tag if root.nsmap: namespace = root.nsmap[None] tag = tag[len(namespace) + 2:] if tag != self.tag: raise WSDLParseError("Expected response with tag '%s', but " "got '%s' instead" % (self.tag, tag)) return NodeItem(self, root, namespace) def dump(self, item): """Return the C{etree.Element} of the given L{NodeItem}. @param item: The L{NodeItem} to dump. """ return item._root def add(self, child, min_occurs=1): """Add a child node. @param child: The schema for the child node. @param min_occurs: The minimum number of times the child node must occur, if C{None} is given the default is 1. """ if not min_occurs in (0, 1): raise RuntimeError("Unexpected min bound for node schema") self.children[child.tag] = child self.children_min_occurs[child.tag] = min_occurs return child class NodeItem(object): """An inner node item in a tree of response elements. @param schema: The L{NodeSchema} this item must comply to. @param root: The C{etree.Element} this item is rooted at, if C{None} a new one will be created. """ def __init__(self, schema, root=None, namespace=None): object.__setattr__(self, "_schema", schema) object.__setattr__(self, "_namespace", namespace) if root is None: tag = self._get_namespace_tag(schema.tag) nsmap = None if namespace is not None: nsmap = {None: namespace} root = etree.Element(tag, nsmap=nsmap) object.__setattr__(self, "_root", root) def __getattr__(self, name): """Get the child item with the given C{name}. @raises L{WSDLParseError}: In the following cases: - The given C{name} is not in the schema. - There is more than one element tagged C{name} in the response. - No matching element is found in the response and C{name} is requred. - A required element is present but empty. """ tag = self._get_tag(name) schema = self._get_schema(tag) child = self._find_child(tag) if child is None: if isinstance(schema, LeafSchema): return self._check_value(tag, None) child = self._create_child(tag) if isinstance(schema, LeafSchema): return self._check_value(tag, child.text) return schema.create(child) def __setattr__(self, name, value): """Set the child item with the given C{name} to the given C{value}. Setting a non-leaf child item to C{None} will make it disappear from the tree completely. @raises L{WSDLParseError}: In the following cases: - The given C{name} is not in the schema. - There is more than one element tagged C{name} in the response. - The given value is C{None} and the element is required. - The given C{name} is associated with a non-leaf node, and the given C{value} is not C{None}. - The given C{name} is associated with a required non-leaf and the given C{value} is C{None}. """ tag = self._get_tag(name) schema = self._get_schema(tag) child = self._find_child(tag) if not isinstance(schema, LeafSchema): if value is not None: raise WSDLParseError("Can't set non-leaf tag '%s'" % tag) if isinstance(schema, NodeSchema): # Setting a node child item to None means removing it. self._check_value(tag, None) if child is not None: self._root.remove(child) if isinstance(schema, SequenceSchema): # Setting a sequence child item to None means removing all # its children. if child is None: child = self._create_child(tag) for item in child.getchildren(): child.remove(item) return if child is None: child = self._create_child(tag) child.text = self._check_value(tag, value) if child.text is None: self._root.remove(child) def _create_child(self, tag): """Create a new child element with the given tag.""" return etree.SubElement(self._root, self._get_namespace_tag(tag)) def _find_child(self, tag): """Find the child C{etree.Element} with the matching C{tag}. @raises L{WSDLParseError}: If more than one such elements are found. """ tag = self._get_namespace_tag(tag) children = self._root.findall(tag) if len(children) > 1: raise WSDLParseError("Duplicate tag '%s'" % tag) if len(children) == 0: return None return children[0] def _check_value(self, tag, value): """Ensure that the element matching C{tag} can have the given C{value}. @param tag: The tag to consider. @param value: The value to check @return: The unchanged L{value}, if valid. @raises L{WSDLParseError}: If the value is invalid. """ if value is None: if self._schema.children_min_occurs[tag] > 0: raise WSDLParseError("Missing tag '%s'" % tag) return value return value def _get_tag(self, name): """Get the L{NodeItem} attribute name for the given C{tag}.""" if name.endswith("_"): if name[:-1] in self._schema.reserved: return name[:-1] return name def _get_namespace_tag(self, tag): """Return the given C{tag} with the namespace prefix added, if any.""" if self._namespace is not None: tag = "{%s}%s" % (self._namespace, tag) return tag def _get_schema(self, tag): """Return the child schema for the given C{tag}. @raises L{WSDLParseError}: If the tag doesn't belong to the schema. """ schema = self._schema.children.get(tag) if not schema: raise WSDLParseError("Unknown tag '%s'" % tag) return schema def to_xml(self): """Convert the response to bare bones XML.""" return etree.tostring(self._root, encoding="utf-8") class SequenceSchema(object): """Schema for a single XML inner node holding a sequence of other nodes. @param tag: The name of the XML element tag this schema is for. @param _child: Optionally the schema of the items in the sequence, used by tests only. """ def __init__(self, tag, _child=None): self.tag = tag self.child = None if _child: self.set(_child, 0, "unbounded") def create(self, root=None, namespace=None): """Create a sequence element with the given root. @param root: The C{etree.Element} to root the sequence at, if C{None} a new one will be created.. @result: A L{SequenceItem} with the given root. @raises L{ECResponseError}: If the given C{root} has a bad tag. """ if root is not None: tag = root.tag if root.nsmap: namespace = root.nsmap[None] tag = tag[len(namespace) + 2:] if tag != self.tag: raise WSDLParseError("Expected response with tag '%s', but " "got '%s' instead" % (self.tag, tag)) return SequenceItem(self, root, namespace) def dump(self, item): """Return the C{etree.Element} of the given L{SequenceItem}. @param item: The L{SequenceItem} to dump. """ return item._root def set(self, child, min_occurs=1, max_occurs=1): """Set the schema for the sequence children. @param child: The schema that children must match. @param min_occurs: The minimum number of children the sequence must have. @param max_occurs: The maximum number of children the sequence can have. """ if isinstance(child, LeafSchema): raise RuntimeError("Sequence can't have leaf children") if self.child is not None: raise RuntimeError("Sequence has already a child") if min_occurs is None or max_occurs is None: raise RuntimeError("Sequence node without min or max") if isinstance(child, LeafSchema): raise RuntimeError("Sequence node with leaf child type") if not child.tag == "item": raise RuntimeError("Sequence node with bad child tag") self.child = child self.min_occurs = min_occurs self.max_occurs = max_occurs return child class SequenceItem(object): """A sequence node item in a tree of response elements. @param schema: The L{SequenceSchema} this item must comply to. @param root: The C{etree.Element} this item is rooted at, if C{None} a new one will be created. """ def __init__(self, schema, root=None, namespace=None): if root is None: root = etree.Element(schema.tag) object.__setattr__(self, "_schema", schema) object.__setattr__(self, "_root", root) object.__setattr__(self, "_namespace", namespace) def __getitem__(self, index): """Get the item with the given C{index} in the sequence. @raises L{WSDLParseError}: In the following cases: - If there is no child element with the given C{index}. - The given C{index} is higher than the allowed max. """ schema = self._schema.child tag = self._schema.tag if (self._schema.max_occurs != "unbounded" and index > self._schema.max_occurs - 1): raise WSDLParseError("Out of range item in tag '%s'" % tag) child = self._get_child(self._root.getchildren(), index) return schema.create(child) def append(self): """Append a new item to the sequence, appending it to the end. @return: The newly created item. @raises L{WSDLParseError}: If the operation would result in having more child elements than the allowed max. """ tag = self._schema.tag children = self._root.getchildren() if len(children) >= self._schema.max_occurs: raise WSDLParseError("Too many items in tag '%s'" % tag) schema = self._schema.child tag = "item" if self._namespace is not None: tag = "{%s}%s" % (self._namespace, tag) child = etree.SubElement(self._root, tag) return schema.create(child) def __delitem__(self, index): """Remove the item with the given C{index} from the sequence. @raises L{WSDLParseError}: If the operation would result in having less child elements than the required min_occurs, or if no such index is found. """ tag = self._schema.tag children = self._root.getchildren() if len(children) <= self._schema.min_occurs: raise WSDLParseError("Not enough items in tag '%s'" % tag) self._root.remove(self._get_child(children, index)) def remove(self, item): """Remove the given C{item} from the sequence. @raises L{WSDLParseError}: If the operation would result in having less child elements than the required min_occurs, or if no such index is found. """ for index, child in enumerate(self._root.getchildren()): if child is item._root: del self[index] return item raise WSDLParseError("Non existing item in tag '%s'" % self._schema.tag) def __iter__(self): """Iter all the sequence items in order.""" schema = self._schema.child for child in self._root.iterchildren(): yield schema.create(child) def __len__(self): """Return the length of the sequence.""" return len(self._root.getchildren()) def _get_child(self, children, index): """Return the child with the given index.""" try: return children[index] except IndexError: raise WSDLParseError("Non existing item in tag '%s'" % self._schema.tag) class WSDLParser(object): """Build response schemas out of WSDL definitions""" leaf_types = ["string", "boolean", "dateTime", "int", "long", "double", "integer"] def parse(self, wsdl): """Parse the given C{wsdl} data and build the associated schemas. @param wdsl: A string containing the raw xml of the WDSL definition to parse. @return: A C{dict} mapping response type names to their schemas. """ parser = etree.XMLParser(remove_blank_text=True, remove_comments=True) root = etree.fromstring(wsdl, parser=parser) types = {} responses = {} schemas = {} namespace = root.attrib["targetNamespace"] for element in root[0][0]: self._remove_namespace_from_tag(element) if element.tag in ["annotation", "group"]: continue name = element.attrib["name"] if element.tag == "element": if name.endswith("Response"): if name in responses: raise RuntimeError("Schema already defined") responses[name] = element elif element.tag == "complexType": types[name] = [element, False] else: raise RuntimeError("Top-level element with unexpected tag") for name, element in responses.iteritems(): schemas[name] = self._parse_type(element, types) schemas[name].namespace = namespace return schemas def _remove_namespace_from_tag(self, element): tag = element.tag if "}" in tag: tag = tag.split("}", 1)[1] element.tag = tag def _parse_type(self, element, types): """Parse a 'complexType' element. @param element: The top-level complexType element @param types: A map of the elements of all available complexType's. @return: The schema for the complexType. """ name = element.attrib["name"] type = element.attrib["type"] if not type.startswith("tns:"): raise RuntimeError("Unexpected element type %s" % type) type = type[4:] [children] = types[type][0] types[type][1] = True self._remove_namespace_from_tag(children) if children.tag not in ("sequence", "choice"): raise RuntimeError("Unexpected children type %s" % children.tag) if children[0].attrib["name"] == "item": schema = SequenceSchema(name) else: schema = NodeSchema(name) for child in children: self._remove_namespace_from_tag(child) if child.tag == "element": name, type, min_occurs, max_occurs = self._parse_child(child) if type in self.leaf_types: if max_occurs != 1: raise RuntimeError("Unexpected max value for leaf") if not isinstance(schema, NodeSchema): raise RuntimeError("Attempt to add leaf to a non-node") schema.add(LeafSchema(name), min_occurs=min_occurs) else: if name == "item": # sequence if not isinstance(schema, SequenceSchema): raise RuntimeError("Attempt to set child for " "non-sequence") schema.set(self._parse_type(child, types), min_occurs=min_occurs, max_occurs=max_occurs) else: if max_occurs != 1: raise RuntimeError("Unexpected max for node") if not isinstance(schema, NodeSchema): raise RuntimeError("Unexpected schema type") schema.add(self._parse_type(child, types), min_occurs=min_occurs) elif child.tag == "choice": pass else: raise RuntimeError("Unexpected child type") return schema def _parse_child(self, child): """Parse a single child element. @param child: The child C{etree.Element} to parse. @return: A tuple C{(name, type, min_occurs, max_occurs)} with the details about the given child. """ if set(child.attrib) - set(["name", "type", "minOccurs", "maxOccurs"]): raise RuntimeError("Unexpected attribute in child") name = child.attrib["name"] type = child.attrib["type"].split(":")[1] min_occurs = child.attrib.get("minOccurs") max_occurs = child.attrib.get("maxOccurs") if min_occurs is None: min_occurs = "1" min_occurs = int(min_occurs) if max_occurs is None: max_occurs = "1" if max_occurs != "unbounded": max_occurs = int(max_occurs) return name, type, min_occurs, max_occurs txAWS-0.2.3/LICENSE0000664000175000017500000000232711741311335015222 0ustar oubiwannoubiwann00000000000000Copyright (C) 2008 Tristan Seligmann Copyright (C) 2009 Robert Collins Copyright (C) 2009 Canonical Ltd Copyright (C) 2009 Duncan McGreggor Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.