import json
# import inflect
import requests
from openapi_core import OpenAPI
from openapi_spec_validator.validation.exceptions import OpenAPIValidationError
from referencing.exceptions import PointerToNowhere
from .request import Request, JSON_HEADER_KEY
from .parser import to_int, delete_nested_value, get_nested_value, find
from .custom_exceptions import UnrespectedSchema, InvalidFieldInSchema, AWrongStatus
[docs]
SUCCESS_FIELDS = [STATUS_KEY, SCHEMA_KEY, EXTRA_CHECK_KEY]
[docs]
PASSWORD_KEY = "password"
[docs]
FIRST_NAME_KEY = "first_name"
[docs]
LAST_NAME_KEY = "last_name"
[docs]
USER_ID_KEY = "user_id"
[docs]
LOGIN_TOKEN_KEY = "login_token"
[docs]
JOB_TOKEN_KEY = "job_token"
[docs]
API_KEY_KEY = "api_key"
[docs]
DESCRIPTION_KEY = 'description'
[docs]
REQUESTS_TESTED_EMPTY_DICT = {STATUS_KEY: None, DESCRIPTION_KEY: None}
[docs]
class Api:
[docs]
OPEN_API_PATCH_KEY = '#/components/schemas/DummySchema'
[docs]
OPEN_API_DUMMY_SCHEMA = {'properties': {'dummyMessage': {'type': 'string'}}, 'type': 'object'}
def __init__(self, api_type: str, url: str, swagger_file: str):
[docs]
self.api_type = api_type
[docs]
self.swagger_file = swagger_file
[docs]
self.schemas: dict = None
[docs]
self.requests_tested = {}
@staticmethod
def _log(logger, message):
if logger:
logger.error(message)
else:
print(message)
def _del_tags_duplicate(self, logger=None):
keys = []
for i in range(len(self.schemas['tags']) - 1, -1, -1):
name = self.schemas['tags'][i]['name']
if name in keys:
del self.schemas['tags'][i]
self._log(
logger,
f'Error while loading {self.swagger_file}, '
f'duplicate tag are forbidden, deleting tag "{name}".')
else:
keys.append(name)
@staticmethod
def _get_keys_from_json_path(json_path: str):
keys = []
temp = ''
for c in json_path[2:]:
if c == "." or c == "[" or c == "]":
if c == "]":
index = to_int(temp)
if index is not None:
temp = index
keys.append(temp)
temp = ""
continue
temp += c
if temp:
keys.append(temp)
return keys
def _generic_OpenAPIValidationError_patch(self, json_path: str, logger):
patched_api = False
keys = Api._get_keys_from_json_path(json_path)
if keys[-2] == 'parameters':
my_dict = get_nested_value(keys, self.schemas)
self._log(
logger,
f'Error while loading {self.swagger_file}, for json path {json_path},'
f' improperly format parameter {my_dict["name"]}, deleting entry.')
delete_nested_value(keys, self.schemas)
patched_api = True
elif keys[-2] == "schemas":
my_dict = get_nested_value(keys, self.schemas)
if "properties" in my_dict:
for key, value in my_dict["properties"].items():
if "description" in value and value["description"] is None:
self._log(
logger,
f'Error while loading {self.swagger_file}, for schema {json_path},'
f' property {key} have null description, removing description from property field.')
delete_nested_value(keys+["properties", key, "description"], self.schemas)
patched_api = True
if not patched_api:
raise NotImplementedError(f"Cannot handle OpenAPIValidationError from path: {json_path}")
def _generic_PointerToNowhere_patch(self, ref: str, logger):
old_ref = '#' + ref
self._log(logger, f"Error while loading {self.swagger_file}, PointerToNowhere, "
f"trying to replace {old_ref} occurrences by a dummy schema")
json_txt = json.dumps(self.schemas)
json_txt = json_txt.replace(old_ref, Api.OPEN_API_PATCH_KEY)
self.schemas = json.loads(json_txt)
def _patch_swagger(self, logger):
self.schemas['components']['schemas']['DummySchema'] = Api.OPEN_API_DUMMY_SCHEMA
successful_load = False
while not successful_load:
try:
self.openapi = OpenAPI.from_dict(self.schemas)
successful_load = True
except OpenAPIValidationError as e:
if e.json_path == '$.tags' and e.validator == 'uniqueItems':
self._del_tags_duplicate(logger)
continue
if e.json_path == '$.components.schemas.HTTPError' and e.validator == 'oneOf':
self.schemas['components']['schemas']['HTTPError'] = {
"properties": {"error": {"type": "string"}, "detail": {"type": "object"}}}
self._log(
logger,
f"Error while loading {self.swagger_file}, $.components.schemas.HTTPError has errors, "
f"replacing it by {self.schemas['components']['schemas']['HTTPError']}")
continue
self._generic_OpenAPIValidationError_patch(e.json_path, logger)
except PointerToNowhere as e:
self._generic_PointerToNowhere_patch(e.ref, logger)
[docs]
def load_swagger(self, patch_open_api: bool = False, logger=None):
if self.swagger_file.startswith("http"):
resp = requests.get(self.swagger_file)
if resp.status_code != 200:
raise ValueError("Invalid status code for openapi.json request")
self.schemas = resp.json()
else:
with open(self.swagger_file, "r", encoding='utf-8') as f:
self.schemas = json.load(f)
if self.schemas['info']['version'] is None:
message = f"ERROR for swagger {self.swagger_file}, version was None"
if logger:
logger.error(message)
else:
print(message)
self.schemas['info']['version'] = 'ERROR, version was None'
if patch_open_api:
self._patch_swagger(logger)
else:
self.openapi = OpenAPI.from_dict(self.schemas)
message = f"Successfully load {self.swagger_file}"
if logger:
logger.info(message)
else:
print(message)
[docs]
def get_schema_from_request(self, request: Request):
return self.get_schema(request.original_path_url, request.method.lower(), request.response.status_code)
[docs]
def get_schema(self, url: str, method: str, status_code: int | str):
schema = self.schemas['paths'][url][method.lower()]['responses'][str(
status_code)]['content'][JSON_HEADER_KEY]['schema']
if not schema:
raise KeyError("Missing schema")
return schema
def _item_match_schema(self, response, schema) -> list[str]:
errors = []
if not isinstance(response, dict):
errors.append(f"type error, expect a dict got a {type(response)}")
return errors
for key, value in schema['properties'].items():
if 'nullable' not in value or not value['nullable']:
if key not in response:
errors.append(f"missing field '{key}'")
if "additionalProperties" not in schema or not schema["additionalProperties"]:
for key, value in response.items():
if key not in schema['properties']:
errors.append(f"extra field '{key}'")
return errors
[docs]
def get_json_schema_exception(self, request) -> BaseException | None:
errors = []
schema = self.get_schema_from_request(request)
if 'items' in schema:
schema_name = schema['items']['$ref'][2:]
good_schema = find(schema_name, self.schemas)
for item in request.response.json():
errors.extend(self._item_match_schema(item, good_schema))
else:
schema_name = schema['$ref'][2:]
good_schema = find(schema_name, self.schemas)
response = request.response.json()
errors.extend(self._item_match_schema(response, good_schema))
if errors:
return UnrespectedSchema(request, schema, errors)
return None
[docs]
def get_invalid_data_exception(self, request, invalid_data) -> BaseException:
errors = []
for error in invalid_data.__cause__.schema_errors:
errors.append(
f"For field {error.json_path}, {error.message}. Expected something like {error.schema['example']}")
return InvalidFieldInSchema(
request,
self.get_schema_from_request(request)
)
[docs]
def get_request_tested_dict(self, request):
path = request.original_path_url
method = request.method.lower()
status = request.expected_status_code
curr_dict = self.requests_tested
if path not in curr_dict:
curr_dict[path] = {}
curr_dict = curr_dict[path]
if method not in curr_dict:
curr_dict[method] = {}
curr_dict = curr_dict[method]
if status not in curr_dict:
curr_dict[status] = {
STATUS_KEY: REQUESTS_TESTED_EMPTY_DICT.copy(),
SCHEMA_KEY: REQUESTS_TESTED_EMPTY_DICT.copy(),
EXTRA_CHECK_KEY: REQUESTS_TESTED_EMPTY_DICT.copy()
}
return curr_dict[status]
[docs]
def check_schema(self, request, exceptions: list[BaseException]):
success = len(exceptions) == 0
api_dict = self.get_request_tested_dict(request)[SCHEMA_KEY]
api_dict[STATUS_KEY] = success
if not success:
api_dict[DESCRIPTION_KEY] = '\n'.join(exception.get_message() for exception in exceptions)
raise exceptions[0]
[docs]
def check_status(self, request, exception: AWrongStatus):
api_dict = self.get_request_tested_dict(request)[STATUS_KEY]
api_dict[STATUS_KEY] = exception.is_raised()
if not api_dict[STATUS_KEY]:
api_dict[DESCRIPTION_KEY] = exception.get_message()
exception()
# def check_api_schemas(self):
# should_be_renamed = {}
# # should_be_created = {}
# should_be_hidden = [] # /page
# bad_named_crud = [] # forbidden endpoints
# bad_named_plural = []
# bad_order = []
# post_should_be_put = []
# probably_badly_named = []
# bad_get_unique_object_content = []
# missing_schema = {}
# pagination_endpoint = []
# p = inflect.engine()
# for path, path_content in self.schemas['paths'].items():
# path: str = path
# path_content: dict = path_content
# if all(['deprecated' in method_content for method_content in path_content.values()]):
# continue
# if path.endswith("/page"):
# expected_path = path[:-len("/page")]
# if expected_path in path:
# should_be_hidden.append(expected_path)
# continue
# else:
# should_be_renamed[path] = expected_path
# for method, method_content in path_content.items():
# if 'deprecated' in method_content:
# continue
# for status in method_content['responses']:
# try:
# schema = self.get_schema(path, method, status)
# except KeyError:
# set_nested_value([path, method, status], missing_schema, True)
# continue
# if 'items' in schema:
# schema_name = schema['items']['$ref'][2:]
# else:
# schema_name = schema['$ref'][2:]
# true_schema = find(schema_name, schemas)
# # if all([key in true_schema['properties'] for key in PAGINATION_RESULT_KEYS]):
# # pagination_endpoint.append(path)
# # continue
# for crud in ["create", "read", "update", "delete", "post", "get", "put", "delete"]:
# if crud in path:
# bad_named_crud.append(path)
# break
# if path.count("{") >= 1:
# path_elems = path.split("/")
# for index, elem in enumerate(path_elems):
# if "{" in elem:
# object_name = elem[1:-1].split("_")[0]
# break
# plural_object_name = p.plural_noun(object_name)
# if path_elems[index-1] == plural_object_name:
# pass
# elif path_elems[index-1] == object_name:
# new_path = path_elems
# new_path[index-1] = plural_object_name
# new_path = "/".join(new_path)
# should_be_renamed[path] = new_path
# bad_named_plural.append(path)
# elif plural_object_name in path_elems:
# # index_to_insert = path_elems.find(plural_object_name)
# bad_order.append(path)
# elif object_name in path_elems:
# bad_named_plural.append(path)
# bad_order.append(path)
# else:
# probably_badly_named.append(path)
# if "post" in path_content and index == len(path_elems)-1:
# post_should_be_put.append(path) # resource is already created so it's a modification
# if path in missing_schema and "get" in missing_schema[path] and "200" in missing_schema[path]['get']:
# continue
# if "get" in path_content:
# schema = self.get_schema(path, 'get', 200)
# if 'items' in schema:
# bad_get_unique_object_content.append(path)
[docs]
def create_all_requests_tested(context):
for api_name, api in context.apis.items():
for path, path_methods in api.schemas['paths'].items():
if path not in context.requests_tested[api_name]:
context.requests_tested[api_name][path] = {}
for method, all_status in path_methods.items():
if method not in context.requests_tested[api_name][path]:
context.requests_tested[api_name][path][method] = {}
for status in all_status['responses']:
status = int(status)
if status not in context.requests_tested[api_name][path][method]:
context.requests_tested[api_name][path][method][status] = {
STATUS_KEY: REQUESTS_TESTED_EMPTY_DICT.copy(),
SCHEMA_KEY: REQUESTS_TESTED_EMPTY_DICT.copy(),
EXTRA_CHECK_KEY: REQUESTS_TESTED_EMPTY_DICT.copy()
}