forked from loafle/openapi-generator-original
Support file downloading in Python codegen.
This commit is contained in:
@@ -17,6 +17,7 @@ import json
|
||||
import datetime
|
||||
import mimetypes
|
||||
import random
|
||||
import tempfile
|
||||
|
||||
# python 2 and python 3 compatibility library
|
||||
from six import iteritems
|
||||
@@ -96,9 +97,13 @@ class ApiClient(object):
|
||||
# request url
|
||||
url = self.host + resource_path
|
||||
|
||||
# perform request and return response
|
||||
response_data = self.request(method, url, query_params=query_params, headers=header_params,
|
||||
post_params=post_params, body=body)
|
||||
if response == "file":
|
||||
# perform request and return response
|
||||
response_data = self.request(method, url, query_params=query_params, headers=header_params,
|
||||
post_params=post_params, body=body, raw=True)
|
||||
else:
|
||||
response_data = self.request(method, url, query_params=query_params, headers=header_params,
|
||||
post_params=post_params, body=body)
|
||||
|
||||
# deserialize response data
|
||||
if response:
|
||||
@@ -173,6 +178,10 @@ class ApiClient(object):
|
||||
sub_class = match.group(2)
|
||||
return {k: self.deserialize(v, sub_class) for k, v in iteritems(obj)}
|
||||
|
||||
# handle file downloading - save response body into a tmp file and return the instance
|
||||
if "file" == obj_class:
|
||||
return self.download_file(obj)
|
||||
|
||||
if obj_class in ['int', 'float', 'dict', 'list', 'str', 'bool', 'datetime', "object"]:
|
||||
obj_class = eval(obj_class)
|
||||
else: # not a native type, must be model class
|
||||
@@ -228,12 +237,12 @@ class ApiClient(object):
|
||||
except ImportError:
|
||||
return string
|
||||
|
||||
def request(self, method, url, query_params=None, headers=None, post_params=None, body=None):
|
||||
def request(self, method, url, query_params=None, headers=None, post_params=None, body=None, raw=False):
|
||||
"""
|
||||
Perform http request using RESTClient.
|
||||
"""
|
||||
if method == "GET":
|
||||
return RESTClient.GET(url, query_params=query_params, headers=headers)
|
||||
return RESTClient.GET(url, query_params=query_params, headers=headers, raw=raw)
|
||||
elif method == "HEAD":
|
||||
return RESTClient.HEAD(url, query_params=query_params, headers=headers)
|
||||
elif method == "POST":
|
||||
@@ -308,3 +317,27 @@ class ApiClient(object):
|
||||
querys[auth_setting['key']] = auth_setting['value']
|
||||
else:
|
||||
raise ValueError('Authentication token must be in `query` or `header`')
|
||||
|
||||
def download_file(self, response):
|
||||
"""
|
||||
Save response body into a file in (the defined) temporary folder, using the filename
|
||||
from the `Content-Disposition` header if provided, otherwise a random filename.
|
||||
|
||||
:param response: RESTResponse
|
||||
:return: file path
|
||||
"""
|
||||
fd, path = tempfile.mkstemp(dir=configuration.temp_folder_path)
|
||||
os.close(fd)
|
||||
os.remove(path)
|
||||
|
||||
content_disposition = response.getheader("Content-Disposition")
|
||||
if content_disposition:
|
||||
filename = re.search(r'filename=[\'"]?([^\'"\s]+)[\'"]?', content_disposition).group(1)
|
||||
path = os.path.join(os.path.dirname(path), filename)
|
||||
|
||||
with open(path, "w") as f:
|
||||
f.write(response.data)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
|
||||
@@ -41,4 +41,5 @@ api_key_prefix = {}
|
||||
username = ''
|
||||
password = ''
|
||||
|
||||
|
||||
# Temp foloder for file download
|
||||
temp_folder_path = None
|
||||
|
||||
@@ -75,7 +75,7 @@ class RESTClientObject(object):
|
||||
return self.pool_manager
|
||||
|
||||
def request(self, method, url, query_params=None, headers=None,
|
||||
body=None, post_params=None):
|
||||
body=None, post_params=None, raw=False):
|
||||
"""
|
||||
:param method: http request method
|
||||
:param url: http request url
|
||||
@@ -128,9 +128,12 @@ class RESTClientObject(object):
|
||||
if r.status not in range(200, 206):
|
||||
raise ApiException(r)
|
||||
|
||||
return self.process_response(r)
|
||||
return self.process_response(r, raw)
|
||||
|
||||
def process_response(self, response):
|
||||
def process_response(self, response, raw):
|
||||
if raw:
|
||||
return response
|
||||
|
||||
# In the python 3, the response.data is bytes.
|
||||
# we need to decode it to string.
|
||||
if sys.version_info > (3,):
|
||||
@@ -144,8 +147,8 @@ class RESTClientObject(object):
|
||||
|
||||
return resp
|
||||
|
||||
def GET(self, url, headers=None, query_params=None):
|
||||
return self.request("GET", url, headers=headers, query_params=query_params)
|
||||
def GET(self, url, headers=None, query_params=None, raw=False):
|
||||
return self.request("GET", url, headers=headers, query_params=query_params, raw=raw)
|
||||
|
||||
def HEAD(self, url, headers=None, query_params=None):
|
||||
return self.request("HEAD", url, headers=headers, query_params=query_params)
|
||||
|
||||
@@ -11,14 +11,14 @@ import os
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import SwaggerPetstore
|
||||
from SwaggerPetstore.rest import ApiException
|
||||
import swagger_client
|
||||
from swagger_client.rest import ApiException
|
||||
|
||||
|
||||
class StoreApiTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.store_api = SwaggerPetstore.StoreApi()
|
||||
self.store_api = swagger_client.StoreApi()
|
||||
|
||||
def tearDown(self):
|
||||
# sleep 1 sec between two every 2 tests
|
||||
Reference in New Issue
Block a user