Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
70f7648c48 | ||
|
|
6a1789b3f5 |
@@ -1,623 +0,0 @@
|
||||
---
|
||||
name: api-documentation-generator
|
||||
description: Generate OpenAPI/Swagger specifications and API documentation from code or design. Use when creating API docs, generating OpenAPI specs, or documenting REST APIs.
|
||||
---
|
||||
|
||||
# API Documentation Generator
|
||||
|
||||
Generate OpenAPI/Swagger specifications and comprehensive API documentation.
|
||||
|
||||
## Quick Start
|
||||
|
||||
Create OpenAPI 3.0 specs with paths, schemas, and examples for complete API documentation.
|
||||
|
||||
## Instructions
|
||||
|
||||
### OpenAPI 3.0 Structure
|
||||
|
||||
**Basic structure:**
|
||||
```yaml
|
||||
openapi: 3.0.0
|
||||
info:
|
||||
title: API Name
|
||||
version: 1.0.0
|
||||
description: API description
|
||||
servers:
|
||||
- url: https://api.example.com/v1
|
||||
paths:
|
||||
/users:
|
||||
get:
|
||||
summary: List users
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
components:
|
||||
schemas:
|
||||
User:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
name:
|
||||
type: string
|
||||
```
|
||||
|
||||
### Info Section
|
||||
|
||||
```yaml
|
||||
info:
|
||||
title: E-commerce API
|
||||
version: 1.0.0
|
||||
description: |
|
||||
REST API for e-commerce platform.
|
||||
|
||||
## Authentication
|
||||
Use Bearer token in Authorization header.
|
||||
|
||||
## Rate Limiting
|
||||
1000 requests per hour per API key.
|
||||
contact:
|
||||
name: API Support
|
||||
email: api@example.com
|
||||
url: https://example.com/support
|
||||
license:
|
||||
name: MIT
|
||||
url: https://opensource.org/licenses/MIT
|
||||
```
|
||||
|
||||
### Servers
|
||||
|
||||
```yaml
|
||||
servers:
|
||||
- url: https://api.example.com/v1
|
||||
description: Production
|
||||
- url: https://staging-api.example.com/v1
|
||||
description: Staging
|
||||
- url: http://localhost:3000/v1
|
||||
description: Development
|
||||
```
|
||||
|
||||
### Paths and Operations
|
||||
|
||||
**GET endpoint:**
|
||||
```yaml
|
||||
paths:
|
||||
/users:
|
||||
get:
|
||||
summary: List users
|
||||
description: Retrieve a paginated list of users
|
||||
tags:
|
||||
- Users
|
||||
parameters:
|
||||
- name: page
|
||||
in: query
|
||||
schema:
|
||||
type: integer
|
||||
default: 1
|
||||
- name: per_page
|
||||
in: query
|
||||
schema:
|
||||
type: integer
|
||||
default: 20
|
||||
responses:
|
||||
'200':
|
||||
description: Successful response
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/User'
|
||||
meta:
|
||||
$ref: '#/components/schemas/PaginationMeta'
|
||||
```
|
||||
|
||||
**POST endpoint:**
|
||||
```yaml
|
||||
/users:
|
||||
post:
|
||||
summary: Create user
|
||||
tags:
|
||||
- Users
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/CreateUserRequest'
|
||||
example:
|
||||
name: John Doe
|
||||
email: john@example.com
|
||||
responses:
|
||||
'201':
|
||||
description: User created
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/User'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest'
|
||||
```
|
||||
|
||||
**Path parameters:**
|
||||
```yaml
|
||||
/users/{id}:
|
||||
get:
|
||||
summary: Get user by ID
|
||||
parameters:
|
||||
- name: id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: integer
|
||||
description: User ID
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/User'
|
||||
'404':
|
||||
$ref: '#/components/responses/NotFound'
|
||||
```
|
||||
|
||||
### Components - Schemas
|
||||
|
||||
**Simple schema:**
|
||||
```yaml
|
||||
components:
|
||||
schemas:
|
||||
User:
|
||||
type: object
|
||||
required:
|
||||
- id
|
||||
- email
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
example: 1
|
||||
name:
|
||||
type: string
|
||||
example: John Doe
|
||||
email:
|
||||
type: string
|
||||
format: email
|
||||
example: john@example.com
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
```
|
||||
|
||||
**Nested schema:**
|
||||
```yaml
|
||||
Order:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
customer:
|
||||
$ref: '#/components/schemas/User'
|
||||
items:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OrderItem'
|
||||
total:
|
||||
type: number
|
||||
format: float
|
||||
```
|
||||
|
||||
**Enum:**
|
||||
```yaml
|
||||
OrderStatus:
|
||||
type: string
|
||||
enum:
|
||||
- pending
|
||||
- processing
|
||||
- shipped
|
||||
- delivered
|
||||
- cancelled
|
||||
```
|
||||
|
||||
**OneOf (union types):**
|
||||
```yaml
|
||||
Payment:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/CreditCardPayment'
|
||||
- $ref: '#/components/schemas/PayPalPayment'
|
||||
discriminator:
|
||||
propertyName: payment_type
|
||||
```
|
||||
|
||||
### Components - Responses
|
||||
|
||||
**Reusable responses:**
|
||||
```yaml
|
||||
components:
|
||||
responses:
|
||||
NotFound:
|
||||
description: Resource not found
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Error'
|
||||
example:
|
||||
error:
|
||||
code: NOT_FOUND
|
||||
message: Resource not found
|
||||
|
||||
BadRequest:
|
||||
description: Invalid request
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ValidationError'
|
||||
|
||||
Unauthorized:
|
||||
description: Authentication required
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Error'
|
||||
```
|
||||
|
||||
### Components - Parameters
|
||||
|
||||
**Reusable parameters:**
|
||||
```yaml
|
||||
components:
|
||||
parameters:
|
||||
PageParam:
|
||||
name: page
|
||||
in: query
|
||||
schema:
|
||||
type: integer
|
||||
default: 1
|
||||
|
||||
LimitParam:
|
||||
name: limit
|
||||
in: query
|
||||
schema:
|
||||
type: integer
|
||||
default: 20
|
||||
maximum: 100
|
||||
|
||||
IdParam:
|
||||
name: id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: integer
|
||||
```
|
||||
|
||||
### Security Schemes
|
||||
|
||||
**Bearer token:**
|
||||
```yaml
|
||||
components:
|
||||
securitySchemes:
|
||||
BearerAuth:
|
||||
type: http
|
||||
scheme: bearer
|
||||
bearerFormat: JWT
|
||||
|
||||
security:
|
||||
- BearerAuth: []
|
||||
```
|
||||
|
||||
**API Key:**
|
||||
```yaml
|
||||
components:
|
||||
securitySchemes:
|
||||
ApiKeyAuth:
|
||||
type: apiKey
|
||||
in: header
|
||||
name: X-API-Key
|
||||
```
|
||||
|
||||
**OAuth 2.0:**
|
||||
```yaml
|
||||
components:
|
||||
securitySchemes:
|
||||
OAuth2:
|
||||
type: oauth2
|
||||
flows:
|
||||
authorizationCode:
|
||||
authorizationUrl: https://example.com/oauth/authorize
|
||||
tokenUrl: https://example.com/oauth/token
|
||||
scopes:
|
||||
read: Read access
|
||||
write: Write access
|
||||
```
|
||||
|
||||
### Examples
|
||||
|
||||
**Multiple examples:**
|
||||
```yaml
|
||||
responses:
|
||||
'200':
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/User'
|
||||
examples:
|
||||
admin:
|
||||
summary: Admin user
|
||||
value:
|
||||
id: 1
|
||||
name: Admin
|
||||
role: admin
|
||||
regular:
|
||||
summary: Regular user
|
||||
value:
|
||||
id: 2
|
||||
name: John
|
||||
role: user
|
||||
```
|
||||
|
||||
### Tags
|
||||
|
||||
**Organize endpoints:**
|
||||
```yaml
|
||||
tags:
|
||||
- name: Users
|
||||
description: User management
|
||||
- name: Products
|
||||
description: Product catalog
|
||||
- name: Orders
|
||||
description: Order processing
|
||||
|
||||
paths:
|
||||
/users:
|
||||
get:
|
||||
tags:
|
||||
- Users
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
|
||||
```yaml
|
||||
openapi: 3.0.0
|
||||
info:
|
||||
title: Blog API
|
||||
version: 1.0.0
|
||||
description: RESTful API for blog platform
|
||||
|
||||
servers:
|
||||
- url: https://api.blog.com/v1
|
||||
|
||||
paths:
|
||||
/posts:
|
||||
get:
|
||||
summary: List posts
|
||||
tags:
|
||||
- Posts
|
||||
parameters:
|
||||
- $ref: '#/components/parameters/PageParam'
|
||||
- $ref: '#/components/parameters/LimitParam'
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Post'
|
||||
meta:
|
||||
$ref: '#/components/schemas/PaginationMeta'
|
||||
|
||||
post:
|
||||
summary: Create post
|
||||
tags:
|
||||
- Posts
|
||||
security:
|
||||
- BearerAuth: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
required:
|
||||
- title
|
||||
- content
|
||||
properties:
|
||||
title:
|
||||
type: string
|
||||
content:
|
||||
type: string
|
||||
tags:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
responses:
|
||||
'201':
|
||||
description: Post created
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Post'
|
||||
|
||||
/posts/{id}:
|
||||
get:
|
||||
summary: Get post
|
||||
tags:
|
||||
- Posts
|
||||
parameters:
|
||||
- name: id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: integer
|
||||
responses:
|
||||
'200':
|
||||
description: Success
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Post'
|
||||
'404':
|
||||
$ref: '#/components/responses/NotFound'
|
||||
|
||||
components:
|
||||
schemas:
|
||||
Post:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
title:
|
||||
type: string
|
||||
content:
|
||||
type: string
|
||||
author:
|
||||
$ref: '#/components/schemas/User'
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
|
||||
User:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
name:
|
||||
type: string
|
||||
email:
|
||||
type: string
|
||||
|
||||
PaginationMeta:
|
||||
type: object
|
||||
properties:
|
||||
total:
|
||||
type: integer
|
||||
page:
|
||||
type: integer
|
||||
per_page:
|
||||
type: integer
|
||||
|
||||
Error:
|
||||
type: object
|
||||
properties:
|
||||
error:
|
||||
type: object
|
||||
properties:
|
||||
code:
|
||||
type: string
|
||||
message:
|
||||
type: string
|
||||
|
||||
parameters:
|
||||
PageParam:
|
||||
name: page
|
||||
in: query
|
||||
schema:
|
||||
type: integer
|
||||
default: 1
|
||||
|
||||
LimitParam:
|
||||
name: limit
|
||||
in: query
|
||||
schema:
|
||||
type: integer
|
||||
default: 20
|
||||
|
||||
responses:
|
||||
NotFound:
|
||||
description: Resource not found
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Error'
|
||||
|
||||
securitySchemes:
|
||||
BearerAuth:
|
||||
type: http
|
||||
scheme: bearer
|
||||
bearerFormat: JWT
|
||||
|
||||
security:
|
||||
- BearerAuth: []
|
||||
```
|
||||
|
||||
## Generating from Code
|
||||
|
||||
### From Express.js
|
||||
|
||||
```javascript
|
||||
// Use swagger-jsdoc
|
||||
/**
|
||||
* @swagger
|
||||
* /users:
|
||||
* get:
|
||||
* summary: List users
|
||||
* responses:
|
||||
* 200:
|
||||
* description: Success
|
||||
*/
|
||||
app.get('/users', (req, res) => {
|
||||
// Handler
|
||||
});
|
||||
```
|
||||
|
||||
### From FastAPI (Python)
|
||||
|
||||
```python
|
||||
# FastAPI auto-generates OpenAPI
|
||||
@app.get("/users", response_model=List[User])
|
||||
async def list_users():
|
||||
return users
|
||||
```
|
||||
|
||||
### From ASP.NET Core
|
||||
|
||||
```csharp
|
||||
// Use Swashbuckle
|
||||
[HttpGet]
|
||||
[ProducesResponseType(typeof(List<User>), 200)]
|
||||
public IActionResult GetUsers()
|
||||
{
|
||||
return Ok(users);
|
||||
}
|
||||
```
|
||||
|
||||
## Tools
|
||||
|
||||
**Swagger Editor**: https://editor.swagger.io
|
||||
**Swagger UI**: Interactive documentation
|
||||
**Redoc**: Alternative documentation UI
|
||||
**Postman**: Import OpenAPI for testing
|
||||
|
||||
## Best Practices
|
||||
|
||||
**Use $ref for reusability:**
|
||||
- Define schemas once
|
||||
- Reference in multiple places
|
||||
- Easier maintenance
|
||||
|
||||
**Include examples:**
|
||||
- Help developers understand
|
||||
- Enable better testing
|
||||
- Show expected formats
|
||||
|
||||
**Document errors:**
|
||||
- All possible status codes
|
||||
- Error response format
|
||||
- Error codes and meanings
|
||||
|
||||
**Version your API:**
|
||||
- Include version in URL or header
|
||||
- Document breaking changes
|
||||
- Maintain old versions
|
||||
|
||||
**Keep it updated:**
|
||||
- Generate from code when possible
|
||||
- Review regularly
|
||||
- Update with API changes
|
||||
@@ -1,187 +0,0 @@
|
||||
---
|
||||
name: pytest
|
||||
description: |
|
||||
Python testing with pytest framework for unit, integration, and API tests.
|
||||
Use when: (1) Writing test cases for Python code, (2) Setting up pytest fixtures,
|
||||
(3) Testing async functions with pytest-asyncio, (4) Mocking dependencies,
|
||||
(5) Parameterizing tests, (6) Testing FastAPI/Flask endpoints, (7) Setting up test coverage,
|
||||
(8) Creating test factories with factory_boy, (9) Configuring CI/CD test pipelines.
|
||||
---
|
||||
|
||||
# Pytest Testing Skill
|
||||
|
||||
Comprehensive testing patterns for Python applications using pytest.
|
||||
|
||||
## Quick Reference
|
||||
|
||||
| Feature | Reference File |
|
||||
|---------|----------------|
|
||||
| Fixtures, conftest, scopes | [references/fixtures.md](references/fixtures.md) |
|
||||
| Async testing, pytest-asyncio | [references/async-testing.md](references/async-testing.md) |
|
||||
| Mocking, patching, spies | [references/mocking.md](references/mocking.md) |
|
||||
| FastAPI/Flask endpoint testing | [references/api-testing.md](references/api-testing.md) |
|
||||
|
||||
## Dependencies
|
||||
|
||||
```toml
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
"pytest-cov>=4.1.0",
|
||||
"httpx>=0.28.0", # Async HTTP client for API tests
|
||||
"factory-boy>=3.3.0", # Test factories
|
||||
"faker>=33.0.0", # Fake data generation
|
||||
]
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### pyproject.toml
|
||||
|
||||
```toml
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
python_functions = ["test_*"]
|
||||
addopts = "-v --tb=short"
|
||||
filterwarnings = ["ignore::DeprecationWarning"]
|
||||
|
||||
[tool.coverage.run]
|
||||
source = ["app"]
|
||||
omit = ["*/tests/*", "*/__init__.py"]
|
||||
```
|
||||
|
||||
## Basic Test Structure
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
# Simple test function
|
||||
def test_addition():
|
||||
assert 1 + 1 == 2
|
||||
|
||||
# Test class (group related tests)
|
||||
class TestCalculator:
|
||||
def test_add(self):
|
||||
assert add(2, 3) == 5
|
||||
|
||||
def test_subtract(self):
|
||||
assert subtract(5, 3) == 2
|
||||
|
||||
# Expected exceptions
|
||||
def test_division_by_zero():
|
||||
with pytest.raises(ZeroDivisionError):
|
||||
divide(1, 0)
|
||||
|
||||
# Parametrized tests
|
||||
@pytest.mark.parametrize("input,expected", [
|
||||
(1, 1),
|
||||
(2, 4),
|
||||
(3, 9),
|
||||
])
|
||||
def test_square(input, expected):
|
||||
assert square(input) == expected
|
||||
```
|
||||
|
||||
## Fixtures
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
@pytest.fixture
|
||||
def sample_user():
|
||||
return {"name": "John", "email": "john@example.com"}
|
||||
|
||||
@pytest.fixture
|
||||
def db_connection():
|
||||
conn = create_connection()
|
||||
yield conn # Test runs here
|
||||
conn.close() # Cleanup after test
|
||||
|
||||
# Use fixture in test
|
||||
def test_user_name(sample_user):
|
||||
assert sample_user["name"] == "John"
|
||||
```
|
||||
|
||||
## Async Testing
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function():
|
||||
result = await async_operation()
|
||||
assert result == "success"
|
||||
|
||||
@pytest.fixture
|
||||
async def async_client():
|
||||
async with AsyncClient() as client:
|
||||
yield client
|
||||
```
|
||||
|
||||
## Mocking
|
||||
|
||||
```python
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
def test_with_mock():
|
||||
with patch("module.external_api") as mock_api:
|
||||
mock_api.return_value = {"data": "mocked"}
|
||||
result = function_using_api()
|
||||
assert result["data"] == "mocked"
|
||||
mock_api.assert_called_once()
|
||||
|
||||
# Async mock
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_mock():
|
||||
with patch("module.async_call", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = "result"
|
||||
result = await function_with_async_call()
|
||||
assert result == "result"
|
||||
```
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
pytest
|
||||
|
||||
# Verbose output
|
||||
pytest -v
|
||||
|
||||
# Run specific file
|
||||
pytest tests/test_users.py
|
||||
|
||||
# Run specific test
|
||||
pytest tests/test_users.py::test_create_user
|
||||
|
||||
# Run with coverage
|
||||
pytest --cov=app --cov-report=term-missing
|
||||
|
||||
# Stop on first failure
|
||||
pytest -x
|
||||
|
||||
# Run last failed tests
|
||||
pytest --lf
|
||||
|
||||
# Run tests matching pattern
|
||||
pytest -k "user and not delete"
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
project/
|
||||
├── app/
|
||||
│ ├── __init__.py
|
||||
│ ├── main.py
|
||||
│ └── services/
|
||||
├── tests/
|
||||
│ ├── __init__.py
|
||||
│ ├── conftest.py # Shared fixtures
|
||||
│ ├── test_main.py
|
||||
│ └── test_services/
|
||||
└── pyproject.toml
|
||||
```
|
||||
@@ -1,434 +0,0 @@
|
||||
# API Testing
|
||||
|
||||
## Table of Contents
|
||||
- [FastAPI Testing](#fastapi-testing)
|
||||
- [Flask Testing](#flask-testing)
|
||||
- [Test Factories](#test-factories)
|
||||
- [Authentication Testing](#authentication-testing)
|
||||
- [Database Testing](#database-testing)
|
||||
|
||||
## FastAPI Testing
|
||||
|
||||
### Setup
|
||||
|
||||
```python
|
||||
# tests/conftest.py
|
||||
import pytest
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from app.main import app
|
||||
from app.core.database import get_session
|
||||
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///./test.db"
|
||||
|
||||
test_engine = create_async_engine(TEST_DATABASE_URL)
|
||||
TestSession = async_sessionmaker(test_engine, class_=AsyncSession)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def db_session():
|
||||
"""Fresh database for each test."""
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
async with TestSession() as session:
|
||||
yield session
|
||||
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.drop_all)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(db_session):
|
||||
"""Test client with dependency override."""
|
||||
async def override_session():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_session] = override_session
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app),
|
||||
base_url="http://test",
|
||||
) as ac:
|
||||
yield ac
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
```
|
||||
|
||||
### Basic Endpoint Tests
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
class TestUsers:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user(self, client: AsyncClient):
|
||||
response = await client.post(
|
||||
"/api/users",
|
||||
json={"email": "test@example.com", "name": "Test User"},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["email"] == "test@example.com"
|
||||
assert "id" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user(self, client: AsyncClient, test_user):
|
||||
response = await client.get(f"/api/users/{test_user.id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["id"] == test_user.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_not_found(self, client: AsyncClient):
|
||||
response = await client.get("/api/users/99999")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user(self, client: AsyncClient, test_user, auth_headers):
|
||||
response = await client.patch(
|
||||
f"/api/users/{test_user.id}",
|
||||
json={"name": "Updated Name"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Updated Name"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_user(self, client: AsyncClient, test_user, admin_headers):
|
||||
response = await client.delete(
|
||||
f"/api/users/{test_user.id}",
|
||||
headers=admin_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
```
|
||||
|
||||
### Testing Query Parameters
|
||||
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_pagination(self, client: AsyncClient, auth_headers):
|
||||
response = await client.get(
|
||||
"/api/users",
|
||||
params={"page": 1, "per_page": 10, "sort": "-created_at"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "items" in data
|
||||
assert "total" in data
|
||||
assert data["page"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_users(self, client: AsyncClient, auth_headers):
|
||||
response = await client.get(
|
||||
"/api/users",
|
||||
params={"q": "john", "active": True},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
```
|
||||
|
||||
### Testing File Uploads
|
||||
|
||||
```python
|
||||
import io
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file(self, client: AsyncClient, auth_headers):
|
||||
file_content = b"Hello, World!"
|
||||
files = {"file": ("test.txt", io.BytesIO(file_content), "text/plain")}
|
||||
|
||||
response = await client.post(
|
||||
"/api/files/upload",
|
||||
files=files,
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["filename"] == "test.txt"
|
||||
```
|
||||
|
||||
## Flask Testing
|
||||
|
||||
### Setup
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from app import create_app
|
||||
from app.extensions import db
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create application for testing."""
|
||||
app = create_app("testing")
|
||||
app.config["TESTING"] = True
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
yield app
|
||||
db.drop_all()
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Test client."""
|
||||
return app.test_client()
|
||||
|
||||
@pytest.fixture
|
||||
def runner(app):
|
||||
"""CLI test runner."""
|
||||
return app.test_cli_runner()
|
||||
```
|
||||
|
||||
### Flask Test Examples
|
||||
|
||||
```python
|
||||
def test_create_user(client):
|
||||
response = client.post(
|
||||
"/api/users",
|
||||
json={"email": "test@example.com", "name": "Test"},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
|
||||
def test_get_user(client, test_user):
|
||||
response = client.get(f"/api/users/{test_user.id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["email"] == test_user.email
|
||||
|
||||
def test_protected_route(client, auth_token):
|
||||
response = client.get(
|
||||
"/api/protected",
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
```
|
||||
|
||||
## Test Factories
|
||||
|
||||
### Using factory_boy
|
||||
|
||||
```python
|
||||
# tests/factories.py
|
||||
import factory
|
||||
from faker import Faker
|
||||
from app.models import User, Project
|
||||
|
||||
fake = Faker()
|
||||
|
||||
|
||||
class UserFactory(factory.Factory):
|
||||
class Meta:
|
||||
model = User
|
||||
|
||||
email = factory.LazyAttribute(lambda _: fake.email())
|
||||
name = factory.LazyAttribute(lambda _: fake.name())
|
||||
hashed_password = "hashed_password"
|
||||
is_active = True
|
||||
|
||||
|
||||
class ProjectFactory(factory.Factory):
|
||||
class Meta:
|
||||
model = Project
|
||||
|
||||
name = factory.LazyAttribute(lambda _: fake.company())
|
||||
description = factory.LazyAttribute(lambda _: fake.sentence())
|
||||
owner_id = None
|
||||
|
||||
|
||||
# Async helper
|
||||
async def create_user(session, **kwargs) -> User:
|
||||
user = UserFactory(**kwargs)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
async def create_users(session, count: int = 5, **kwargs) -> list[User]:
|
||||
users = [UserFactory(**kwargs) for _ in range(count)]
|
||||
session.add_all(users)
|
||||
await session.commit()
|
||||
return users
|
||||
```
|
||||
|
||||
### Using Factories in Tests
|
||||
|
||||
```python
|
||||
from tests.factories import create_user, create_users
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users(client, db_session, auth_headers):
|
||||
# Create 20 users
|
||||
await create_users(db_session, count=20)
|
||||
|
||||
response = await client.get("/api/users", headers=auth_headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["total"] >= 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_active_users(client, db_session, auth_headers):
|
||||
await create_user(db_session, is_active=True)
|
||||
await create_user(db_session, is_active=False)
|
||||
|
||||
response = await client.get(
|
||||
"/api/users",
|
||||
params={"active": True},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert all(u["is_active"] for u in data["items"])
|
||||
```
|
||||
|
||||
## Authentication Testing
|
||||
|
||||
### Auth Fixtures
|
||||
|
||||
```python
|
||||
from app.core.security import create_access_token, hash_password
|
||||
|
||||
@pytest.fixture
|
||||
async def test_user(db_session):
|
||||
"""Create a regular test user."""
|
||||
user = User(
|
||||
email="test@example.com",
|
||||
hashed_password=hash_password("password123"),
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
return user
|
||||
|
||||
@pytest.fixture
|
||||
async def admin_user(db_session):
|
||||
"""Create an admin user."""
|
||||
user = User(
|
||||
email="admin@example.com",
|
||||
hashed_password=hash_password("admin123"),
|
||||
is_active=True,
|
||||
is_admin=True,
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
return user
|
||||
|
||||
@pytest.fixture
|
||||
def auth_headers(test_user):
|
||||
"""Headers with valid auth token."""
|
||||
token = create_access_token({"sub": str(test_user.id)})
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
@pytest.fixture
|
||||
def admin_headers(admin_user):
|
||||
"""Headers with admin auth token."""
|
||||
token = create_access_token({"sub": str(admin_user.id)})
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
```
|
||||
|
||||
### Auth Test Examples
|
||||
|
||||
```python
|
||||
class TestAuth:
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_success(self, client, test_user):
|
||||
response = await client.post(
|
||||
"/api/auth/login",
|
||||
data={"username": test_user.email, "password": "password123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "access_token" in response.json()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_wrong_password(self, client, test_user):
|
||||
response = await client.post(
|
||||
"/api/auth/login",
|
||||
data={"username": test_user.email, "password": "wrong"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_protected_endpoint_no_auth(self, client):
|
||||
response = await client.get("/api/protected")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_endpoint_regular_user(self, client, auth_headers):
|
||||
response = await client.delete(
|
||||
"/api/admin/users/1",
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
```
|
||||
|
||||
## Database Testing
|
||||
|
||||
### Transaction Rollback Pattern
|
||||
|
||||
```python
|
||||
@pytest.fixture
|
||||
async def db_session(engine):
|
||||
"""Each test runs in a transaction that rolls back."""
|
||||
async with engine.connect() as conn:
|
||||
await conn.begin()
|
||||
|
||||
async with AsyncSession(bind=conn) as session:
|
||||
yield session
|
||||
|
||||
await conn.rollback()
|
||||
```
|
||||
|
||||
### Testing Repository Layer
|
||||
|
||||
```python
|
||||
from app.repositories import UserRepository
|
||||
|
||||
class TestUserRepository:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create(self, db_session):
|
||||
repo = UserRepository(db_session)
|
||||
|
||||
user = await repo.create({
|
||||
"email": "test@example.com",
|
||||
"hashed_password": "hash",
|
||||
})
|
||||
|
||||
assert user.id is not None
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_email(self, db_session, test_user):
|
||||
repo = UserRepository(db_session)
|
||||
|
||||
user = await repo.get_by_email(test_user.email)
|
||||
|
||||
assert user.id == test_user.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete(self, db_session, test_user):
|
||||
repo = UserRepository(db_session)
|
||||
|
||||
await repo.soft_delete(test_user.id)
|
||||
|
||||
user = await repo.get_by_id(test_user.id)
|
||||
assert user.status == 0 # Soft deleted
|
||||
```
|
||||
@@ -1,290 +0,0 @@
|
||||
# Async Testing
|
||||
|
||||
## Table of Contents
|
||||
- [Setup](#setup)
|
||||
- [Basic Async Tests](#basic-async-tests)
|
||||
- [Async Fixtures](#async-fixtures)
|
||||
- [Testing Async Generators](#testing-async-generators)
|
||||
- [Timeouts](#timeouts)
|
||||
|
||||
## Setup
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip install pytest-asyncio
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```toml
|
||||
# pyproject.toml
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto" # Automatically handle async tests
|
||||
# OR
|
||||
asyncio_mode = "strict" # Require explicit @pytest.mark.asyncio
|
||||
```
|
||||
|
||||
### With pytest.ini
|
||||
|
||||
```ini
|
||||
[pytest]
|
||||
asyncio_mode = auto
|
||||
```
|
||||
|
||||
## Basic Async Tests
|
||||
|
||||
### Simple Async Test
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
# With asyncio_mode = "auto", decorator is optional
|
||||
async def test_async_function():
|
||||
result = await async_operation()
|
||||
assert result == "success"
|
||||
|
||||
# With asyncio_mode = "strict", decorator is required
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_with_marker():
|
||||
result = await async_operation()
|
||||
assert result == "success"
|
||||
```
|
||||
|
||||
### Async Test Class
|
||||
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
class TestAsyncOperations:
|
||||
async def test_fetch_data(self):
|
||||
data = await fetch_data()
|
||||
assert data is not None
|
||||
|
||||
async def test_process_data(self):
|
||||
result = await process_data({"key": "value"})
|
||||
assert result["processed"] is True
|
||||
```
|
||||
|
||||
### Testing Async Exceptions
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_exception():
|
||||
with pytest.raises(ValueError):
|
||||
await async_function_that_raises()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_exception_message():
|
||||
with pytest.raises(ValueError, match="Invalid input"):
|
||||
await validate_input("")
|
||||
```
|
||||
|
||||
## Async Fixtures
|
||||
|
||||
### Basic Async Fixture
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
@pytest.fixture
|
||||
async def async_client():
|
||||
"""Async fixture for HTTP client."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
yield client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_call(async_client):
|
||||
response = await async_client.get("https://api.example.com/data")
|
||||
assert response.status_code == 200
|
||||
```
|
||||
|
||||
### Async Database Fixture
|
||||
|
||||
```python
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create event loop for session-scoped async fixtures."""
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def engine():
|
||||
"""Session-scoped async engine."""
|
||||
engine = create_async_engine("postgresql+asyncpg://...")
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(engine):
|
||||
"""Function-scoped database session with rollback."""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
```
|
||||
|
||||
### Fixture with Async Setup/Teardown
|
||||
|
||||
```python
|
||||
@pytest.fixture
|
||||
async def resource():
|
||||
# Async setup
|
||||
resource = await create_resource()
|
||||
await resource.initialize()
|
||||
|
||||
yield resource
|
||||
|
||||
# Async teardown
|
||||
await resource.cleanup()
|
||||
await resource.close()
|
||||
```
|
||||
|
||||
## Testing Async Generators
|
||||
|
||||
### Async Generator Function
|
||||
|
||||
```python
|
||||
async def async_data_generator():
|
||||
for i in range(5):
|
||||
await asyncio.sleep(0.1)
|
||||
yield i
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_generator():
|
||||
results = []
|
||||
async for item in async_data_generator():
|
||||
results.append(item)
|
||||
|
||||
assert results == [0, 1, 2, 3, 4]
|
||||
```
|
||||
|
||||
### Async Context Manager
|
||||
|
||||
```python
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_resource():
|
||||
resource = await create_resource()
|
||||
try:
|
||||
yield resource
|
||||
finally:
|
||||
await resource.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_context_manager():
|
||||
async with async_resource() as resource:
|
||||
result = await resource.operation()
|
||||
assert result is not None
|
||||
```
|
||||
|
||||
## Timeouts
|
||||
|
||||
### Test Timeout
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import pytest
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.timeout(5) # 5 second timeout
|
||||
async def test_slow_operation():
|
||||
result = await potentially_slow_operation()
|
||||
assert result is not None
|
||||
```
|
||||
|
||||
### Using asyncio.wait_for
|
||||
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_timeout():
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
slow_operation(),
|
||||
timeout=2.0
|
||||
)
|
||||
assert result is not None
|
||||
except asyncio.TimeoutError:
|
||||
pytest.fail("Operation timed out")
|
||||
```
|
||||
|
||||
### Testing Timeout Behavior
|
||||
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_is_raised():
|
||||
"""Verify function properly times out."""
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(
|
||||
never_completes(),
|
||||
timeout=0.1
|
||||
)
|
||||
```
|
||||
|
||||
## Concurrent Tests
|
||||
|
||||
### Testing Concurrent Operations
|
||||
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_requests():
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Run 10 requests concurrently
|
||||
tasks = [
|
||||
client.get(f"https://api.example.com/items/{i}")
|
||||
for i in range(10)
|
||||
]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
assert all(r.status_code == 200 for r in responses)
|
||||
```
|
||||
|
||||
### Testing Race Conditions
|
||||
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
async def test_counter_thread_safety():
|
||||
counter = AsyncCounter()
|
||||
|
||||
async def increment():
|
||||
for _ in range(100):
|
||||
await counter.increment()
|
||||
|
||||
# Run 10 concurrent incrementers
|
||||
await asyncio.gather(*[increment() for _ in range(10)])
|
||||
|
||||
assert counter.value == 1000
|
||||
```
|
||||
|
||||
## Event Loop Configuration
|
||||
|
||||
### Custom Event Loop
|
||||
|
||||
```python
|
||||
import pytest
|
||||
import asyncio
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop_policy():
|
||||
"""Use uvloop for faster async tests."""
|
||||
import uvloop
|
||||
return uvloop.EventLoopPolicy()
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop(event_loop_policy):
|
||||
asyncio.set_event_loop_policy(event_loop_policy)
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
```
|
||||
@@ -1,296 +0,0 @@
|
||||
# Fixtures
|
||||
|
||||
## Table of Contents
|
||||
- [Basic Fixtures](#basic-fixtures)
|
||||
- [Fixture Scopes](#fixture-scopes)
|
||||
- [Fixture Parameters](#fixture-parameters)
|
||||
- [Conftest.py](#conftestpy)
|
||||
- [Built-in Fixtures](#built-in-fixtures)
|
||||
|
||||
## Basic Fixtures
|
||||
|
||||
### Simple Fixture
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
@pytest.fixture
|
||||
def user_data():
|
||||
"""Return sample user data."""
|
||||
return {
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"age": 30,
|
||||
}
|
||||
|
||||
def test_user_name(user_data):
|
||||
assert user_data["name"] == "John Doe"
|
||||
|
||||
def test_user_email(user_data):
|
||||
assert "@" in user_data["email"]
|
||||
```
|
||||
|
||||
### Fixture with Setup and Teardown
|
||||
|
||||
```python
|
||||
@pytest.fixture
|
||||
def database():
|
||||
"""Setup database, yield, then cleanup."""
|
||||
# Setup
|
||||
db = create_database()
|
||||
db.connect()
|
||||
|
||||
yield db # Test runs here
|
||||
|
||||
# Teardown
|
||||
db.clear()
|
||||
db.disconnect()
|
||||
|
||||
def test_insert(database):
|
||||
database.insert({"id": 1, "name": "Test"})
|
||||
assert database.count() == 1
|
||||
```
|
||||
|
||||
### Fixture Returning Factory
|
||||
|
||||
```python
|
||||
@pytest.fixture
|
||||
def make_user():
|
||||
"""Return a factory function for creating users."""
|
||||
created_users = []
|
||||
|
||||
def _make_user(name: str, email: str = None):
|
||||
user = User(name=name, email=email or f"{name}@example.com")
|
||||
created_users.append(user)
|
||||
return user
|
||||
|
||||
yield _make_user
|
||||
|
||||
# Cleanup all created users
|
||||
for user in created_users:
|
||||
user.delete()
|
||||
|
||||
def test_multiple_users(make_user):
|
||||
user1 = make_user("Alice")
|
||||
user2 = make_user("Bob")
|
||||
assert user1.name != user2.name
|
||||
```
|
||||
|
||||
## Fixture Scopes
|
||||
|
||||
```python
|
||||
# Function scope (default) - runs for each test
|
||||
@pytest.fixture(scope="function")
|
||||
def fresh_data():
|
||||
return {"count": 0}
|
||||
|
||||
# Class scope - runs once per test class
|
||||
@pytest.fixture(scope="class")
|
||||
def class_resource():
|
||||
return ExpensiveResource()
|
||||
|
||||
# Module scope - runs once per test module
|
||||
@pytest.fixture(scope="module")
|
||||
def module_connection():
|
||||
conn = create_connection()
|
||||
yield conn
|
||||
conn.close()
|
||||
|
||||
# Session scope - runs once per test session
|
||||
@pytest.fixture(scope="session")
|
||||
def session_config():
|
||||
return load_config()
|
||||
```
|
||||
|
||||
### Scope Hierarchy
|
||||
|
||||
```
|
||||
session (once per test run)
|
||||
└── package (once per test package)
|
||||
└── module (once per test file)
|
||||
└── class (once per test class)
|
||||
└── function (once per test function)
|
||||
```
|
||||
|
||||
## Fixture Parameters
|
||||
|
||||
### Parametrized Fixtures
|
||||
|
||||
```python
|
||||
@pytest.fixture(params=["sqlite", "postgresql", "mysql"])
|
||||
def database_type(request):
|
||||
"""Run tests with each database type."""
|
||||
return request.param
|
||||
|
||||
def test_connection(database_type):
|
||||
# This test runs 3 times, once for each database
|
||||
db = create_database(database_type)
|
||||
assert db.connect()
|
||||
```
|
||||
|
||||
### Fixture with IDs
|
||||
|
||||
```python
|
||||
@pytest.fixture(params=[
|
||||
pytest.param({"admin": True}, id="admin"),
|
||||
pytest.param({"admin": False}, id="regular"),
|
||||
])
|
||||
def user_config(request):
|
||||
return request.param
|
||||
|
||||
# Test output shows: test_permissions[admin], test_permissions[regular]
|
||||
```
|
||||
|
||||
### Indirect Parametrization
|
||||
|
||||
```python
|
||||
@pytest.fixture
|
||||
def user(request):
|
||||
"""Create user based on parameter."""
|
||||
role = request.param
|
||||
return User(role=role)
|
||||
|
||||
@pytest.mark.parametrize("user", ["admin", "editor", "viewer"], indirect=True)
|
||||
def test_user_access(user):
|
||||
# user fixture receives each role as request.param
|
||||
assert user.role in ["admin", "editor", "viewer"]
|
||||
```
|
||||
|
||||
## Conftest.py
|
||||
|
||||
### tests/conftest.py
|
||||
|
||||
```python
|
||||
"""Shared fixtures for all tests."""
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Session-scoped database engine
|
||||
@pytest.fixture(scope="session")
|
||||
def engine():
|
||||
return create_engine("sqlite:///:memory:")
|
||||
|
||||
# Function-scoped session with transaction rollback
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session(engine):
|
||||
connection = engine.connect()
|
||||
transaction = connection.begin()
|
||||
Session = sessionmaker(bind=connection)
|
||||
session = Session()
|
||||
|
||||
yield session
|
||||
|
||||
session.close()
|
||||
transaction.rollback()
|
||||
connection.close()
|
||||
|
||||
# Shared test data
|
||||
@pytest.fixture
|
||||
def sample_products():
|
||||
return [
|
||||
{"name": "Widget", "price": 9.99},
|
||||
{"name": "Gadget", "price": 19.99},
|
||||
{"name": "Gizmo", "price": 29.99},
|
||||
]
|
||||
```
|
||||
|
||||
### Nested Conftest Files
|
||||
|
||||
```
|
||||
tests/
|
||||
├── conftest.py # Available to all tests
|
||||
├── unit/
|
||||
│ ├── conftest.py # Available to unit tests only
|
||||
│ └── test_models.py
|
||||
└── integration/
|
||||
├── conftest.py # Available to integration tests only
|
||||
└── test_api.py
|
||||
```
|
||||
|
||||
## Built-in Fixtures
|
||||
|
||||
### tmp_path / tmp_path_factory
|
||||
|
||||
```python
|
||||
def test_create_file(tmp_path):
|
||||
"""tmp_path provides unique temporary directory."""
|
||||
file = tmp_path / "test.txt"
|
||||
file.write_text("Hello, World!")
|
||||
assert file.read_text() == "Hello, World!"
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def session_temp_dir(tmp_path_factory):
|
||||
"""Create temp dir for entire session."""
|
||||
return tmp_path_factory.mktemp("session_data")
|
||||
```
|
||||
|
||||
### capsys / capfd
|
||||
|
||||
```python
|
||||
def test_print_output(capsys):
|
||||
"""Capture stdout/stderr."""
|
||||
print("Hello")
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "Hello\n"
|
||||
|
||||
def test_file_descriptor_output(capfd):
|
||||
"""Capture at file descriptor level."""
|
||||
import os
|
||||
os.system("echo 'Hello'")
|
||||
captured = capfd.readouterr()
|
||||
assert "Hello" in captured.out
|
||||
```
|
||||
|
||||
### monkeypatch
|
||||
|
||||
```python
|
||||
def test_env_variable(monkeypatch):
|
||||
"""Modify environment temporarily."""
|
||||
monkeypatch.setenv("API_KEY", "test-key")
|
||||
assert os.environ["API_KEY"] == "test-key"
|
||||
|
||||
def test_module_attribute(monkeypatch):
|
||||
"""Modify module attribute temporarily."""
|
||||
monkeypatch.setattr("module.CONFIG", {"debug": True})
|
||||
assert module.CONFIG["debug"] is True
|
||||
|
||||
def test_dict_item(monkeypatch):
|
||||
"""Modify dictionary temporarily."""
|
||||
monkeypatch.setitem(app.settings, "DEBUG", True)
|
||||
```
|
||||
|
||||
### request
|
||||
|
||||
```python
|
||||
@pytest.fixture
|
||||
def resource(request):
|
||||
"""Access test context."""
|
||||
print(f"Running: {request.node.name}")
|
||||
print(f"Module: {request.module.__name__}")
|
||||
print(f"Function: {request.function.__name__}")
|
||||
|
||||
# Access fixture parameters
|
||||
if hasattr(request, "param"):
|
||||
return create_resource(request.param)
|
||||
|
||||
return create_default_resource()
|
||||
```
|
||||
|
||||
### Autouse Fixtures
|
||||
|
||||
```python
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_logging():
|
||||
"""Automatically runs for every test."""
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
yield
|
||||
logging.shutdown()
|
||||
|
||||
@pytest.fixture(autouse=True, scope="session")
|
||||
def global_setup():
|
||||
"""Run once at session start."""
|
||||
initialize_system()
|
||||
yield
|
||||
cleanup_system()
|
||||
```
|
||||
@@ -1,354 +0,0 @@
|
||||
# Mocking
|
||||
|
||||
## Table of Contents
|
||||
- [Basic Mocking](#basic-mocking)
|
||||
- [Patching](#patching)
|
||||
- [Mock Objects](#mock-objects)
|
||||
- [Async Mocking](#async-mocking)
|
||||
- [Fixture-Based Mocking](#fixture-based-mocking)
|
||||
- [Common Patterns](#common-patterns)
|
||||
|
||||
## Basic Mocking
|
||||
|
||||
### MagicMock Basics
|
||||
|
||||
```python
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
def test_with_mock():
|
||||
# Create a mock object
|
||||
mock_service = MagicMock()
|
||||
|
||||
# Configure return value
|
||||
mock_service.get_data.return_value = {"id": 1, "name": "Test"}
|
||||
|
||||
# Use mock
|
||||
result = mock_service.get_data()
|
||||
assert result["name"] == "Test"
|
||||
|
||||
# Verify call
|
||||
mock_service.get_data.assert_called_once()
|
||||
```
|
||||
|
||||
### Return Value Configuration
|
||||
|
||||
```python
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock = MagicMock()
|
||||
|
||||
# Simple return value
|
||||
mock.method.return_value = 42
|
||||
|
||||
# Different returns on consecutive calls
|
||||
mock.method.side_effect = [1, 2, 3]
|
||||
assert mock.method() == 1
|
||||
assert mock.method() == 2
|
||||
assert mock.method() == 3
|
||||
|
||||
# Raise exception
|
||||
mock.method.side_effect = ValueError("Error!")
|
||||
|
||||
# Dynamic return value
|
||||
mock.method.side_effect = lambda x: x * 2
|
||||
assert mock.method(5) == 10
|
||||
```
|
||||
|
||||
## Patching
|
||||
|
||||
### patch Decorator
|
||||
|
||||
```python
|
||||
from unittest.mock import patch
|
||||
|
||||
# Patch a function
|
||||
@patch("mymodule.external_api")
|
||||
def test_with_patched_api(mock_api):
|
||||
mock_api.return_value = {"status": "ok"}
|
||||
result = mymodule.call_api()
|
||||
assert result["status"] == "ok"
|
||||
|
||||
# Patch multiple
|
||||
@patch("mymodule.function_a")
|
||||
@patch("mymodule.function_b")
|
||||
def test_multiple_patches(mock_b, mock_a):
|
||||
# Note: decorators apply bottom-up
|
||||
mock_a.return_value = "a"
|
||||
mock_b.return_value = "b"
|
||||
```
|
||||
|
||||
### patch Context Manager
|
||||
|
||||
```python
|
||||
from unittest.mock import patch
|
||||
|
||||
def test_with_context_manager():
|
||||
with patch("mymodule.external_api") as mock_api:
|
||||
mock_api.return_value = {"data": "mocked"}
|
||||
result = mymodule.process()
|
||||
assert result == {"data": "mocked"}
|
||||
# Original function restored after with block
|
||||
```
|
||||
|
||||
### patch.object
|
||||
|
||||
```python
|
||||
from unittest.mock import patch
|
||||
|
||||
class MyService:
|
||||
def fetch_data(self):
|
||||
return "real data"
|
||||
|
||||
def test_patch_instance_method():
|
||||
service = MyService()
|
||||
|
||||
with patch.object(service, "fetch_data", return_value="mocked"):
|
||||
assert service.fetch_data() == "mocked"
|
||||
```
|
||||
|
||||
### Patching Where Used
|
||||
|
||||
```python
|
||||
# mymodule.py
|
||||
from requests import get # 'get' is imported here
|
||||
|
||||
def fetch_url(url):
|
||||
return get(url).text
|
||||
|
||||
# test_mymodule.py
|
||||
# Patch where it's USED, not where it's DEFINED
|
||||
@patch("mymodule.get") # NOT "requests.get"
|
||||
def test_fetch_url(mock_get):
|
||||
mock_get.return_value.text = "mocked response"
|
||||
result = fetch_url("http://example.com")
|
||||
assert result == "mocked response"
|
||||
```
|
||||
|
||||
## Mock Objects
|
||||
|
||||
### Spec and Autospec
|
||||
|
||||
```python
|
||||
from unittest.mock import MagicMock, create_autospec
|
||||
|
||||
class UserService:
|
||||
def get_user(self, user_id: int) -> dict:
|
||||
pass
|
||||
|
||||
def create_user(self, data: dict) -> dict:
|
||||
pass
|
||||
|
||||
# Basic mock (allows any attribute)
|
||||
mock = MagicMock()
|
||||
mock.nonexistent_method() # Works, but shouldn't
|
||||
|
||||
# Spec mock (restricts to real attributes)
|
||||
mock = MagicMock(spec=UserService)
|
||||
# mock.nonexistent_method() # Raises AttributeError
|
||||
|
||||
# Autospec (also checks signatures)
|
||||
mock = create_autospec(UserService)
|
||||
# mock.get_user() # Raises TypeError (missing user_id)
|
||||
mock.get_user(123) # Works
|
||||
```
|
||||
|
||||
### Assertion Methods
|
||||
|
||||
```python
|
||||
from unittest.mock import MagicMock, call
|
||||
|
||||
mock = MagicMock()
|
||||
mock.method(1, 2, key="value")
|
||||
mock.method(3, 4)
|
||||
|
||||
# Verify calls
|
||||
mock.method.assert_called()
|
||||
mock.method.assert_called_once() # Fails - called twice
|
||||
mock.method.assert_called_with(3, 4)
|
||||
mock.method.assert_any_call(1, 2, key="value")
|
||||
|
||||
# Check call count
|
||||
assert mock.method.call_count == 2
|
||||
|
||||
# Check all calls
|
||||
assert mock.method.call_args_list == [
|
||||
call(1, 2, key="value"),
|
||||
call(3, 4),
|
||||
]
|
||||
|
||||
# Reset mock
|
||||
mock.reset_mock()
|
||||
assert mock.method.call_count == 0
|
||||
```
|
||||
|
||||
## Async Mocking
|
||||
|
||||
### AsyncMock
|
||||
|
||||
```python
|
||||
from unittest.mock import AsyncMock, patch
|
||||
import pytest
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_mock():
|
||||
mock = AsyncMock(return_value={"data": "mocked"})
|
||||
result = await mock()
|
||||
assert result["data"] == "mocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("mymodule.async_fetch", new_callable=AsyncMock)
|
||||
async def test_patched_async(mock_fetch):
|
||||
mock_fetch.return_value = {"status": "ok"}
|
||||
result = await mymodule.process()
|
||||
assert result["status"] == "ok"
|
||||
```
|
||||
|
||||
### Async Side Effects
|
||||
|
||||
```python
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_side_effect():
|
||||
mock = AsyncMock()
|
||||
|
||||
# Return different values
|
||||
mock.side_effect = [1, 2, 3]
|
||||
assert await mock() == 1
|
||||
assert await mock() == 2
|
||||
|
||||
# Async function as side effect
|
||||
async def async_side_effect(x):
|
||||
return x * 2
|
||||
|
||||
mock.side_effect = async_side_effect
|
||||
assert await mock(5) == 10
|
||||
```
|
||||
|
||||
## Fixture-Based Mocking
|
||||
|
||||
### Mock Fixtures
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@pytest.fixture
|
||||
def mock_database():
|
||||
"""Fixture providing mock database."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value = [{"id": 1}, {"id": 2}]
|
||||
return mock_db
|
||||
|
||||
def test_with_mock_fixture(mock_database):
|
||||
result = mock_database.query("SELECT * FROM users")
|
||||
assert len(result) == 2
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_api():
|
||||
"""Fixture with patching."""
|
||||
with patch("mymodule.external_api") as mock:
|
||||
mock.return_value = {"status": "ok"}
|
||||
yield mock
|
||||
```
|
||||
|
||||
### Monkeypatch Fixture
|
||||
|
||||
```python
|
||||
def test_with_monkeypatch(monkeypatch):
|
||||
# Patch function
|
||||
monkeypatch.setattr("mymodule.get_config", lambda: {"debug": True})
|
||||
|
||||
# Patch environment variable
|
||||
monkeypatch.setenv("API_KEY", "test-key")
|
||||
|
||||
# Patch dictionary item
|
||||
monkeypatch.setitem(mymodule.settings, "DEBUG", True)
|
||||
|
||||
# Patch attribute
|
||||
monkeypatch.setattr(mymodule.client, "timeout", 5)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Mock HTTP Responses
|
||||
|
||||
```python
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
@patch("requests.get")
|
||||
def test_http_request(mock_get):
|
||||
# Configure mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"data": "test"}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = fetch_data("http://api.example.com")
|
||||
|
||||
assert result == {"data": "test"}
|
||||
mock_get.assert_called_once_with("http://api.example.com")
|
||||
```
|
||||
|
||||
### Mock File Operations
|
||||
|
||||
```python
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
def test_read_file():
|
||||
mock_file_content = "Hello, World!"
|
||||
|
||||
with patch("builtins.open", mock_open(read_data=mock_file_content)):
|
||||
result = read_file("test.txt")
|
||||
assert result == "Hello, World!"
|
||||
|
||||
def test_write_file():
|
||||
m = mock_open()
|
||||
|
||||
with patch("builtins.open", m):
|
||||
write_file("test.txt", "content")
|
||||
|
||||
m().write.assert_called_once_with("content")
|
||||
```
|
||||
|
||||
### Mock Datetime
|
||||
|
||||
```python
|
||||
from unittest.mock import patch
|
||||
from datetime import datetime
|
||||
|
||||
@patch("mymodule.datetime")
|
||||
def test_with_frozen_time(mock_datetime):
|
||||
mock_datetime.now.return_value = datetime(2024, 1, 15, 12, 0, 0)
|
||||
mock_datetime.side_effect = lambda *args, **kwargs: datetime(*args, **kwargs)
|
||||
|
||||
result = mymodule.get_timestamp()
|
||||
assert result == "2024-01-15 12:00:00"
|
||||
```
|
||||
|
||||
### Mock Class Instance
|
||||
|
||||
```python
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
class EmailService:
|
||||
def send(self, to, subject, body):
|
||||
# Real implementation
|
||||
pass
|
||||
|
||||
@patch("mymodule.EmailService")
|
||||
def test_email_sent(MockEmailService):
|
||||
# Configure mock instance
|
||||
mock_instance = MagicMock()
|
||||
MockEmailService.return_value = mock_instance
|
||||
|
||||
# Call function that uses EmailService
|
||||
send_notification("user@example.com", "Hello!")
|
||||
|
||||
# Verify email was sent
|
||||
mock_instance.send.assert_called_once_with(
|
||||
"user@example.com",
|
||||
"Notification",
|
||||
"Hello!",
|
||||
)
|
||||
```
|
||||
@@ -1,399 +0,0 @@
|
||||
---
|
||||
name: write-unit-tests
|
||||
description: Writing unit and integration tests for the tldraw SDK. Use when creating new tests, adding test coverage, or fixing failing tests in packages/editor or packages/tldraw. Covers Vitest patterns, TestEditor usage, and test file organization.
|
||||
---
|
||||
|
||||
# Writing tests
|
||||
|
||||
Unit and integration tests use Vitest. Tests run from workspace directories, not the repo root.
|
||||
|
||||
## Test file locations
|
||||
|
||||
**Unit tests** - alongside source files:
|
||||
|
||||
```
|
||||
packages/editor/src/lib/primitives/Vec.ts
|
||||
packages/editor/src/lib/primitives/Vec.test.ts # Same directory
|
||||
```
|
||||
|
||||
**Integration tests** - in `src/test/` directory:
|
||||
|
||||
```
|
||||
packages/tldraw/src/test/SelectTool.test.ts
|
||||
packages/tldraw/src/test/commands/createShape.test.ts
|
||||
```
|
||||
|
||||
**Shape/tool tests** - alongside the implementation:
|
||||
|
||||
```
|
||||
packages/tldraw/src/lib/shapes/arrow/ArrowShapeUtil.test.ts
|
||||
packages/tldraw/src/lib/shapes/arrow/ArrowShapeTool.test.ts
|
||||
```
|
||||
|
||||
## Which workspace to test in
|
||||
|
||||
- **packages/editor**: Core primitives, geometry, managers, base editor functionality
|
||||
- **packages/tldraw**: Anything needing default shapes/tools (most integration tests)
|
||||
|
||||
```bash
|
||||
cd packages/tldraw && yarn test run
|
||||
cd packages/tldraw && yarn test run --grep "SelectTool"
|
||||
```
|
||||
|
||||
## TestEditor vs Editor
|
||||
|
||||
Use `TestEditor` for integration tests (includes default shapes/tools):
|
||||
|
||||
```typescript
|
||||
import { createShapeId } from '@tldraw/editor'
|
||||
import { TestEditor } from './TestEditor'
|
||||
|
||||
let editor: TestEditor
|
||||
|
||||
beforeEach(() => {
|
||||
editor = new TestEditor()
|
||||
editor.selectAll().deleteShapes(editor.getSelectedShapeIds())
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
editor?.dispose()
|
||||
})
|
||||
```
|
||||
|
||||
Use raw `Editor` when testing editor setup or custom configurations:
|
||||
|
||||
```typescript
|
||||
import { Editor, createTLStore } from '@tldraw/editor'
|
||||
|
||||
beforeEach(() => {
|
||||
editor = new Editor({
|
||||
shapeUtils: [CustomShape],
|
||||
bindingUtils: [],
|
||||
tools: [CustomTool],
|
||||
store: createTLStore({ shapeUtils: [CustomShape], bindingUtils: [] }),
|
||||
getContainer: () => document.body,
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Common TestEditor methods
|
||||
|
||||
```typescript
|
||||
// Pointer simulation
|
||||
editor.pointerDown(x, y, options?)
|
||||
editor.pointerMove(x, y, options?)
|
||||
editor.pointerUp(x, y, options?)
|
||||
editor.click(x, y, shapeId?)
|
||||
editor.doubleClick(x, y, shapeId?)
|
||||
|
||||
// Keyboard simulation
|
||||
editor.keyDown(key, options?)
|
||||
editor.keyUp(key, options?)
|
||||
|
||||
// State assertions
|
||||
editor.expectToBeIn('select.idle')
|
||||
editor.expectToBeIn('select.crop.idle')
|
||||
|
||||
// Shape assertions
|
||||
editor.expectShapeToMatch({ id, x, y, props: { ... } })
|
||||
|
||||
// Shape operations
|
||||
editor.createShapes([{ id, type, x, y, props }])
|
||||
editor.updateShapes([{ id, type, props }])
|
||||
editor.getShape(id)
|
||||
editor.select(id1, id2)
|
||||
editor.selectAll()
|
||||
editor.selectNone()
|
||||
editor.getSelectedShapeIds()
|
||||
editor.getOnlySelectedShape()
|
||||
|
||||
// Tool operations
|
||||
editor.setCurrentTool('arrow')
|
||||
editor.getCurrentToolId()
|
||||
|
||||
// Undo/redo
|
||||
editor.undo()
|
||||
editor.redo()
|
||||
```
|
||||
|
||||
## Pointer event options
|
||||
|
||||
```typescript
|
||||
editor.pointerDown(100, 100, {
|
||||
target: 'shape', // 'canvas' | 'shape' | 'handle' | 'selection'
|
||||
shape: editor.getShape(id),
|
||||
})
|
||||
|
||||
editor.pointerDown(150, 300, {
|
||||
target: 'selection',
|
||||
handle: 'bottom', // 'top' | 'bottom' | 'left' | 'right' | corners
|
||||
})
|
||||
|
||||
editor.doubleClick(550, 550, {
|
||||
target: 'selection',
|
||||
handle: 'bottom_right',
|
||||
})
|
||||
```
|
||||
|
||||
## Setup patterns
|
||||
|
||||
### Standard setup with shape IDs
|
||||
|
||||
```typescript
|
||||
const ids = {
|
||||
box1: createShapeId('box1'),
|
||||
box2: createShapeId('box2'),
|
||||
arrow1: createShapeId('arrow1'),
|
||||
}
|
||||
|
||||
vi.useFakeTimers()
|
||||
|
||||
beforeEach(() => {
|
||||
editor = new TestEditor()
|
||||
editor.selectAll().deleteShapes(editor.getSelectedShapeIds())
|
||||
editor.createShapes([
|
||||
{ id: ids.box1, type: 'geo', x: 100, y: 100, props: { w: 100, h: 100 } },
|
||||
{ id: ids.box2, type: 'geo', x: 300, y: 300, props: { w: 100, h: 100 } },
|
||||
])
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
editor?.dispose()
|
||||
})
|
||||
```
|
||||
|
||||
### Reusable props
|
||||
|
||||
```typescript
|
||||
const imageProps = {
|
||||
assetId: null,
|
||||
playing: true,
|
||||
url: '',
|
||||
w: 1200,
|
||||
h: 800,
|
||||
}
|
||||
|
||||
editor.createShapes([
|
||||
{ id: ids.imageA, type: 'image', x: 100, y: 100, props: imageProps },
|
||||
{ id: ids.imageB, type: 'image', x: 500, y: 500, props: { ...imageProps, w: 600, h: 400 } },
|
||||
])
|
||||
```
|
||||
|
||||
### Helper functions
|
||||
|
||||
```typescript
|
||||
function arrow(id = ids.arrow1) {
|
||||
return editor.getShape(id) as TLArrowShape
|
||||
}
|
||||
|
||||
function bindings(id = ids.arrow1) {
|
||||
return getArrowBindings(editor, arrow(id))
|
||||
}
|
||||
```
|
||||
|
||||
## Mocking with vi.spyOn
|
||||
|
||||
```typescript
|
||||
// Mock return value
|
||||
vi.spyOn(editor, 'getIsReadonly').mockReturnValue(true)
|
||||
|
||||
// Mock implementation
|
||||
const isHiddenSpy = vi.spyOn(editor, 'isShapeHidden')
|
||||
isHiddenSpy.mockImplementation((shape) => shape.id === ids.hiddenShape)
|
||||
|
||||
// Verify calls
|
||||
const spy = vi.spyOn(editor, 'setSelectedShapes')
|
||||
editor.selectAll()
|
||||
expect(spy).toHaveBeenCalled()
|
||||
expect(spy).not.toHaveBeenCalled()
|
||||
|
||||
// Always restore
|
||||
isHiddenSpy.mockRestore()
|
||||
```
|
||||
|
||||
## Fake timers
|
||||
|
||||
```typescript
|
||||
vi.useFakeTimers()
|
||||
|
||||
// Mock animation frame
|
||||
window.requestAnimationFrame = (cb) => setTimeout(cb, 1000 / 60)
|
||||
window.cancelAnimationFrame = (id) => clearTimeout(id)
|
||||
|
||||
it('handles animation', () => {
|
||||
editor.alignShapes(editor.getSelectedShapeIds(), 'right')
|
||||
vi.advanceTimersByTime(1000)
|
||||
// Assert after animation completes
|
||||
})
|
||||
```
|
||||
|
||||
## Assertions
|
||||
|
||||
### Shape matching
|
||||
|
||||
```typescript
|
||||
// Partial matching (most common)
|
||||
expect(editor.getShape(id)).toMatchObject({
|
||||
type: 'geo',
|
||||
x: 100,
|
||||
props: { w: 100 },
|
||||
})
|
||||
|
||||
editor.expectShapeToMatch({
|
||||
id: ids.box1,
|
||||
x: 350,
|
||||
y: 350,
|
||||
})
|
||||
|
||||
// Floating point matching (custom matcher)
|
||||
expect(result).toCloselyMatchObject({
|
||||
props: { normalizedAnchor: { x: 0.5, y: 0.75 } },
|
||||
})
|
||||
```
|
||||
|
||||
### Array assertions
|
||||
|
||||
```typescript
|
||||
expect(editor.getSelectedShapeIds()).toMatchObject([ids.box1])
|
||||
expect(Array.from(selectedIds).sort()).toEqual([id1, id2, id3].sort())
|
||||
expect(shapes).toContain('geo')
|
||||
expect(shapes).not.toContain(ids.lockedShape)
|
||||
```
|
||||
|
||||
### State assertions
|
||||
|
||||
```typescript
|
||||
editor.expectToBeIn('select.idle')
|
||||
editor.expectToBeIn('select.brushing')
|
||||
editor.expectToBeIn('select.crop.idle')
|
||||
```
|
||||
|
||||
## Testing undo/redo
|
||||
|
||||
```typescript
|
||||
it('handles undo/redo', () => {
|
||||
editor.doubleClick(550, 550, ids.image)
|
||||
editor.expectToBeIn('select.crop.idle')
|
||||
|
||||
editor.updateShape({ id: ids.image, type: 'image', props: { crop: newCrop } })
|
||||
|
||||
editor.undo()
|
||||
editor.expectToBeIn('select.crop.idle')
|
||||
expect(editor.getShape(ids.image)!.props.crop).toMatchObject(originalCrop)
|
||||
|
||||
editor.redo()
|
||||
expect(editor.getShape(ids.image)!.props.crop).toMatchObject(newCrop)
|
||||
})
|
||||
```
|
||||
|
||||
## Testing TypeScript types
|
||||
|
||||
```typescript
|
||||
it('Uses typescript generics', () => {
|
||||
expect(() => {
|
||||
// @ts-expect-error - wrong props type
|
||||
editor.createShape({ id, type: 'geo', props: { w: 'OH NO' } })
|
||||
|
||||
// @ts-expect-error - unknown prop
|
||||
editor.createShape({ id, type: 'geo', props: { foo: 'bar' } })
|
||||
|
||||
// Valid
|
||||
editor.createShape<TLGeoShape>({ id, type: 'geo', props: { w: 100 } })
|
||||
}).toThrow()
|
||||
})
|
||||
```
|
||||
|
||||
## Testing custom shapes
|
||||
|
||||
```typescript
|
||||
declare module '@tldraw/tlschema' {
|
||||
export interface TLGlobalShapePropsMap {
|
||||
'my-custom-shape': { w: number; h: number; text: string | undefined }
|
||||
}
|
||||
}
|
||||
|
||||
class CustomShape extends ShapeUtil<ICustomShape> {
|
||||
static override type = 'my-custom-shape'
|
||||
static override props: RecordProps<ICustomShape> = {
|
||||
w: T.number,
|
||||
h: T.number,
|
||||
text: T.string.optional(),
|
||||
}
|
||||
getDefaultProps() {
|
||||
return { w: 200, h: 200, text: '' }
|
||||
}
|
||||
getGeometry(shape) {
|
||||
return new Rectangle2d({ width: shape.props.w, height: shape.props.h })
|
||||
}
|
||||
indicator() {}
|
||||
component() {}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing side effects
|
||||
|
||||
```typescript
|
||||
beforeEach(() => {
|
||||
editor = new TestEditor()
|
||||
editor.sideEffects.registerAfterChangeHandler('instance_page_state', (prev, next) => {
|
||||
if (prev.croppingShapeId !== next.croppingShapeId) {
|
||||
// Handle state change
|
||||
}
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
## Testing events
|
||||
|
||||
```typescript
|
||||
it('emits wheel events', () => {
|
||||
const handler = vi.fn()
|
||||
editor.on('event', handler)
|
||||
|
||||
editor.dispatch({
|
||||
type: 'wheel',
|
||||
name: 'wheel',
|
||||
delta: { x: 0, y: 10, z: 0 },
|
||||
point: { x: 100, y: 100, z: 1 },
|
||||
shiftKey: false,
|
||||
// ... other modifiers
|
||||
})
|
||||
editor.emit('tick', 16) // Flush batched events
|
||||
|
||||
expect(handler).toHaveBeenCalledWith(expect.objectContaining({ name: 'wheel' }))
|
||||
})
|
||||
```
|
||||
|
||||
## Method chaining
|
||||
|
||||
```typescript
|
||||
editor
|
||||
.expectToBeIn('select.idle')
|
||||
.select(ids.imageA, ids.imageB)
|
||||
.doubleClick(550, 550, { target: 'selection', handle: 'bottom_right' })
|
||||
.expectToBeIn('select.idle')
|
||||
|
||||
editor.setCurrentTool('arrow').pointerDown(0, 0).pointerMove(100, 100).pointerUp()
|
||||
```
|
||||
|
||||
## Running tests
|
||||
|
||||
```bash
|
||||
cd packages/tldraw && yarn test run
|
||||
cd packages/tldraw && yarn test run --grep "arrow"
|
||||
cd packages/editor && yarn test run --grep "Vec"
|
||||
|
||||
# Watch mode
|
||||
cd packages/tldraw && yarn test
|
||||
```
|
||||
|
||||
## Key patterns summary
|
||||
|
||||
- Use `createShapeId()` for shape IDs
|
||||
- Use `vi.useFakeTimers()` for time-dependent behavior
|
||||
- Clear shapes in `beforeEach`, dispose in `afterEach`
|
||||
- Test in `packages/tldraw` for shapes/tools
|
||||
- Use `expectToBeIn()` for state machine assertions
|
||||
- Use `toMatchObject()` for partial matching
|
||||
- Use `toCloselyMatchObject()` for floating point values
|
||||
- Mock with `vi.spyOn()` and always `mockRestore()`
|
||||
@@ -1 +0,0 @@
|
||||
../../.agents/skills/api-documentation-generator
|
||||
@@ -1 +0,0 @@
|
||||
../../.agents/skills/pytest
|
||||
@@ -127,7 +127,7 @@ spec:
|
||||
rules:
|
||||
- http:
|
||||
paths:
|
||||
- path: /ai
|
||||
- path: /ailbl
|
||||
pathType: Prefix
|
||||
backend:
|
||||
service:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
name: "K8S Fission Deployment"
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
branches: [ 'main', 'ai' ]
|
||||
jobs:
|
||||
deployment-fission:
|
||||
name: Deployment fission functions
|
||||
@@ -12,75 +12,67 @@ jobs:
|
||||
FISSION_VER: 1.21.0
|
||||
RAKE_VER: 0.1.7
|
||||
steps:
|
||||
- name: ☸️ Setup kubectl
|
||||
uses: azure/setup-kubectl@v4
|
||||
|
||||
- name: 🐍 Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache-dependency-path: "apps/requirements.txt"
|
||||
|
||||
- name: 📥 Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: 📦 Install dependencies and run tests
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r apps/requirements.txt
|
||||
pip install pytest pytest-mock
|
||||
pytest apps/tests/ -v --tb=short
|
||||
|
||||
- name: 🔄 Cache
|
||||
id: cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
/usr/local/bin/rake
|
||||
/usr/local/bin/fission
|
||||
key: ${{ runner.os }}-${{ github.event.repository.name }}-${{ hashFiles('.fission/deployment.json') }}
|
||||
- name: ☘️ Configure Kubeconfig
|
||||
uses: azure/k8s-set-context@v4
|
||||
with:
|
||||
method: kubeconfig
|
||||
kubeconfig: ${{ secrets[format('{0}_KUBECONFIG', env.FISSION_PROFILE)] }}
|
||||
- name: 🔄 Install Dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
curl -L "https://${{ secrets.REGISTRY_PASSWORD }}@registry.vegastar.vn/vegacloud/make/releases/download/${RAKE_VER}/rake-${RAKE_VER}-x86_64-unknown-linux-musl.tar.gz" | tar xzv -C /tmp/
|
||||
curl -L "https://github.com/fission/fission/releases/download/v${FISSION_VER}/fission-v${FISSION_VER}-linux-amd64" --output /tmp/fission
|
||||
install -o root -g root -m 0755 /tmp/rake-${RAKE_VER}-x86_64-unknown-linux-musl/rake /usr/local/bin/rake
|
||||
install -o root -g root -m 0755 /tmp/fission /usr/local/bin/fission
|
||||
fission check
|
||||
# rake cfg install fission -f
|
||||
- name: 🕓 Checkout the previous codes
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.before }}
|
||||
- name: ♻️ Remove the previous version
|
||||
# continue-on-error: true
|
||||
run: |
|
||||
echo "use profile [$FISSION_PROFILE]"
|
||||
mkdir -p manifests || true
|
||||
rake sec detail && rake cfm detail && rake env detail && rake pkg detail && rake fn detail && rake ht detail
|
||||
rake sp build -fi && rake sp down -i
|
||||
- name: 🔎 Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
- name: ✨ Deploy the new version
|
||||
id: deploy
|
||||
run: |
|
||||
echo "use profile [$FISSION_PROFILE]"
|
||||
mkdir -p manifests || true
|
||||
rake sec detail && rake cfm detail && rake env detail && rake pkg detail && rake fn detail && rake ht detail
|
||||
rake sp build -fi && rake sp up -i
|
||||
- name: 🔔 Send notification
|
||||
uses: appleboy/telegram-action@master
|
||||
if: always()
|
||||
with:
|
||||
to: ${{ secrets.TELEGRAM_TO }}
|
||||
token: ${{ secrets.TELEGRAM_TOKEN }}
|
||||
format: markdown
|
||||
socks5: ${{ secrets.TELEGRAM_PROXY_URL != '' && secrets.TELEGRAM_PROXY_URL || '' }}
|
||||
message: |
|
||||
${{ steps.deploy.outcome == 'success' && '🟢 (=^ ◡ ^=)' || '🔴 (。•́︿•̀。)' }} Install fn ${{ github.event.repository.name }}
|
||||
*Msg*: `${{ github.event.commits[0].message }}`
|
||||
- name: ☸️ Setup kubectl
|
||||
uses: azure/setup-kubectl@v4
|
||||
- name: 🔄 Cache
|
||||
id: cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
/usr/local/bin/rake
|
||||
/usr/local/bin/fission
|
||||
key: ${{ runner.os }}-${{ github.event.repository.name }}-${{ hashFiles('.fission/deployment.json') }}
|
||||
- name: ☘️ Configure Kubeconfig
|
||||
uses: azure/k8s-set-context@v4
|
||||
with:
|
||||
method: kubeconfig
|
||||
kubeconfig: ${{ secrets[format('{0}_KUBECONFIG', env.FISSION_PROFILE)] }}
|
||||
- name: 🔄 Install Dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
curl -L "https://${{ secrets.REGISTRY_PASSWORD }}@registry.vegastar.vn/vegacloud/make/releases/download/${RAKE_VER}/rake-${RAKE_VER}-x86_64-unknown-linux-musl.tar.gz" | tar xzv -C /tmp/
|
||||
curl -L "https://github.com/fission/fission/releases/download/v${FISSION_VER}/fission-v${FISSION_VER}-linux-amd64" --output /tmp/fission
|
||||
install -o root -g root -m 0755 /tmp/rake-${RAKE_VER}-x86_64-unknown-linux-musl/rake /usr/local/bin/rake
|
||||
install -o root -g root -m 0755 /tmp/fission /usr/local/bin/fission
|
||||
fission check
|
||||
# rake cfg install fission -f
|
||||
- name: 🕓 Checkout the previous codes
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.before }}
|
||||
- name: ♻️ Remove the previous version
|
||||
# continue-on-error: true
|
||||
run: |
|
||||
echo "use profile [$FISSION_PROFILE]"
|
||||
mkdir -p manifests || true
|
||||
rake sec detail && rake cfm detail && rake env detail && rake pkg detail && rake fn detail && rake ht detail
|
||||
rake sp build -fi && rake sp down -i
|
||||
- name: 🔎 Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
- name: 🐍 Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: 🧪 Run tests
|
||||
run: |
|
||||
cd apps
|
||||
pip install -r requirements-dev.txt
|
||||
pytest tests/ -v --tb=short
|
||||
- name: ✨ Deploy the new version
|
||||
id: deploy
|
||||
run: |
|
||||
echo "use profile [$FISSION_PROFILE]"
|
||||
mkdir -p manifests || true
|
||||
rake sec detail && rake cfm detail && rake env detail && rake pkg detail && rake fn detail && rake ht detail
|
||||
rake sp build -fi && rake sp up -i
|
||||
- name: 🔔 Send notification
|
||||
uses: appleboy/telegram-action@master
|
||||
if: always()
|
||||
with:
|
||||
to: ${{ secrets.TELEGRAM_TO }}
|
||||
token: ${{ secrets.TELEGRAM_TOKEN }}
|
||||
format: markdown
|
||||
socks5: ${{ secrets.TELEGRAM_PROXY_URL != '' && secrets.TELEGRAM_PROXY_URL || '' }}
|
||||
message: |
|
||||
${{ steps.deploy.outcome == 'success' && '🟢 (=^ ◡ ^=)' || '🔴 (。•́︿•̀。)' }} Install fn ${{ github.event.repository.name }}
|
||||
*Msg*: `${{ github.event.commits[0].message }}`
|
||||
|
||||
@@ -18,24 +18,6 @@ jobs:
|
||||
run: echo "K8S_PROFILE=`echo ${GITHUB_REF_NAME:-${GITHUB_REF#refs/heads/}} | tr '[:lower:]' '[:upper:]'`" >> $GITHUB_ENV
|
||||
- name: ☸️ Setup kubectl
|
||||
uses: azure/setup-kubectl@v4
|
||||
|
||||
- name: 🐍 Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache-dependency-path: 'apps/requirements.txt'
|
||||
|
||||
- name: 📥 Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: 📦 Install dependencies and run tests
|
||||
run: |
|
||||
cd apps
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install pytest pytest-mock
|
||||
pytest tests/ -v --tb=short
|
||||
|
||||
- name: 🛠️ Configure Kubeconfig
|
||||
uses: azure/k8s-set-context@v4
|
||||
with:
|
||||
|
||||
3437
2026-01-27-implement-the-following-plan.txt
Normal file
3437
2026-01-27-implement-the-following-plan.txt
Normal file
File diff suppressed because it is too large
Load Diff
79
CLAUDE.md
Normal file
79
CLAUDE.md
Normal file
@@ -0,0 +1,79 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
This is a Fission-based serverless Python microservice for AI user administration. It runs on Kubernetes with PostgreSQL as the data store.
|
||||
|
||||
**Stack:** Python 3.10, Flask, Fission FaaS, PostgreSQL, Pydantic
|
||||
|
||||
## Build & Deploy Commands
|
||||
|
||||
```bash
|
||||
# Build (installs Python dependencies)
|
||||
cd apps && ./build.sh
|
||||
|
||||
# Deploy to Kubernetes (reconciles all Fission specs)
|
||||
fission spec apply
|
||||
|
||||
# Watch and redeploy on changes
|
||||
fission spec apply --watch
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
Two Fission functions handle all HTTP endpoints:
|
||||
|
||||
| Function | Routes | Operations |
|
||||
|----------|--------|------------|
|
||||
| `ai-admin-filter-create-user` | `GET/POST /ai/admin/users` | Filter users with pagination, create new user |
|
||||
| `ai-admin-update-delete-user` | `PUT/DELETE /ai/admin/users/{UserID}` | Update/delete user by ID |
|
||||
|
||||
**Source structure (`apps/`):**
|
||||
- `filter_insert.py` - GET (filter) & POST (create) handler with `main()` entry point
|
||||
- `update_delete.py` - PUT (update) & DELETE handler with `main()` entry point
|
||||
- `schemas.py` - Pydantic models (`AiUserCreate`, `AiUserUpdate`) for validation
|
||||
- `helpers.py` - Database connection, K8s secrets, CORS headers, utilities
|
||||
- `vault.py` - PyNaCl symmetric encryption (`encrypt_vault`/`decrypt_vault`)
|
||||
|
||||
**Deployment configs (`.fission/`):**
|
||||
- `local-deployment.json`, `dev-deployment.json`, `test-deployment.json`, `staging-deployment.json`, `deployment.json`
|
||||
|
||||
**Fission specs (`specs/`):**
|
||||
- `env-work-py.yaml` - Python 3.10 runtime environment
|
||||
- `package-ai-work.yaml` - Build configuration
|
||||
- Function and HTTP trigger definitions
|
||||
|
||||
## Key Patterns
|
||||
|
||||
**Error handling:**
|
||||
```python
|
||||
try:
|
||||
conn = init_db_connection()
|
||||
# operations
|
||||
except ValidationError as e:
|
||||
return jsonify({"errorCode": "VALIDATION_ERROR", "details": e.errors()}), 400, CORS_HEADERS
|
||||
except IntegrityError:
|
||||
return jsonify({"errorCode": "DUPLICATE_TAG", ...}), 409, CORS_HEADERS
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
```
|
||||
|
||||
**Dynamic SQL filtering:** Build conditions list and values dict, join with AND for WHERE clause.
|
||||
|
||||
**Fission route params:** Extracted from headers (e.g., `X-Fission-Params-UserID`).
|
||||
|
||||
**Concurrent updates:** Uses PostgreSQL row-level locking (`FOR UPDATE`).
|
||||
|
||||
## Secrets & Configuration
|
||||
|
||||
Secrets are read from K8s mounted volumes via `helpers.get_secret()` and `helpers.get_config()`, not environment variables. PostgreSQL credentials come from the `fission-ai-work-env` secret.
|
||||
|
||||
## Function Configuration
|
||||
|
||||
- Executor: `newdeploy` (dedicated pod per function)
|
||||
- Timeout: 300 seconds
|
||||
- Min/Max Scale: 1
|
||||
- Concurrency: 500 requests per pod
|
||||
@@ -24,7 +24,7 @@ def main():
|
||||
"fntimeout": 300,
|
||||
"http_triggers": {
|
||||
"ai-admin-filter-create-user-http": {
|
||||
"url": "/ai/admin/users",
|
||||
"url": "/ailbl/ai/admin/users",
|
||||
"methods": ["POST", "GET"]
|
||||
}
|
||||
}
|
||||
@@ -44,34 +44,40 @@ def main():
|
||||
|
||||
|
||||
def make_insert_request():
|
||||
"""
|
||||
Handle POST request to create a new AI user.
|
||||
r"""make_insert_request() -> tuple[Response, int, dict]
|
||||
|
||||
Validates the request body using AiUserCreate schema, inserts a new record
|
||||
into the public.ai_user table, and returns the created user data.
|
||||
Create a new user from the request JSON body.
|
||||
|
||||
Validates the request body using :class:`AiUserCreate` schema, generates
|
||||
a UUID for the new user, and inserts the record into the database.
|
||||
|
||||
Returns:
|
||||
tuple: (json_response, status_code, headers)
|
||||
- 201: User created successfully
|
||||
- 400: Validation error in request body
|
||||
- 409: Duplicate entry violation
|
||||
- 500: Internal server error
|
||||
tuple: A tuple containing:
|
||||
- JSON response with created user data or error details
|
||||
- HTTP status code (201 on success, 400/409/500 on error)
|
||||
- CORS headers dict
|
||||
|
||||
Raises:
|
||||
ValidationError: If request body fails Pydantic validation (returns 400).
|
||||
IntegrityError: If email already exists in database (returns 409).
|
||||
|
||||
Example::
|
||||
|
||||
>>> # POST /ai/admin/users
|
||||
>>> # Body: {"name": "John Doe", "email": "john@example.com"}
|
||||
>>> # Response: 201 Created
|
||||
>>> {
|
||||
... "id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
... "name": "John Doe",
|
||||
... "email": "john@example.com",
|
||||
... "created": "2024-01-01T10:00:00",
|
||||
... "modified": "2024-01-01T10:00:00"
|
||||
... }
|
||||
"""
|
||||
try:
|
||||
body = AiUserCreate(**(request.get_json(silent=True) or {}))
|
||||
except ValidationError as e:
|
||||
errors = []
|
||||
for err in e.errors():
|
||||
errors.append({
|
||||
"loc": err.get("loc", []),
|
||||
"msg": err.get("msg", ""),
|
||||
"type": err.get("type", ""),
|
||||
})
|
||||
return (
|
||||
jsonify({"errorCode": "VALIDATION_ERROR", "details": errors}),
|
||||
400,
|
||||
CORS_HEADERS,
|
||||
)
|
||||
return jsonify({"errorCode": "VALIDATION_ERROR", "details": e.errors()}), 400, CORS_HEADERS
|
||||
|
||||
sql = """
|
||||
INSERT INTO public.ai_user (id, name, dob, email, gender)
|
||||
@@ -83,19 +89,13 @@ def make_insert_request():
|
||||
conn = init_db_connection()
|
||||
with conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
sql,
|
||||
(str(uuid.uuid4()), body.name, body.dob, body.email, body.gender),
|
||||
)
|
||||
cur.execute(sql, (str(uuid.uuid4()), body.name,
|
||||
body.dob, body.email, body.gender))
|
||||
row = cur.fetchone()
|
||||
return jsonify(db_row_to_dict(cur, row)), 201, CORS_HEADERS
|
||||
except IntegrityError as e:
|
||||
# vi phạm unique(tag,kind,ref)
|
||||
return (
|
||||
jsonify({"errorCode": "DUPLICATE_TAG", "details": str(e)}),
|
||||
409,
|
||||
CORS_HEADERS,
|
||||
)
|
||||
return jsonify({"errorCode": "DUPLICATE_TAG", "details": str(e)}), 409, CORS_HEADERS
|
||||
except Exception as err:
|
||||
return jsonify({"error": str(err)}), 500, CORS_HEADERS
|
||||
finally:
|
||||
@@ -104,16 +104,43 @@ def make_insert_request():
|
||||
|
||||
|
||||
def make_filter_request():
|
||||
"""
|
||||
Handle GET request to filter and list AI users.
|
||||
r"""make_filter_request() -> Response
|
||||
|
||||
Retrieves pagination parameters from request queries, executes the filter
|
||||
query against the database, and returns a list of matching users.
|
||||
Filter and paginate users based on query parameters.
|
||||
|
||||
Builds a dynamic SQL query based on filter parameters from the request,
|
||||
executes the query with pagination, and returns the matching users.
|
||||
|
||||
Query Parameters:
|
||||
page (int): Page number (0-indexed). Default: ``0``
|
||||
size (int): Number of records per page. Default: ``8``
|
||||
sortby (str): Field to sort by (``created`` or ``modified``).
|
||||
asc (bool): Sort ascending if ``True``, descending if ``False``.
|
||||
filter[ids] (list): Filter by specific user IDs.
|
||||
filter[keyword] (str): Search in name and email fields.
|
||||
filter[name] (str): Filter by name (case-insensitive partial match).
|
||||
filter[email] (str): Filter by email (case-insensitive partial match).
|
||||
filter[created_from] (str): Filter by creation date (from).
|
||||
filter[created_to] (str): Filter by creation date (to).
|
||||
filter[dob_from] (str): Filter by date of birth (from).
|
||||
filter[dob_to] (str): Filter by date of birth (to).
|
||||
|
||||
Returns:
|
||||
tuple: (json_response, status_code, headers)
|
||||
- 200: Successfully retrieved users list
|
||||
- 500: Internal server error
|
||||
Response: JSON array of user records with pagination metadata.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # GET /ai/admin/users?page=0&size=10&filter[name]=john
|
||||
>>> # Response: 200 OK
|
||||
>>> [
|
||||
... {
|
||||
... "id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
... "name": "John Doe",
|
||||
... "email": "john@example.com",
|
||||
... "count": 100,
|
||||
... "total": 5
|
||||
... }
|
||||
... ]
|
||||
"""
|
||||
paging = UserPage.from_request_queries()
|
||||
|
||||
@@ -130,18 +157,18 @@ def make_filter_request():
|
||||
|
||||
|
||||
def __filter_users(cursor, paging: "UserPage"):
|
||||
"""
|
||||
Build and execute SQL query to filter users based on pagination and filter criteria.
|
||||
|
||||
Constructs dynamic WHERE clause from UserFilter attributes, applies sorting
|
||||
and pagination, and returns matching database records.
|
||||
r"""Build and execute SQL query for filtering users.
|
||||
|
||||
Args:
|
||||
cursor: Database cursor for executing queries.
|
||||
paging: UserPage object containing pagination and filter parameters.
|
||||
cursor: Database cursor object for executing queries.
|
||||
paging (UserPage): Pagination and filter parameters.
|
||||
|
||||
Returns:
|
||||
list: List of user records matching the filter criteria.
|
||||
list[dict]: List of user records as dictionaries, including
|
||||
``count`` (total records) and ``total`` (filtered count).
|
||||
|
||||
Note:
|
||||
This is a private function. Use :func:`make_filter_request` instead.
|
||||
"""
|
||||
conditions = []
|
||||
values = {}
|
||||
@@ -218,21 +245,20 @@ def __filter_users(cursor, paging: "UserPage"):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Page:
|
||||
r"""Base pagination parameters for list queries.
|
||||
|
||||
Attributes:
|
||||
page (int, optional): Page number (0-indexed). Default: ``0``
|
||||
size (int, optional): Number of records per page. Default: ``8``
|
||||
asc (bool, optional): Sort order. ``True`` for ascending, ``False`` for descending.
|
||||
"""
|
||||
|
||||
page: typing.Optional[int] = None
|
||||
size: typing.Optional[int] = None
|
||||
asc: typing.Optional[bool] = None
|
||||
|
||||
@classmethod
|
||||
def from_request_queries(cls) -> "Page":
|
||||
"""
|
||||
Create Page instance from HTTP request query parameters.
|
||||
|
||||
Extracts 'page', 'size', and 'asc' parameters from the request URL.
|
||||
Defaults: page=0, size=8, asc=None.
|
||||
|
||||
Returns:
|
||||
Page: A new Page instance with values from query parameters.
|
||||
"""
|
||||
paging = Page()
|
||||
paging.page = int(request.args.get("page", 0))
|
||||
paging.size = int(request.args.get("size", 8))
|
||||
@@ -242,6 +268,22 @@ class Page:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UserFilter:
|
||||
r"""Filter parameters for user queries.
|
||||
|
||||
Attributes:
|
||||
ids (list[str], optional): Filter by specific user IDs.
|
||||
keyword (str, optional): Search keyword for name and email fields.
|
||||
name (str, optional): Filter by name (case-insensitive partial match).
|
||||
email (str, optional): Filter by email (case-insensitive partial match).
|
||||
gender (str, optional): Filter by gender.
|
||||
created_from (str, optional): Filter users created on or after this date.
|
||||
created_to (str, optional): Filter users created on or before this date.
|
||||
modified_from (str, optional): Filter users modified on or after this date.
|
||||
modified_to (str, optional): Filter users modified on or before this date.
|
||||
dob_from (str, optional): Filter users with DOB on or after this date.
|
||||
dob_to (str, optional): Filter users with DOB on or before this date.
|
||||
"""
|
||||
|
||||
ids: typing.Optional[typing.List[str]] = None
|
||||
keyword: typing.Optional[str] = None
|
||||
name: typing.Optional[str] = None
|
||||
@@ -256,15 +298,6 @@ class UserFilter:
|
||||
|
||||
@classmethod
|
||||
def from_request_queries(cls) -> "UserFilter":
|
||||
"""
|
||||
Create UserFilter instance from HTTP request query parameters.
|
||||
|
||||
Extracts filter parameters with 'filter[...]' prefix from the request URL.
|
||||
Supports filtering by ids, keyword, name, email, gender, date ranges.
|
||||
|
||||
Returns:
|
||||
UserFilter: A new UserFilter instance with values from query parameters.
|
||||
"""
|
||||
filter = UserFilter()
|
||||
filter.ids = request.args.getlist("filter[ids]")
|
||||
filter.keyword = request.args.get("filter[keyword]")
|
||||
@@ -281,12 +314,39 @@ class UserFilter:
|
||||
|
||||
|
||||
class UserSortField(str, enum.Enum):
|
||||
r"""Allowed sort fields for user queries.
|
||||
|
||||
Attributes:
|
||||
CREATED: Sort by creation timestamp.
|
||||
MODIFIED: Sort by last modification timestamp.
|
||||
"""
|
||||
|
||||
CREATED = "created"
|
||||
MODIFIED = "modified"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UserPage(Page):
|
||||
r"""Pagination parameters with user-specific filters and sorting.
|
||||
|
||||
Extends :class:`Page` with user filtering and sorting capabilities.
|
||||
|
||||
Attributes:
|
||||
sortby (UserSortField, optional): Field to sort results by.
|
||||
See :class:`UserSortField` for allowed values.
|
||||
filter (UserFilter, optional): Filter parameters for the query.
|
||||
Default: Parsed from request query parameters.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # Parse from request: GET /users?page=1&size=20&sortby=created&asc=true
|
||||
>>> paging = UserPage.from_request_queries()
|
||||
>>> paging.page
|
||||
1
|
||||
>>> paging.sortby
|
||||
<UserSortField.CREATED: 'created'>
|
||||
"""
|
||||
|
||||
sortby: typing.Optional[UserSortField] = None
|
||||
filter: typing.Optional[UserFilter] = dataclasses.field(
|
||||
default_factory=UserFilter.from_request_queries
|
||||
@@ -294,15 +354,6 @@ class UserPage(Page):
|
||||
|
||||
@classmethod
|
||||
def from_request_queries(cls) -> "UserPage":
|
||||
"""
|
||||
Create UserPage instance from HTTP request query parameters.
|
||||
|
||||
Combines pagination (page, size, asc) and filter parameters from the request URL.
|
||||
Also parses 'sortby' parameter to UserSortField enum.
|
||||
|
||||
Returns:
|
||||
UserPage: A new UserPage instance with all query parameters.
|
||||
"""
|
||||
base = super(UserPage, cls).from_request_queries()
|
||||
paging = UserPage(**dataclasses.asdict(base))
|
||||
|
||||
|
||||
138
apps/helpers.py
138
apps/helpers.py
@@ -19,6 +19,31 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def init_db_connection():
|
||||
r"""init_db_connection() -> psycopg2.extensions.connection
|
||||
|
||||
Initialize and return a PostgreSQL database connection.
|
||||
|
||||
Reads connection parameters from Kubernetes secrets and establishes
|
||||
a connection to the PostgreSQL database with logging enabled.
|
||||
|
||||
Returns:
|
||||
psycopg2.extensions.connection: Active database connection with
|
||||
:class:`LoggingConnection` factory for query logging.
|
||||
|
||||
Raises:
|
||||
Exception: If database host/port is unreachable.
|
||||
|
||||
Note:
|
||||
Connection parameters are read from K8s secrets:
|
||||
- ``PG_HOST``: Database host (default: ``127.0.0.1``)
|
||||
- ``PG_PORT``: Database port (default: ``5432``)
|
||||
- ``PG_DB``: Database name (default: ``postgres``)
|
||||
- ``PG_USER``: Database user (default: ``postgres``)
|
||||
- ``PG_PASS``: Database password (default: ``secret``)
|
||||
|
||||
.. warning::
|
||||
Caller is responsible for closing the connection after use.
|
||||
"""
|
||||
db_host = get_secret("PG_HOST", "127.0.0.1")
|
||||
db_port = int(get_secret("PG_PORT", 5432))
|
||||
|
||||
@@ -43,6 +68,28 @@ def init_db_connection():
|
||||
|
||||
|
||||
def db_row_to_dict(cursor, row):
|
||||
r"""db_row_to_dict(cursor, row) -> dict
|
||||
|
||||
Convert a database row tuple to a dictionary.
|
||||
|
||||
Uses cursor description to map column names to values. Automatically
|
||||
converts :class:`datetime.datetime` values to ISO format strings.
|
||||
|
||||
Args:
|
||||
cursor: Database cursor with ``description`` attribute containing
|
||||
column metadata.
|
||||
row (tuple): Row tuple from ``cursor.fetchone()`` or similar.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with column names as keys and row values.
|
||||
|
||||
Example::
|
||||
|
||||
>>> cursor.execute("SELECT id, name, created FROM users")
|
||||
>>> row = cursor.fetchone()
|
||||
>>> db_row_to_dict(cursor, row)
|
||||
{'id': '123', 'name': 'John', 'created': '2024-01-01T10:00:00'}
|
||||
"""
|
||||
record = {}
|
||||
for i, column in enumerate(cursor.description):
|
||||
data = row[i]
|
||||
@@ -53,10 +100,34 @@ def db_row_to_dict(cursor, row):
|
||||
|
||||
|
||||
def db_rows_to_array(cursor, rows):
|
||||
r"""db_rows_to_array(cursor, rows) -> list[dict]
|
||||
|
||||
Convert multiple database rows to a list of dictionaries.
|
||||
|
||||
Args:
|
||||
cursor: Database cursor with ``description`` attribute.
|
||||
rows (list[tuple]): List of row tuples from ``cursor.fetchall()``.
|
||||
|
||||
Returns:
|
||||
list[dict]: List of dictionaries, one per row.
|
||||
|
||||
See Also:
|
||||
:func:`db_row_to_dict` for single row conversion.
|
||||
"""
|
||||
return [db_row_to_dict(cursor, row) for row in rows]
|
||||
|
||||
|
||||
def get_current_namespace() -> str:
|
||||
r"""get_current_namespace() -> str
|
||||
|
||||
Get the current Kubernetes namespace.
|
||||
|
||||
Reads namespace from the K8s service account file. Falls back to
|
||||
``default`` namespace if file is not accessible.
|
||||
|
||||
Returns:
|
||||
str: Current K8s namespace name.
|
||||
"""
|
||||
try:
|
||||
with open("/var/run/secrets/kubernetes.io/serviceaccount/namespace", "r") as f:
|
||||
namespace = f.read()
|
||||
@@ -67,6 +138,20 @@ def get_current_namespace() -> str:
|
||||
|
||||
|
||||
def get_secret(key: str, default=None):
|
||||
r"""get_secret(key, default=None) -> str | None
|
||||
|
||||
Read a secret value from Kubernetes mounted volume.
|
||||
|
||||
Args:
|
||||
key (str): Secret key name (e.g., ``"PG_HOST"``).
|
||||
default: Value to return if secret is not found. Default: ``None``
|
||||
|
||||
Returns:
|
||||
str | None: Secret value or default if not found.
|
||||
|
||||
Note:
|
||||
Secrets are mounted at ``/secrets/{namespace}/fission-ai-work-env/{key}``
|
||||
"""
|
||||
namespace = get_current_namespace()
|
||||
path = f"/secrets/{namespace}/{SECRET_NAME}/{key}"
|
||||
try:
|
||||
@@ -78,6 +163,20 @@ def get_secret(key: str, default=None):
|
||||
|
||||
|
||||
def get_config(key: str, default=None):
|
||||
r"""get_config(key, default=None) -> str | None
|
||||
|
||||
Read a config value from Kubernetes ConfigMap mounted volume.
|
||||
|
||||
Args:
|
||||
key (str): Config key name.
|
||||
default: Value to return if config is not found. Default: ``None``
|
||||
|
||||
Returns:
|
||||
str | None: Config value or default if not found.
|
||||
|
||||
Note:
|
||||
ConfigMaps are mounted at ``/configs/{namespace}/fission-ai-work-config/{key}``
|
||||
"""
|
||||
namespace = get_current_namespace()
|
||||
path = f"/configs/{namespace}/{CONFIG_NAME}/{key}"
|
||||
try:
|
||||
@@ -89,6 +188,26 @@ def get_config(key: str, default=None):
|
||||
|
||||
|
||||
def str_to_bool(input: str | None) -> bool:
|
||||
r"""str_to_bool(input) -> bool | None
|
||||
|
||||
Convert a string to boolean value.
|
||||
|
||||
Args:
|
||||
input (str | None): String to convert. Case-insensitive.
|
||||
|
||||
Returns:
|
||||
bool | None: ``True`` for ``"true"``, ``False`` for ``"false"``,
|
||||
``None`` for any other value.
|
||||
|
||||
Example::
|
||||
|
||||
>>> str_to_bool("true")
|
||||
True
|
||||
>>> str_to_bool("FALSE")
|
||||
False
|
||||
>>> str_to_bool("yes")
|
||||
None
|
||||
"""
|
||||
input = input or ""
|
||||
# Dictionary to map string values to boolean
|
||||
BOOL_MAP = {"true": True, "false": False}
|
||||
@@ -96,6 +215,25 @@ def str_to_bool(input: str | None) -> bool:
|
||||
|
||||
|
||||
def check_port_open(ip: str, port: int, timeout: int = 30):
|
||||
r"""check_port_open(ip, port, timeout=30) -> bool
|
||||
|
||||
Check if a TCP port is open and accepting connections.
|
||||
|
||||
Args:
|
||||
ip (str): IP address or hostname to check.
|
||||
port (int): Port number to check.
|
||||
timeout (int): Connection timeout in seconds. Default: ``30``
|
||||
|
||||
Returns:
|
||||
bool: ``True`` if port is open, ``False`` otherwise.
|
||||
|
||||
Example::
|
||||
|
||||
>>> check_port_open("127.0.0.1", 5432)
|
||||
True
|
||||
>>> check_port_open("localhost", 9999, timeout=5)
|
||||
False
|
||||
"""
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.settimeout(timeout)
|
||||
|
||||
5
apps/requirements-dev.txt
Normal file
5
apps/requirements-dev.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
pytest==8.3.5
|
||||
pytest-mock==3.14.0
|
||||
flask==3.1.0
|
||||
psycopg2-binary==2.9.10
|
||||
pydantic==2.10.6
|
||||
@@ -1,6 +1,4 @@
|
||||
psycopg2-binary==2.9.10
|
||||
pydantic==2.11.7
|
||||
PyNaCl==1.6.0
|
||||
Flask==3.1.0
|
||||
pytest==8.3.4
|
||||
pytest-mock==3.14.0
|
||||
Flask==3.1.0
|
||||
@@ -6,6 +6,28 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class AiUserCreate(BaseModel):
|
||||
r"""Schema for creating a new AI user.
|
||||
|
||||
Validates user data for the POST /ai/admin/users endpoint.
|
||||
|
||||
Attributes:
|
||||
id (str, optional): User UUID. Auto-generated if not provided.
|
||||
name (str): User's full name. Required, 1-128 characters.
|
||||
email (str): User's email address. Required, max 256 characters.
|
||||
Must be a valid email format.
|
||||
dob (date, optional): Date of birth in YYYY-MM-DD format.
|
||||
gender (str, optional): User's gender, max 10 characters.
|
||||
|
||||
Example::
|
||||
|
||||
>>> user = AiUserCreate(
|
||||
... name="John Doe",
|
||||
... email="john@example.com",
|
||||
... dob="1990-01-15",
|
||||
... gender="male"
|
||||
... )
|
||||
"""
|
||||
|
||||
id: Optional[str] = None
|
||||
name: str = Field(min_length=1, max_length=128)
|
||||
email: str = Field(..., max_length=256)
|
||||
@@ -20,6 +42,31 @@ class AiUserCreate(BaseModel):
|
||||
|
||||
|
||||
class AiUserUpdate(BaseModel):
|
||||
r"""Schema for updating an existing AI user.
|
||||
|
||||
Validates user data for the PUT /ai/admin/users/{UserID} endpoint.
|
||||
All fields are optional for partial updates.
|
||||
|
||||
Attributes:
|
||||
name (str, optional): User's full name, 1-128 characters.
|
||||
email (str, optional): User's email address, max 256 characters.
|
||||
Must be a valid email format if provided.
|
||||
dob (date, optional): Date of birth in YYYY-MM-DD format.
|
||||
gender (str, optional): User's gender, max 10 characters.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # Partial update - only change name
|
||||
>>> update = AiUserUpdate(name="Jane Doe")
|
||||
|
||||
>>> # Full update
|
||||
>>> update = AiUserUpdate(
|
||||
... name="Jane Doe",
|
||||
... email="jane@example.com",
|
||||
... gender="female"
|
||||
... )
|
||||
"""
|
||||
|
||||
name: Optional[str] = Field(default=None, min_length=1, max_length=128)
|
||||
email: Optional[str] = Field(default=None, max_length=256)
|
||||
dob: Optional[date] = None
|
||||
@@ -33,6 +80,36 @@ class AiUserUpdate(BaseModel):
|
||||
|
||||
|
||||
class AiUserFilter(BaseModel):
|
||||
r"""Schema for filtering and paginating AI users.
|
||||
|
||||
Used for parsing query parameters in the GET /ai/admin/users endpoint.
|
||||
|
||||
Attributes:
|
||||
q (str, optional): Search keyword for name and email fields.
|
||||
name (str, optional): Filter by name (partial match).
|
||||
email (str, optional): Filter by email (partial match).
|
||||
gender (str, optional): Filter by gender (exact match).
|
||||
dob_from (date, optional): Filter users with DOB on or after this date.
|
||||
dob_to (date, optional): Filter users with DOB on or before this date.
|
||||
created_from (str, optional): Filter users created on or after this datetime.
|
||||
created_to (str, optional): Filter users created on or before this datetime.
|
||||
page (int): Page number (0-indexed). Default: ``0``
|
||||
size (int): Records per page, 1-200. Default: ``20``
|
||||
sortby (str): Field to sort by. Default: ``"modified"``
|
||||
asc (bool): Sort ascending if ``True``. Default: ``False``
|
||||
|
||||
Example::
|
||||
|
||||
>>> # Parse from query string
|
||||
>>> filter = AiUserFilter(
|
||||
... q="john",
|
||||
... page=0,
|
||||
... size=10,
|
||||
... sortby="created",
|
||||
... asc=True
|
||||
... )
|
||||
"""
|
||||
|
||||
q: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# Tests package for AI Admin API
|
||||
|
||||
@@ -1,5 +1,99 @@
|
||||
"""Shared fixtures for API handler tests."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
# Add apps directory to path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
apps_dir = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(apps_dir))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app():
|
||||
"""Create Flask app context for testing."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_context(flask_app):
|
||||
"""Provide Flask application context."""
|
||||
with flask_app.app_context():
|
||||
yield flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def request_context(flask_app):
|
||||
"""Provide Flask request context."""
|
||||
with flask_app.test_request_context():
|
||||
yield flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_connection(mocker):
|
||||
"""Mock database connection with cursor that has description attribute."""
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.description = [
|
||||
MagicMock(name="id"),
|
||||
MagicMock(name="name"),
|
||||
MagicMock(name="dob"),
|
||||
MagicMock(name="email"),
|
||||
MagicMock(name="gender"),
|
||||
MagicMock(name="created"),
|
||||
MagicMock(name="modified"),
|
||||
]
|
||||
# Set name attribute on description items
|
||||
for i, col_name in enumerate(["id", "name", "dob", "email", "gender", "created", "modified"]):
|
||||
mock_cursor.description[i].name = col_name
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mocker.patch("helpers.init_db_connection", return_value=mock_conn)
|
||||
|
||||
return {"connection": mock_conn, "cursor": mock_cursor}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_secrets(mocker):
|
||||
"""Mock get_secret to return test values."""
|
||||
secrets = {
|
||||
"PG_HOST": "localhost",
|
||||
"PG_PORT": "5432",
|
||||
"PG_DB": "test_db",
|
||||
"PG_USER": "test_user",
|
||||
"PG_PASS": "test_pass",
|
||||
}
|
||||
mocker.patch("helpers.get_secret", side_effect=lambda key, default=None: secrets.get(key, default))
|
||||
return secrets
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_user_data():
|
||||
"""Sample user data for testing."""
|
||||
return {
|
||||
"name": "Test User",
|
||||
"email": "test@example.com",
|
||||
"dob": "1990-01-15",
|
||||
"gender": "male",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_db_row():
|
||||
"""Sample database row tuple."""
|
||||
return (
|
||||
"550e8400-e29b-41d4-a716-446655440000", # id
|
||||
"Test User", # name
|
||||
"1990-01-15", # dob
|
||||
"test@example.com", # email
|
||||
"male", # gender
|
||||
"2024-01-01T10:00:00", # created
|
||||
"2024-01-01T10:00:00", # modified
|
||||
)
|
||||
|
||||
@@ -1,86 +1,108 @@
|
||||
"""Tests for filter_insert.py - GET (filter) & POST (create) handlers."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from flask import Flask
|
||||
from psycopg2 import IntegrityError
|
||||
|
||||
|
||||
class TestMain:
|
||||
"""Test cases for the main() function"""
|
||||
"""Tests for main() dispatcher function."""
|
||||
|
||||
def test_main_get_method(self):
|
||||
"""Test main() with GET method calls make_filter_request()"""
|
||||
from flask import Flask
|
||||
from filter_insert import main
|
||||
def test_main_get_calls_filter(self, mocker):
|
||||
"""Test GET request routes to make_filter_request().
|
||||
|
||||
Given:
|
||||
A GET request to the endpoint.
|
||||
When:
|
||||
main() is called.
|
||||
Then:
|
||||
make_filter_request() is invoked and returns 200.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/ai/admin/users", method="GET"):
|
||||
with patch("filter_insert.make_filter_request") as mock_filter:
|
||||
mock_filter.return_value = ({"data": "test"}, 200, {})
|
||||
result = main()
|
||||
|
||||
with app.test_request_context(method="GET"):
|
||||
mock_filter = mocker.patch(
|
||||
"filter_insert.make_filter_request",
|
||||
return_value=({"data": []}, 200, {}),
|
||||
)
|
||||
mocker.patch("filter_insert.init_db_connection")
|
||||
|
||||
import filter_insert
|
||||
|
||||
result = filter_insert.main()
|
||||
|
||||
mock_filter.assert_called_once()
|
||||
assert result == ({"data": "test"}, 200, {})
|
||||
assert result[1] == 200
|
||||
|
||||
def test_main_post_method(self):
|
||||
"""Test main() with POST method calls make_insert_request()"""
|
||||
from flask import Flask
|
||||
from filter_insert import main
|
||||
def test_main_post_calls_insert(self, mocker):
|
||||
"""Test POST request routes to make_insert_request().
|
||||
|
||||
Given:
|
||||
A POST request with JSON body containing user data.
|
||||
When:
|
||||
main() is called.
|
||||
Then:
|
||||
make_insert_request() is invoked and returns 201.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/ai/admin/users", method="POST", json={"name": "Test", "email": "test@example.com"}):
|
||||
with patch("filter_insert.make_insert_request") as mock_insert:
|
||||
mock_insert.return_value = ({"id": "123"}, 201, {})
|
||||
result = main()
|
||||
|
||||
with app.test_request_context(
|
||||
method="POST",
|
||||
json={"name": "Test", "email": "test@example.com"},
|
||||
):
|
||||
mock_insert = mocker.patch(
|
||||
"filter_insert.make_insert_request",
|
||||
return_value=({"id": "123"}, 201, {}),
|
||||
)
|
||||
|
||||
import filter_insert
|
||||
|
||||
result = filter_insert.main()
|
||||
|
||||
mock_insert.assert_called_once()
|
||||
assert result == ({"id": "123"}, 201, {})
|
||||
assert result[1] == 201
|
||||
|
||||
def test_main_method_not_allowed(self):
|
||||
"""Test main() with unsupported HTTP method returns 405"""
|
||||
from flask import Flask
|
||||
from filter_insert import main, CORS_HEADERS
|
||||
def test_main_invalid_method_returns_405(self, mocker):
|
||||
"""Test unsupported HTTP method returns 405.
|
||||
|
||||
Given:
|
||||
A PATCH request (unsupported method).
|
||||
When:
|
||||
main() is called.
|
||||
Then:
|
||||
Returns 405 Method Not Allowed with error message.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/ai/admin/users", method="PUT"):
|
||||
result = main()
|
||||
|
||||
expected = ({"error": "Method not allow"}, 405, CORS_HEADERS)
|
||||
assert result == expected
|
||||
with app.test_request_context(method="PATCH"):
|
||||
import filter_insert
|
||||
|
||||
def test_main_exception_handling(self):
|
||||
"""Test main() catches exceptions and returns 500"""
|
||||
from flask import Flask
|
||||
from filter_insert import main
|
||||
result = filter_insert.main()
|
||||
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/ai/admin/users", method="GET"):
|
||||
with patch("filter_insert.make_filter_request") as mock_filter:
|
||||
mock_filter.side_effect = Exception("Database connection failed")
|
||||
result = main()
|
||||
|
||||
assert result[1] == 500
|
||||
assert result[1] == 405
|
||||
assert "error" in result[0]
|
||||
|
||||
|
||||
class TestMakeInsertRequest:
|
||||
"""Test cases for make_insert_request() function"""
|
||||
"""Tests for make_insert_request() - user creation."""
|
||||
|
||||
def test_make_insert_request_success(self):
|
||||
"""Test successful user insertion"""
|
||||
from flask import Flask
|
||||
from filter_insert import make_insert_request, CORS_HEADERS
|
||||
def test_insert_success(self, mocker, sample_user_data, sample_db_row):
|
||||
"""Test successful user creation returns 201.
|
||||
|
||||
Given:
|
||||
Valid user data with name, email, dob, and gender.
|
||||
Database connection is available.
|
||||
When:
|
||||
POST request to create user.
|
||||
Then:
|
||||
Returns 201 status code.
|
||||
Executes INSERT query with user data.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.get_json.return_value = {
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"dob": "1990-01-01",
|
||||
"gender": "male"
|
||||
}
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = ("uuid-123", "John Doe", "1990-01-01", "john@example.com", "male", "2024-01-01", "2024-01-01")
|
||||
mock_cursor.description = [
|
||||
MagicMock(name="id"),
|
||||
MagicMock(name="name"),
|
||||
@@ -90,258 +112,367 @@ class TestMakeInsertRequest:
|
||||
MagicMock(name="created"),
|
||||
MagicMock(name="modified"),
|
||||
]
|
||||
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
|
||||
mock_conn.__exit__ = MagicMock(return_value=False)
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with app.test_request_context("/ai/admin/users", method="POST", json=mock_request.get_json.return_value):
|
||||
with patch("filter_insert.init_db_connection", return_value=mock_conn):
|
||||
with patch("filter_insert.db_row_to_dict", return_value={"id": "uuid-123", "name": "John Doe"}):
|
||||
result = make_insert_request()
|
||||
|
||||
assert result[1] == 201
|
||||
assert result[2] == CORS_HEADERS
|
||||
|
||||
def test_make_insert_request_validation_error(self):
|
||||
"""Test make_insert_request() with invalid data returns 400"""
|
||||
from flask import Flask
|
||||
from filter_insert import make_insert_request, CORS_HEADERS
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# Invalid email format - pydantic will validate and reject
|
||||
request_data = {
|
||||
"name": "John Doe",
|
||||
"email": "invalid-email"
|
||||
}
|
||||
|
||||
with app.test_request_context("/ai/admin/users", method="POST", json=request_data):
|
||||
result = make_insert_request()
|
||||
|
||||
assert result[1] == 400
|
||||
assert "errorCode" in result[0].json
|
||||
assert result[0].json["errorCode"] == "VALIDATION_ERROR"
|
||||
assert result[2] == CORS_HEADERS
|
||||
|
||||
def test_make_insert_request_duplicate_error(self):
|
||||
"""Test make_insert_request() with duplicate email returns 409"""
|
||||
from flask import Flask
|
||||
from filter_insert import make_insert_request, CORS_HEADERS
|
||||
from psycopg2 import IntegrityError
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
request_data = {
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com"
|
||||
}
|
||||
for i, col_name in enumerate(
|
||||
["id", "name", "dob", "email", "gender", "created", "modified"]
|
||||
):
|
||||
mock_cursor.description[i].name = col_name
|
||||
mock_cursor.fetchone.return_value = sample_db_row
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
|
||||
mock_conn.__exit__ = MagicMock(return_value=False)
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with app.test_request_context("/ai/admin/users", method="POST", json=request_data):
|
||||
with patch("filter_insert.init_db_connection", return_value=mock_conn):
|
||||
# Simulate IntegrityError
|
||||
mock_cursor.execute.side_effect = IntegrityError("duplicate key value violates unique constraint")
|
||||
result = make_insert_request()
|
||||
mocker.patch("filter_insert.init_db_connection", return_value=mock_conn)
|
||||
|
||||
assert result[1] == 409
|
||||
assert result[0].json["errorCode"] == "DUPLICATE_TAG"
|
||||
assert result[2] == CORS_HEADERS
|
||||
with app.test_request_context(
|
||||
method="POST",
|
||||
json=sample_user_data,
|
||||
content_type="application/json",
|
||||
):
|
||||
import filter_insert
|
||||
|
||||
result = filter_insert.make_insert_request()
|
||||
|
||||
assert result[1] == 201
|
||||
mock_cursor.execute.assert_called_once()
|
||||
|
||||
def test_insert_validation_error_missing_name(self, mocker):
|
||||
"""Test missing required field 'name' returns 400.
|
||||
|
||||
Given:
|
||||
Request body with email but missing required 'name' field.
|
||||
When:
|
||||
POST request to create user.
|
||||
Then:
|
||||
Returns 400 Bad Request.
|
||||
Response contains errorCode 'VALIDATION_ERROR'.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
method="POST",
|
||||
json={"email": "test@example.com"}, # missing 'name'
|
||||
content_type="application/json",
|
||||
):
|
||||
import filter_insert
|
||||
|
||||
result = filter_insert.make_insert_request()
|
||||
|
||||
assert result[1] == 400
|
||||
response_data = result[0].get_json()
|
||||
assert response_data["errorCode"] == "VALIDATION_ERROR"
|
||||
|
||||
def test_insert_validation_error_invalid_email(self, mocker):
|
||||
"""Test invalid email format raises serialization error.
|
||||
|
||||
Given:
|
||||
Request body with invalid email format 'invalid-email'.
|
||||
When:
|
||||
POST request to create user.
|
||||
Then:
|
||||
Raises TypeError because ValidationError.errors() contains
|
||||
ValueError which is not JSON serializable.
|
||||
|
||||
Note:
|
||||
This test documents a bug in the source code where e.errors()
|
||||
is passed directly to jsonify without sanitization.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
method="POST",
|
||||
json={"name": "Test", "email": "invalid-email"},
|
||||
content_type="application/json",
|
||||
):
|
||||
import filter_insert
|
||||
|
||||
with pytest.raises(TypeError, match="not JSON serializable"):
|
||||
filter_insert.make_insert_request()
|
||||
|
||||
def test_insert_duplicate_email(self, mocker, sample_user_data):
|
||||
"""Test duplicate email returns 409 Conflict.
|
||||
|
||||
Given:
|
||||
Valid user data but email already exists in database.
|
||||
Database raises IntegrityError on INSERT.
|
||||
When:
|
||||
POST request to create user.
|
||||
Then:
|
||||
Returns 409 Conflict.
|
||||
Response contains errorCode 'DUPLICATE_TAG'.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
|
||||
mock_conn.__exit__ = MagicMock(return_value=False)
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(
|
||||
side_effect=IntegrityError("duplicate key value")
|
||||
)
|
||||
|
||||
mocker.patch("filter_insert.init_db_connection", return_value=mock_conn)
|
||||
|
||||
with app.test_request_context(
|
||||
method="POST",
|
||||
json=sample_user_data,
|
||||
content_type="application/json",
|
||||
):
|
||||
import filter_insert
|
||||
|
||||
result = filter_insert.make_insert_request()
|
||||
|
||||
assert result[1] == 409
|
||||
response_data = result[0].get_json()
|
||||
assert response_data["errorCode"] == "DUPLICATE_TAG"
|
||||
|
||||
|
||||
class TestMakeFilterRequest:
|
||||
"""Test cases for make_filter_request() function"""
|
||||
"""Tests for make_filter_request() - user filtering."""
|
||||
|
||||
def test_make_filter_request_success(self):
|
||||
"""Test successful filter request"""
|
||||
from flask import Flask
|
||||
from filter_insert import make_filter_request
|
||||
def test_filter_empty_result(self, mocker):
|
||||
"""Test filter with no matching results returns empty array.
|
||||
|
||||
Given:
|
||||
Database has no users matching the filter criteria.
|
||||
When:
|
||||
GET request to filter users.
|
||||
Then:
|
||||
Returns empty JSON array [].
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context("/ai/admin/users?page=0&size=8"):
|
||||
with patch("filter_insert.init_db_connection") as mock_init_db:
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = []
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_init_db.return_value = mock_conn
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.description = [
|
||||
MagicMock(name="id"),
|
||||
MagicMock(name="name"),
|
||||
MagicMock(name="dob"),
|
||||
MagicMock(name="email"),
|
||||
MagicMock(name="gender"),
|
||||
MagicMock(name="created"),
|
||||
MagicMock(name="modified"),
|
||||
MagicMock(name="count"),
|
||||
MagicMock(name="total"),
|
||||
]
|
||||
for i, col_name in enumerate(
|
||||
["id", "name", "dob", "email", "gender", "created", "modified", "count", "total"]
|
||||
):
|
||||
mock_cursor.description[i].name = col_name
|
||||
mock_cursor.fetchall.return_value = []
|
||||
|
||||
with patch("filter_insert.db_rows_to_array", return_value=[]):
|
||||
result = make_filter_request()
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# make_filter_request returns Flask Response object (jsonify)
|
||||
assert result.status_code == 200
|
||||
mocker.patch("filter_insert.init_db_connection", return_value=mock_conn)
|
||||
|
||||
with app.test_request_context(method="GET"):
|
||||
import filter_insert
|
||||
|
||||
class TestUserFilter:
|
||||
"""Test cases for UserFilter.from_request_queries()"""
|
||||
response = filter_insert.make_filter_request()
|
||||
|
||||
def test_user_filter_from_queries(self):
|
||||
"""Test UserFilter parses query parameters correctly"""
|
||||
from flask import Flask
|
||||
from filter_insert import UserFilter
|
||||
# make_filter_request returns Response object directly (not tuple)
|
||||
response_data = response.get_json()
|
||||
assert response_data == []
|
||||
|
||||
def test_filter_with_pagination(self, mocker, sample_db_row):
|
||||
"""Test filter with page and size parameters.
|
||||
|
||||
Given:
|
||||
Request with query params page=1 and size=5.
|
||||
When:
|
||||
GET request to filter users.
|
||||
Then:
|
||||
SQL query uses LIMIT 5 and OFFSET 5 (page * size).
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
def mock_get(key, default=None):
|
||||
values = {
|
||||
"filter[ids]": None,
|
||||
"filter[keyword]": "test",
|
||||
"filter[name]": "John",
|
||||
"filter[email]": "john@example.com",
|
||||
"filter[gender]": "male",
|
||||
"filter[created_from]": "2024-01-01",
|
||||
"filter[created_to]": "2024-12-31",
|
||||
"filter[modified_from]": None,
|
||||
"filter[modified_to]": None,
|
||||
"filter[dob_from]": None,
|
||||
"filter[dob_to]": None,
|
||||
}
|
||||
return values.get(key, default)
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.description = [
|
||||
MagicMock(name="id"),
|
||||
MagicMock(name="name"),
|
||||
MagicMock(name="dob"),
|
||||
MagicMock(name="email"),
|
||||
MagicMock(name="gender"),
|
||||
MagicMock(name="created"),
|
||||
MagicMock(name="modified"),
|
||||
MagicMock(name="count"),
|
||||
MagicMock(name="total"),
|
||||
]
|
||||
for i, col_name in enumerate(
|
||||
["id", "name", "dob", "email", "gender", "created", "modified", "count", "total"]
|
||||
):
|
||||
mock_cursor.description[i].name = col_name
|
||||
|
||||
with app.test_request_context():
|
||||
with patch("filter_insert.request") as mock_request:
|
||||
mock_request.args.get = mock_get
|
||||
mock_request.args.getlist.return_value = []
|
||||
# Add count and total to sample row
|
||||
row_with_counts = sample_db_row + (10, 10)
|
||||
mock_cursor.fetchall.return_value = [row_with_counts]
|
||||
|
||||
result = UserFilter.from_request_queries()
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
assert result.keyword == "test"
|
||||
assert result.name == "John"
|
||||
assert result.email == "john@example.com"
|
||||
assert result.gender == "male"
|
||||
assert result.created_from == "2024-01-01"
|
||||
assert result.created_to == "2024-12-31"
|
||||
mocker.patch("filter_insert.init_db_connection", return_value=mock_conn)
|
||||
|
||||
with app.test_request_context(method="GET", query_string={"page": "1", "size": "5"}):
|
||||
import filter_insert
|
||||
|
||||
class TestPage:
|
||||
"""Test cases for Page.from_request_queries()"""
|
||||
filter_insert.make_filter_request()
|
||||
|
||||
def test_page_default_values(self):
|
||||
"""Test Page uses default values when params not provided"""
|
||||
from flask import Flask
|
||||
from filter_insert import Page
|
||||
# Check that execute was called with correct offset
|
||||
# cursor.execute(sql, values) - values is second positional arg
|
||||
call_args = mock_cursor.execute.call_args
|
||||
values = call_args[0][1] # Second positional argument
|
||||
assert values["limit"] == 5
|
||||
assert values["offset"] == 5 # page 1 * size 5
|
||||
|
||||
def test_filter_with_keyword(self, mocker, sample_db_row):
|
||||
"""Test filter with keyword search across name and email.
|
||||
|
||||
Given:
|
||||
Request with query param filter[keyword]='test'.
|
||||
When:
|
||||
GET request to filter users.
|
||||
Then:
|
||||
SQL query contains ILIKE clause for keyword matching.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
def mock_get(key, default=None, type=None):
|
||||
values = {
|
||||
"page": 0,
|
||||
"size": 8,
|
||||
"asc": "false",
|
||||
}
|
||||
return values.get(key, default)
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.description = [
|
||||
MagicMock(name="id"),
|
||||
MagicMock(name="name"),
|
||||
MagicMock(name="dob"),
|
||||
MagicMock(name="email"),
|
||||
MagicMock(name="gender"),
|
||||
MagicMock(name="created"),
|
||||
MagicMock(name="modified"),
|
||||
MagicMock(name="count"),
|
||||
MagicMock(name="total"),
|
||||
]
|
||||
for i, col_name in enumerate(
|
||||
["id", "name", "dob", "email", "gender", "created", "modified", "count", "total"]
|
||||
):
|
||||
mock_cursor.description[i].name = col_name
|
||||
|
||||
with app.test_request_context():
|
||||
with patch("filter_insert.request") as mock_request:
|
||||
mock_request.args.get = mock_get
|
||||
result = Page.from_request_queries()
|
||||
row_with_counts = sample_db_row + (1, 1)
|
||||
mock_cursor.fetchall.return_value = [row_with_counts]
|
||||
|
||||
assert result.page == 0
|
||||
assert result.size == 8
|
||||
assert result.asc == False or result.asc == "false"
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mocker.patch("filter_insert.init_db_connection", return_value=mock_conn)
|
||||
|
||||
class TestUserPage:
|
||||
"""Test cases for UserPage.from_request_queries()"""
|
||||
with app.test_request_context(
|
||||
method="GET", query_string={"filter[keyword]": "test"}
|
||||
):
|
||||
import filter_insert
|
||||
|
||||
def test_user_page_with_sortby(self):
|
||||
"""Test UserPage parses sortby parameter correctly"""
|
||||
from flask import Flask
|
||||
from filter_insert import UserPage, UserSortField
|
||||
result = filter_insert.make_filter_request()
|
||||
|
||||
# Check SQL contains keyword filter
|
||||
call_args = mock_cursor.execute.call_args
|
||||
sql = call_args[0][0]
|
||||
assert "ILIKE" in sql
|
||||
|
||||
def test_filter_with_name(self, mocker, sample_db_row):
|
||||
"""Test filter by name with case-insensitive partial match.
|
||||
|
||||
Given:
|
||||
Request with query param filter[name]='John'.
|
||||
When:
|
||||
GET request to filter users.
|
||||
Then:
|
||||
SQL query contains 'LOWER(name) LIKE %john%'.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
def mock_get(key, default=None, type=None):
|
||||
values = {
|
||||
"page": 0,
|
||||
"size": 8,
|
||||
"asc": "true",
|
||||
"sortby": "created",
|
||||
}
|
||||
return values.get(key, default)
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.description = [
|
||||
MagicMock(name="id"),
|
||||
MagicMock(name="name"),
|
||||
MagicMock(name="dob"),
|
||||
MagicMock(name="email"),
|
||||
MagicMock(name="gender"),
|
||||
MagicMock(name="created"),
|
||||
MagicMock(name="modified"),
|
||||
MagicMock(name="count"),
|
||||
MagicMock(name="total"),
|
||||
]
|
||||
for i, col_name in enumerate(
|
||||
["id", "name", "dob", "email", "gender", "created", "modified", "count", "total"]
|
||||
):
|
||||
mock_cursor.description[i].name = col_name
|
||||
|
||||
with app.test_request_context():
|
||||
with patch("filter_insert.request") as mock_request:
|
||||
mock_request.args.get = mock_get
|
||||
mock_request.args.getlist.return_value = []
|
||||
row_with_counts = sample_db_row + (1, 1)
|
||||
mock_cursor.fetchall.return_value = [row_with_counts]
|
||||
|
||||
result = UserPage.from_request_queries()
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
assert result.sortby == UserSortField.CREATED
|
||||
assert result.asc == True or result.asc == "true"
|
||||
mocker.patch("filter_insert.init_db_connection", return_value=mock_conn)
|
||||
|
||||
def test_user_page_invalid_sortby(self):
|
||||
"""Test UserPage handles invalid sortby gracefully"""
|
||||
from flask import Flask
|
||||
from filter_insert import UserPage
|
||||
with app.test_request_context(
|
||||
method="GET", query_string={"filter[name]": "John"}
|
||||
):
|
||||
import filter_insert
|
||||
|
||||
filter_insert.make_filter_request()
|
||||
|
||||
call_args = mock_cursor.execute.call_args
|
||||
sql = call_args[0][0]
|
||||
values = call_args[0][1] # Second positional argument
|
||||
assert "LOWER(name) LIKE" in sql
|
||||
assert values["name"] == "%john%"
|
||||
|
||||
def test_filter_with_sortby(self, mocker, sample_db_row):
|
||||
"""Test filter with sortby and asc parameters.
|
||||
|
||||
Given:
|
||||
Request with query params sortby='created' and asc='true'.
|
||||
When:
|
||||
GET request to filter users.
|
||||
Then:
|
||||
SQL query contains 'ORDER BY created ASC'.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
def mock_get(key, default=None, type=None):
|
||||
values = {
|
||||
"page": 0,
|
||||
"size": 8,
|
||||
"asc": "false",
|
||||
"sortby": "invalid_field",
|
||||
}
|
||||
return values.get(key, default)
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.description = [
|
||||
MagicMock(name="id"),
|
||||
MagicMock(name="name"),
|
||||
MagicMock(name="dob"),
|
||||
MagicMock(name="email"),
|
||||
MagicMock(name="gender"),
|
||||
MagicMock(name="created"),
|
||||
MagicMock(name="modified"),
|
||||
MagicMock(name="count"),
|
||||
MagicMock(name="total"),
|
||||
]
|
||||
for i, col_name in enumerate(
|
||||
["id", "name", "dob", "email", "gender", "created", "modified", "count", "total"]
|
||||
):
|
||||
mock_cursor.description[i].name = col_name
|
||||
|
||||
with app.test_request_context():
|
||||
with patch("filter_insert.request") as mock_request:
|
||||
mock_request.args.get = mock_get
|
||||
mock_request.args.getlist.return_value = []
|
||||
row_with_counts = sample_db_row + (1, 1)
|
||||
mock_cursor.fetchall.return_value = [row_with_counts]
|
||||
|
||||
result = UserPage.from_request_queries()
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
assert result.sortby is None
|
||||
mocker.patch("filter_insert.init_db_connection", return_value=mock_conn)
|
||||
|
||||
with app.test_request_context(
|
||||
method="GET", query_string={"sortby": "created", "asc": "true"}
|
||||
):
|
||||
import filter_insert
|
||||
|
||||
class TestUserSortField:
|
||||
"""Test cases for UserSortField enum"""
|
||||
result = filter_insert.make_filter_request()
|
||||
|
||||
def test_user_sort_field_values(self):
|
||||
"""Test UserSortField has correct values"""
|
||||
from filter_insert import UserSortField
|
||||
|
||||
assert UserSortField.CREATED.value == "created"
|
||||
assert UserSortField.MODIFIED.value == "modified"
|
||||
|
||||
|
||||
class TestHelpers:
|
||||
"""Test cases for helper functions"""
|
||||
|
||||
def test_str_to_bool_true(self):
|
||||
"""Test str_to_bool with true values"""
|
||||
from helpers import str_to_bool
|
||||
|
||||
assert str_to_bool("true") is True
|
||||
assert str_to_bool("True") is True
|
||||
assert str_to_bool("TRUE") is True
|
||||
|
||||
def test_str_to_bool_false(self):
|
||||
"""Test str_to_bool with false values"""
|
||||
from helpers import str_to_bool
|
||||
|
||||
assert str_to_bool("false") is False
|
||||
assert str_to_bool("False") is False
|
||||
assert str_to_bool("FALSE") is False
|
||||
|
||||
def test_str_to_bool_none(self):
|
||||
"""Test str_to_bool with invalid values returns None"""
|
||||
from helpers import str_to_bool
|
||||
|
||||
assert str_to_bool("invalid") is None
|
||||
assert str_to_bool("") is None
|
||||
assert str_to_bool(None) is None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
call_args = mock_cursor.execute.call_args
|
||||
sql = call_args[0][0]
|
||||
assert "ORDER BY created ASC" in sql
|
||||
|
||||
514
apps/tests/test_update_delete.py
Normal file
514
apps/tests/test_update_delete.py
Normal file
@@ -0,0 +1,514 @@
|
||||
"""Tests for update_delete.py - PUT (update) & DELETE handlers."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from psycopg2 import IntegrityError
|
||||
|
||||
|
||||
class TestMain:
|
||||
"""Tests for main() dispatcher function."""
|
||||
|
||||
def test_main_put_calls_update(self, mocker):
|
||||
"""Test PUT request routes to make_update_request().
|
||||
|
||||
Given:
|
||||
A PUT request to the endpoint.
|
||||
When:
|
||||
main() is called.
|
||||
Then:
|
||||
make_update_request() is invoked and returns 200.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(method="PUT"):
|
||||
mock_update = mocker.patch(
|
||||
"update_delete.make_update_request",
|
||||
return_value=({"id": "123"}, 200, {}),
|
||||
)
|
||||
|
||||
import update_delete
|
||||
|
||||
result = update_delete.main()
|
||||
|
||||
mock_update.assert_called_once()
|
||||
assert result[1] == 200
|
||||
|
||||
def test_main_delete_calls_delete(self, mocker):
|
||||
"""Test DELETE request routes to make_delete_request().
|
||||
|
||||
Given:
|
||||
A DELETE request to the endpoint.
|
||||
When:
|
||||
main() is called.
|
||||
Then:
|
||||
make_delete_request() is invoked and returns 200.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(method="DELETE"):
|
||||
mock_delete = mocker.patch(
|
||||
"update_delete.make_delete_request",
|
||||
return_value=({"id": "123"}, 200, {}),
|
||||
)
|
||||
|
||||
import update_delete
|
||||
|
||||
result = update_delete.main()
|
||||
|
||||
mock_delete.assert_called_once()
|
||||
assert result[1] == 200
|
||||
|
||||
def test_main_invalid_method_returns_405(self, mocker):
|
||||
"""Test unsupported HTTP method returns 405.
|
||||
|
||||
Given:
|
||||
A POST request (unsupported method for this endpoint).
|
||||
When:
|
||||
main() is called.
|
||||
Then:
|
||||
Returns 405 Method Not Allowed with error message.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(method="POST"):
|
||||
import update_delete
|
||||
|
||||
result = update_delete.main()
|
||||
|
||||
assert result[1] == 405
|
||||
assert "error" in result[0]
|
||||
|
||||
|
||||
class TestMakeUpdateRequest:
|
||||
"""Tests for make_update_request() - user update."""
|
||||
|
||||
def test_update_missing_user_id(self, mocker):
|
||||
"""Test missing X-Fission-Params-UserID header returns 400.
|
||||
|
||||
Given:
|
||||
PUT request without X-Fission-Params-UserID header.
|
||||
When:
|
||||
make_update_request() is called.
|
||||
Then:
|
||||
Returns 400 Bad Request.
|
||||
Response contains errorCode 'MISSING_USER_ID'.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
method="PUT",
|
||||
json={"name": "Updated Name"},
|
||||
content_type="application/json",
|
||||
):
|
||||
import update_delete
|
||||
|
||||
result = update_delete.make_update_request()
|
||||
|
||||
assert result[1] == 400
|
||||
response_data = result[0].get_json()
|
||||
assert response_data["errorCode"] == "MISSING_USER_ID"
|
||||
|
||||
def test_update_user_not_found(self, mocker):
|
||||
"""Test update non-existent user returns 404.
|
||||
|
||||
Given:
|
||||
Valid X-Fission-Params-UserID header.
|
||||
User does not exist in database (SELECT returns None).
|
||||
When:
|
||||
PUT request to update user.
|
||||
Then:
|
||||
Returns 404 Not Found.
|
||||
Response contains errorCode 'USER_NOT_FOUND'.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = None # User not found
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
|
||||
mock_conn.__exit__ = MagicMock(return_value=False)
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mocker.patch("update_delete.init_db_connection", return_value=mock_conn)
|
||||
|
||||
with app.test_request_context(
|
||||
method="PUT",
|
||||
headers={"X-Fission-Params-UserID": "nonexistent-id"},
|
||||
json={"name": "Updated Name"},
|
||||
content_type="application/json",
|
||||
):
|
||||
import update_delete
|
||||
|
||||
result = update_delete.make_update_request()
|
||||
|
||||
assert result[1] == 404
|
||||
response_data = result[0].get_json()
|
||||
assert response_data["errorCode"] == "USER_NOT_FOUND"
|
||||
|
||||
def test_update_success(self, mocker, sample_db_row):
|
||||
"""Test successful user update returns 200.
|
||||
|
||||
Given:
|
||||
Valid X-Fission-Params-UserID header.
|
||||
User exists in database.
|
||||
Valid update data in request body.
|
||||
When:
|
||||
PUT request to update user.
|
||||
Then:
|
||||
Returns 200 OK.
|
||||
User data is updated in database.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.description = [
|
||||
MagicMock(name="id"),
|
||||
MagicMock(name="name"),
|
||||
MagicMock(name="dob"),
|
||||
MagicMock(name="email"),
|
||||
MagicMock(name="gender"),
|
||||
MagicMock(name="created"),
|
||||
MagicMock(name="modified"),
|
||||
]
|
||||
for i, col_name in enumerate(
|
||||
["id", "name", "dob", "email", "gender", "created", "modified"]
|
||||
):
|
||||
mock_cursor.description[i].name = col_name
|
||||
|
||||
# First fetchone returns existing user, second returns updated user
|
||||
mock_cursor.fetchone.side_effect = [sample_db_row, sample_db_row]
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
|
||||
mock_conn.__exit__ = MagicMock(return_value=False)
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mocker.patch("update_delete.init_db_connection", return_value=mock_conn)
|
||||
|
||||
with app.test_request_context(
|
||||
method="PUT",
|
||||
headers={"X-Fission-Params-UserID": "550e8400-e29b-41d4-a716-446655440000"},
|
||||
json={"name": "Updated Name"},
|
||||
content_type="application/json",
|
||||
):
|
||||
import update_delete
|
||||
|
||||
result = update_delete.make_update_request()
|
||||
|
||||
assert result[1] == 200
|
||||
|
||||
def test_update_validation_error_invalid_email(self, mocker):
|
||||
"""Test invalid email format raises serialization error.
|
||||
|
||||
Given:
|
||||
Valid X-Fission-Params-UserID header.
|
||||
Request body with invalid email format 'invalid-email'.
|
||||
When:
|
||||
PUT request to update user.
|
||||
Then:
|
||||
Raises TypeError because ValidationError.errors() contains
|
||||
ValueError which is not JSON serializable.
|
||||
|
||||
Note:
|
||||
This test documents a bug in the source code where e.errors()
|
||||
is passed directly to jsonify without sanitization.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
method="PUT",
|
||||
headers={"X-Fission-Params-UserID": "some-id"},
|
||||
json={"email": "invalid-email"},
|
||||
content_type="application/json",
|
||||
):
|
||||
import update_delete
|
||||
|
||||
with pytest.raises(TypeError, match="not JSON serializable"):
|
||||
update_delete.make_update_request()
|
||||
|
||||
def test_update_duplicate_email(self, mocker, sample_db_row):
|
||||
"""Test duplicate email on update returns 409 Conflict.
|
||||
|
||||
Given:
|
||||
Valid X-Fission-Params-UserID header.
|
||||
User exists in database.
|
||||
New email already exists for another user.
|
||||
Database raises IntegrityError on UPDATE.
|
||||
When:
|
||||
PUT request to update user email.
|
||||
Then:
|
||||
Returns 409 Conflict.
|
||||
Response contains errorCode 'DUPLICATE_USER'.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.description = [
|
||||
MagicMock(name="id"),
|
||||
MagicMock(name="name"),
|
||||
MagicMock(name="dob"),
|
||||
MagicMock(name="email"),
|
||||
MagicMock(name="gender"),
|
||||
MagicMock(name="created"),
|
||||
MagicMock(name="modified"),
|
||||
]
|
||||
for i, col_name in enumerate(
|
||||
["id", "name", "dob", "email", "gender", "created", "modified"]
|
||||
):
|
||||
mock_cursor.description[i].name = col_name
|
||||
|
||||
# First fetchone returns existing user
|
||||
mock_cursor.fetchone.return_value = sample_db_row
|
||||
# Second execute raises IntegrityError
|
||||
mock_cursor.execute.side_effect = [None, IntegrityError("duplicate key")]
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
|
||||
mock_conn.__exit__ = MagicMock(return_value=False)
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mocker.patch("update_delete.init_db_connection", return_value=mock_conn)
|
||||
|
||||
with app.test_request_context(
|
||||
method="PUT",
|
||||
headers={"X-Fission-Params-UserID": "550e8400-e29b-41d4-a716-446655440000"},
|
||||
json={"email": "duplicate@example.com"},
|
||||
content_type="application/json",
|
||||
):
|
||||
import update_delete
|
||||
|
||||
result = update_delete.make_update_request()
|
||||
|
||||
assert result[1] == 409
|
||||
response_data = result[0].get_json()
|
||||
assert response_data["errorCode"] == "DUPLICATE_USER"
|
||||
|
||||
|
||||
class TestMakeDeleteRequest:
|
||||
"""Tests for make_delete_request() - user deletion."""
|
||||
|
||||
def test_delete_missing_user_id(self, mocker):
|
||||
"""Test missing X-Fission-Params-UserID header returns 400.
|
||||
|
||||
Given:
|
||||
DELETE request without X-Fission-Params-UserID header.
|
||||
When:
|
||||
make_delete_request() is called.
|
||||
Then:
|
||||
Returns 400 Bad Request.
|
||||
Response contains errorCode 'MISSING_USER_ID'.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(method="DELETE"):
|
||||
import update_delete
|
||||
|
||||
result = update_delete.make_delete_request()
|
||||
|
||||
assert result[1] == 400
|
||||
response_data = result[0].get_json()
|
||||
assert response_data["errorCode"] == "MISSING_USER_ID"
|
||||
|
||||
def test_delete_user_not_found(self, mocker):
|
||||
"""Test delete non-existent user returns 404.
|
||||
|
||||
Given:
|
||||
Valid X-Fission-Params-UserID header.
|
||||
User does not exist in database (SELECT returns None).
|
||||
When:
|
||||
DELETE request to delete user.
|
||||
Then:
|
||||
Returns 404 Not Found.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = None # User not found
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mocker.patch("update_delete.init_db_connection", return_value=mock_conn)
|
||||
|
||||
with app.test_request_context(
|
||||
method="DELETE",
|
||||
headers={"X-Fission-Params-UserID": "nonexistent-id"},
|
||||
):
|
||||
import update_delete
|
||||
|
||||
result = update_delete.make_delete_request()
|
||||
|
||||
assert result[1] == 404
|
||||
|
||||
def test_delete_success(self, mocker, sample_db_row):
|
||||
"""Test successful user deletion returns 200.
|
||||
|
||||
Given:
|
||||
Valid X-Fission-Params-UserID header.
|
||||
User exists in database.
|
||||
When:
|
||||
DELETE request to delete user.
|
||||
Then:
|
||||
Returns 200 OK.
|
||||
User is deleted from database.
|
||||
Response contains deleted user data.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.description = [
|
||||
MagicMock(name="id"),
|
||||
MagicMock(name="name"),
|
||||
MagicMock(name="dob"),
|
||||
MagicMock(name="email"),
|
||||
MagicMock(name="gender"),
|
||||
MagicMock(name="created"),
|
||||
MagicMock(name="modified"),
|
||||
]
|
||||
for i, col_name in enumerate(
|
||||
["id", "name", "dob", "email", "gender", "created", "modified"]
|
||||
):
|
||||
mock_cursor.description[i].name = col_name
|
||||
|
||||
# First fetchone checks existence, second returns deleted row
|
||||
mock_cursor.fetchone.side_effect = [(1,), sample_db_row]
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mocker.patch("update_delete.init_db_connection", return_value=mock_conn)
|
||||
|
||||
with app.test_request_context(
|
||||
method="DELETE",
|
||||
headers={"X-Fission-Params-UserID": "550e8400-e29b-41d4-a716-446655440000"},
|
||||
):
|
||||
import update_delete
|
||||
|
||||
result = update_delete.make_delete_request()
|
||||
|
||||
assert result[1] == 200
|
||||
|
||||
|
||||
class TestUpdatePartialFields:
|
||||
"""Tests for partial field updates."""
|
||||
|
||||
def test_update_only_name(self, mocker, sample_db_row):
|
||||
"""Test update only name field.
|
||||
|
||||
Given:
|
||||
Valid X-Fission-Params-UserID header.
|
||||
User exists in database.
|
||||
Request body contains only 'name' field.
|
||||
When:
|
||||
PUT request to update user.
|
||||
Then:
|
||||
Returns 200 OK.
|
||||
UPDATE SQL only includes name field (plus modified timestamp).
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.description = [
|
||||
MagicMock(name="id"),
|
||||
MagicMock(name="name"),
|
||||
MagicMock(name="dob"),
|
||||
MagicMock(name="email"),
|
||||
MagicMock(name="gender"),
|
||||
MagicMock(name="created"),
|
||||
MagicMock(name="modified"),
|
||||
]
|
||||
for i, col_name in enumerate(
|
||||
["id", "name", "dob", "email", "gender", "created", "modified"]
|
||||
):
|
||||
mock_cursor.description[i].name = col_name
|
||||
|
||||
mock_cursor.fetchone.side_effect = [sample_db_row, sample_db_row]
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
|
||||
mock_conn.__exit__ = MagicMock(return_value=False)
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mocker.patch("update_delete.init_db_connection", return_value=mock_conn)
|
||||
|
||||
with app.test_request_context(
|
||||
method="PUT",
|
||||
headers={"X-Fission-Params-UserID": "550e8400-e29b-41d4-a716-446655440000"},
|
||||
json={"name": "New Name Only"},
|
||||
content_type="application/json",
|
||||
):
|
||||
import update_delete
|
||||
|
||||
result = update_delete.make_update_request()
|
||||
|
||||
assert result[1] == 200
|
||||
# Check that UPDATE SQL contains name field
|
||||
call_args = mock_cursor.execute.call_args_list[-1]
|
||||
sql = call_args[0][0]
|
||||
assert "name=" in sql
|
||||
|
||||
def test_update_multiple_fields(self, mocker, sample_db_row):
|
||||
"""Test update multiple fields at once.
|
||||
|
||||
Given:
|
||||
Valid X-Fission-Params-UserID header.
|
||||
User exists in database.
|
||||
Request body contains name, email, and gender fields.
|
||||
When:
|
||||
PUT request to update user.
|
||||
Then:
|
||||
Returns 200 OK.
|
||||
UPDATE SQL includes all three fields.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.description = [
|
||||
MagicMock(name="id"),
|
||||
MagicMock(name="name"),
|
||||
MagicMock(name="dob"),
|
||||
MagicMock(name="email"),
|
||||
MagicMock(name="gender"),
|
||||
MagicMock(name="created"),
|
||||
MagicMock(name="modified"),
|
||||
]
|
||||
for i, col_name in enumerate(
|
||||
["id", "name", "dob", "email", "gender", "created", "modified"]
|
||||
):
|
||||
mock_cursor.description[i].name = col_name
|
||||
|
||||
mock_cursor.fetchone.side_effect = [sample_db_row, sample_db_row]
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
|
||||
mock_conn.__exit__ = MagicMock(return_value=False)
|
||||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mocker.patch("update_delete.init_db_connection", return_value=mock_conn)
|
||||
|
||||
with app.test_request_context(
|
||||
method="PUT",
|
||||
headers={"X-Fission-Params-UserID": "550e8400-e29b-41d4-a716-446655440000"},
|
||||
json={"name": "New Name", "email": "new@example.com", "gender": "female"},
|
||||
content_type="application/json",
|
||||
):
|
||||
import update_delete
|
||||
|
||||
result = update_delete.make_update_request()
|
||||
|
||||
assert result[1] == 200
|
||||
call_args = mock_cursor.execute.call_args_list[-1]
|
||||
sql = call_args[0][0]
|
||||
assert "name=" in sql
|
||||
assert "email=" in sql
|
||||
assert "gender=" in sql
|
||||
@@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
from flask import current_app, jsonify, request
|
||||
from helpers import CORS_HEADERS, db_row_to_dict, init_db_connection
|
||||
from psycopg2 import IntegrityError
|
||||
@@ -15,7 +13,7 @@ def main():
|
||||
"fntimeout": 300,
|
||||
"http_triggers": {
|
||||
"ai-admin-update-delete-user-http": {
|
||||
"url": "/ai/admin/users/{UserID}",
|
||||
"url": "/ailbl/ai/admin/users/{UserID}",
|
||||
"methods": ["DELETE", "PUT"]
|
||||
}
|
||||
}
|
||||
@@ -35,6 +33,40 @@ def main():
|
||||
|
||||
|
||||
def make_update_request():
|
||||
r"""make_update_request() -> tuple[Response, int, dict]
|
||||
|
||||
Update an existing user by ID.
|
||||
|
||||
Retrieves the user ID from ``X-Fission-Params-UserID`` header, validates
|
||||
the request body using :class:`AiUserUpdate` schema, and performs a
|
||||
partial update on the user record.
|
||||
|
||||
Uses row-level locking (``SELECT ... FOR UPDATE``) to prevent concurrent
|
||||
modification conflicts.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing:
|
||||
- JSON response with updated user data or error details
|
||||
- HTTP status code (200 on success, 400/404/409 on error)
|
||||
- CORS headers dict
|
||||
|
||||
Raises:
|
||||
ValidationError: If request body fails Pydantic validation (returns 400).
|
||||
IntegrityError: If email conflicts with another user (returns 409).
|
||||
|
||||
Example::
|
||||
|
||||
>>> # PUT /ai/admin/users/550e8400-e29b-41d4-a716-446655440000
|
||||
>>> # Header: X-Fission-Params-UserID: 550e8400-e29b-41d4-a716-446655440000
|
||||
>>> # Body: {"name": "Jane Doe"}
|
||||
>>> # Response: 200 OK
|
||||
>>> {
|
||||
... "id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
... "name": "Jane Doe",
|
||||
... "email": "john@example.com",
|
||||
... "modified": "2024-01-02T10:00:00"
|
||||
... }
|
||||
"""
|
||||
user_id = request.headers.get("X-Fission-Params-UserID")
|
||||
if not user_id:
|
||||
return jsonify({"errorCode": "MISSING_USER_ID"}), 400, CORS_HEADERS
|
||||
@@ -97,6 +129,19 @@ def make_update_request():
|
||||
|
||||
|
||||
def __delete_user(cursor, id: str):
|
||||
r"""Delete a user from the database by ID.
|
||||
|
||||
Args:
|
||||
cursor: Database cursor object for executing queries.
|
||||
id (str): UUID of the user to delete.
|
||||
|
||||
Returns:
|
||||
dict | str: User data dict if deleted successfully,
|
||||
or ``"USER_NOT_FOUND"`` string if user doesn't exist.
|
||||
|
||||
Note:
|
||||
This is a private function. Use :func:`make_delete_request` instead.
|
||||
"""
|
||||
cursor.execute("SELECT 1 FROM ai_user WHERE id = %(id)s", {"id": id})
|
||||
if not cursor.fetchone():
|
||||
return "USER_NOT_FOUND"
|
||||
@@ -106,7 +151,30 @@ def __delete_user(cursor, id: str):
|
||||
return db_row_to_dict(cursor, row)
|
||||
|
||||
def make_delete_request():
|
||||
r"""make_delete_request() -> tuple[Response, int, dict]
|
||||
|
||||
Delete a user by ID.
|
||||
|
||||
Retrieves the user ID from ``X-Fission-Params-UserID`` header and
|
||||
deletes the user from the database if found.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing:
|
||||
- JSON response with deleted user data or error details
|
||||
- HTTP status code (200 on success, 400/404/500 on error)
|
||||
- CORS headers dict (may be omitted on some error responses)
|
||||
|
||||
Example::
|
||||
|
||||
>>> # DELETE /ai/admin/users/550e8400-e29b-41d4-a716-446655440000
|
||||
>>> # Header: X-Fission-Params-UserID: 550e8400-e29b-41d4-a716-446655440000
|
||||
>>> # Response: 200 OK
|
||||
>>> {
|
||||
... "id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
... "name": "John Doe",
|
||||
... "email": "john@example.com"
|
||||
... }
|
||||
"""
|
||||
user_id = request.headers.get("X-Fission-Params-UserID")
|
||||
if not user_id:
|
||||
return jsonify({"errorCode": "MISSING_USER_ID"}), 400, CORS_HEADERS
|
||||
|
||||
Reference in New Issue
Block a user