Support file downloading in Python codegen.

This commit is contained in:
geekerzp
2015-06-29 17:08:03 +08:00
parent 80303b524d
commit 6df6c079ee
8 changed files with 100 additions and 25 deletions

View File

@@ -17,6 +17,7 @@ import json
import datetime
import mimetypes
import random
import tempfile
# python 2 and python 3 compatibility library
from six import iteritems
@@ -96,9 +97,13 @@ class ApiClient(object):
# request url
url = self.host + resource_path
# perform request and return response
response_data = self.request(method, url, query_params=query_params, headers=header_params,
post_params=post_params, body=body)
if response == "file":
# perform request and return response
response_data = self.request(method, url, query_params=query_params, headers=header_params,
post_params=post_params, body=body, raw=True)
else:
response_data = self.request(method, url, query_params=query_params, headers=header_params,
post_params=post_params, body=body)
# deserialize response data
if response:
@@ -173,6 +178,10 @@ class ApiClient(object):
sub_class = match.group(2)
return {k: self.deserialize(v, sub_class) for k, v in iteritems(obj)}
# handle file downloading - save response body into a tmp file and return the instance
if "file" == obj_class:
return self.download_file(obj)
if obj_class in ['int', 'float', 'dict', 'list', 'str', 'bool', 'datetime', "object"]:
obj_class = eval(obj_class)
else: # not a native type, must be model class
@@ -228,12 +237,12 @@ class ApiClient(object):
except ImportError:
return string
def request(self, method, url, query_params=None, headers=None, post_params=None, body=None):
def request(self, method, url, query_params=None, headers=None, post_params=None, body=None, raw=False):
"""
Perform http request using RESTClient.
"""
if method == "GET":
return RESTClient.GET(url, query_params=query_params, headers=headers)
return RESTClient.GET(url, query_params=query_params, headers=headers, raw=raw)
elif method == "HEAD":
return RESTClient.HEAD(url, query_params=query_params, headers=headers)
elif method == "POST":
@@ -308,3 +317,27 @@ class ApiClient(object):
querys[auth_setting['key']] = auth_setting['value']
else:
raise ValueError('Authentication token must be in `query` or `header`')
def download_file(self, response):
"""
Save response body into a file in (the defined) temporary folder, using the filename
from the `Content-Disposition` header if provided, otherwise a random filename.
:param response: RESTResponse
:return: file path
"""
fd, path = tempfile.mkstemp(dir=configuration.temp_folder_path)
os.close(fd)
os.remove(path)
content_disposition = response.getheader("Content-Disposition")
if content_disposition:
filename = re.search(r'filename=[\'"]?([^\'"\s]+)[\'"]?', content_disposition).group(1)
path = os.path.join(os.path.dirname(path), filename)
with open(path, "w") as f:
f.write(response.data)
return path

View File

@@ -41,4 +41,5 @@ api_key_prefix = {}
username = ''
password = ''
# Temp foloder for file download
temp_folder_path = None

View File

@@ -75,7 +75,7 @@ class RESTClientObject(object):
return self.pool_manager
def request(self, method, url, query_params=None, headers=None,
body=None, post_params=None):
body=None, post_params=None, raw=False):
"""
:param method: http request method
:param url: http request url
@@ -128,9 +128,12 @@ class RESTClientObject(object):
if r.status not in range(200, 206):
raise ApiException(r)
return self.process_response(r)
return self.process_response(r, raw)
def process_response(self, response):
def process_response(self, response, raw):
if raw:
return response
# In the python 3, the response.data is bytes.
# we need to decode it to string.
if sys.version_info > (3,):
@@ -144,8 +147,8 @@ class RESTClientObject(object):
return resp
def GET(self, url, headers=None, query_params=None):
return self.request("GET", url, headers=headers, query_params=query_params)
def GET(self, url, headers=None, query_params=None, raw=False):
return self.request("GET", url, headers=headers, query_params=query_params, raw=raw)
def HEAD(self, url, headers=None, query_params=None):
return self.request("HEAD", url, headers=headers, query_params=query_params)

View File

@@ -11,14 +11,14 @@ import os
import time
import unittest
import SwaggerPetstore
from SwaggerPetstore.rest import ApiException
import swagger_client
from swagger_client.rest import ApiException
class StoreApiTests(unittest.TestCase):
def setUp(self):
self.store_api = SwaggerPetstore.StoreApi()
self.store_api = swagger_client.StoreApi()
def tearDown(self):
# sleep 1 sec between two every 2 tests