Compare commits

6 Commits
ai ... main

Author SHA1 Message Date
Nguyen Duc Thao
ab38579758 fix
Some checks failed
K8S Fission Deployment / Deployment fission functions (push) Failing after 17s
2026-01-26 23:35:47 +07:00
Nguyen Duc Thao
68cb816208 fix
Some checks failed
K8S Fission Deployment / Deployment fission functions (push) Failing after 22s
2026-01-26 23:29:16 +07:00
Nguyen Duc Thao
86899d9593 fix
Some checks failed
K8S Fission Deployment / Deployment fission functions (push) Failing after 22s
2026-01-26 23:23:38 +07:00
Nguyen Duc Thao
f2696c2f75 fix
Some checks failed
K8S Fission Deployment / Deployment fission functions (push) Failing after 10s
2026-01-26 23:22:14 +07:00
Nguyen Duc Thao
aaebbfee76 remove
Some checks failed
K8S Fission Deployment / Deployment fission functions (push) Failing after 10s
2026-01-26 23:08:47 +07:00
Nguyen Duc Thao
3861b027b2 add new
Some checks failed
K8S Fission Deployment / Deployment fission functions (push) Failing after 12s
2026-01-26 23:07:28 +07:00
24 changed files with 3000 additions and 4982 deletions

View File

@@ -0,0 +1,623 @@
---
name: api-documentation-generator
description: Generate OpenAPI/Swagger specifications and API documentation from code or design. Use when creating API docs, generating OpenAPI specs, or documenting REST APIs.
---
# API Documentation Generator
Generate OpenAPI/Swagger specifications and comprehensive API documentation.
## Quick Start
Create OpenAPI 3.0 specs with paths, schemas, and examples for complete API documentation.
## Instructions
### OpenAPI 3.0 Structure
**Basic structure:**
```yaml
openapi: 3.0.0
info:
title: API Name
version: 1.0.0
description: API description
servers:
- url: https://api.example.com/v1
paths:
/users:
get:
summary: List users
responses:
'200':
description: Success
components:
schemas:
User:
type: object
properties:
id:
type: integer
name:
type: string
```
### Info Section
```yaml
info:
title: E-commerce API
version: 1.0.0
description: |
REST API for e-commerce platform.
## Authentication
Use Bearer token in Authorization header.
## Rate Limiting
1000 requests per hour per API key.
contact:
name: API Support
email: api@example.com
url: https://example.com/support
license:
name: MIT
url: https://opensource.org/licenses/MIT
```
### Servers
```yaml
servers:
- url: https://api.example.com/v1
description: Production
- url: https://staging-api.example.com/v1
description: Staging
- url: http://localhost:3000/v1
description: Development
```
### Paths and Operations
**GET endpoint:**
```yaml
paths:
/users:
get:
summary: List users
description: Retrieve a paginated list of users
tags:
- Users
parameters:
- name: page
in: query
schema:
type: integer
default: 1
- name: per_page
in: query
schema:
type: integer
default: 20
responses:
'200':
description: Successful response
content:
application/json:
schema:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/User'
meta:
$ref: '#/components/schemas/PaginationMeta'
```
**POST endpoint:**
```yaml
/users:
post:
summary: Create user
tags:
- Users
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/CreateUserRequest'
example:
name: John Doe
email: john@example.com
responses:
'201':
description: User created
content:
application/json:
schema:
$ref: '#/components/schemas/User'
'400':
$ref: '#/components/responses/BadRequest'
```
**Path parameters:**
```yaml
/users/{id}:
get:
summary: Get user by ID
parameters:
- name: id
in: path
required: true
schema:
type: integer
description: User ID
responses:
'200':
description: Success
content:
application/json:
schema:
$ref: '#/components/schemas/User'
'404':
$ref: '#/components/responses/NotFound'
```
### Components - Schemas
**Simple schema:**
```yaml
components:
schemas:
User:
type: object
required:
- id
- email
properties:
id:
type: integer
example: 1
name:
type: string
example: John Doe
email:
type: string
format: email
example: john@example.com
created_at:
type: string
format: date-time
```
**Nested schema:**
```yaml
Order:
type: object
properties:
id:
type: integer
customer:
$ref: '#/components/schemas/User'
items:
type: array
items:
$ref: '#/components/schemas/OrderItem'
total:
type: number
format: float
```
**Enum:**
```yaml
OrderStatus:
type: string
enum:
- pending
- processing
- shipped
- delivered
- cancelled
```
**OneOf (union types):**
```yaml
Payment:
oneOf:
- $ref: '#/components/schemas/CreditCardPayment'
- $ref: '#/components/schemas/PayPalPayment'
discriminator:
propertyName: payment_type
```
### Components - Responses
**Reusable responses:**
```yaml
components:
responses:
NotFound:
description: Resource not found
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
example:
error:
code: NOT_FOUND
message: Resource not found
BadRequest:
description: Invalid request
content:
application/json:
schema:
$ref: '#/components/schemas/ValidationError'
Unauthorized:
description: Authentication required
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
```
### Components - Parameters
**Reusable parameters:**
```yaml
components:
parameters:
PageParam:
name: page
in: query
schema:
type: integer
default: 1
LimitParam:
name: limit
in: query
schema:
type: integer
default: 20
maximum: 100
IdParam:
name: id
in: path
required: true
schema:
type: integer
```
### Security Schemes
**Bearer token:**
```yaml
components:
securitySchemes:
BearerAuth:
type: http
scheme: bearer
bearerFormat: JWT
security:
- BearerAuth: []
```
**API Key:**
```yaml
components:
securitySchemes:
ApiKeyAuth:
type: apiKey
in: header
name: X-API-Key
```
**OAuth 2.0:**
```yaml
components:
securitySchemes:
OAuth2:
type: oauth2
flows:
authorizationCode:
authorizationUrl: https://example.com/oauth/authorize
tokenUrl: https://example.com/oauth/token
scopes:
read: Read access
write: Write access
```
### Examples
**Multiple examples:**
```yaml
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/User'
examples:
admin:
summary: Admin user
value:
id: 1
name: Admin
role: admin
regular:
summary: Regular user
value:
id: 2
name: John
role: user
```
### Tags
**Organize endpoints:**
```yaml
tags:
- name: Users
description: User management
- name: Products
description: Product catalog
- name: Orders
description: Order processing
paths:
/users:
get:
tags:
- Users
```
## Complete Example
```yaml
openapi: 3.0.0
info:
title: Blog API
version: 1.0.0
description: RESTful API for blog platform
servers:
- url: https://api.blog.com/v1
paths:
/posts:
get:
summary: List posts
tags:
- Posts
parameters:
- $ref: '#/components/parameters/PageParam'
- $ref: '#/components/parameters/LimitParam'
responses:
'200':
description: Success
content:
application/json:
schema:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/Post'
meta:
$ref: '#/components/schemas/PaginationMeta'
post:
summary: Create post
tags:
- Posts
security:
- BearerAuth: []
requestBody:
required: true
content:
application/json:
schema:
type: object
required:
- title
- content
properties:
title:
type: string
content:
type: string
tags:
type: array
items:
type: string
responses:
'201':
description: Post created
content:
application/json:
schema:
$ref: '#/components/schemas/Post'
/posts/{id}:
get:
summary: Get post
tags:
- Posts
parameters:
- name: id
in: path
required: true
schema:
type: integer
responses:
'200':
description: Success
content:
application/json:
schema:
$ref: '#/components/schemas/Post'
'404':
$ref: '#/components/responses/NotFound'
components:
schemas:
Post:
type: object
properties:
id:
type: integer
title:
type: string
content:
type: string
author:
$ref: '#/components/schemas/User'
created_at:
type: string
format: date-time
User:
type: object
properties:
id:
type: integer
name:
type: string
email:
type: string
PaginationMeta:
type: object
properties:
total:
type: integer
page:
type: integer
per_page:
type: integer
Error:
type: object
properties:
error:
type: object
properties:
code:
type: string
message:
type: string
parameters:
PageParam:
name: page
in: query
schema:
type: integer
default: 1
LimitParam:
name: limit
in: query
schema:
type: integer
default: 20
responses:
NotFound:
description: Resource not found
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
securitySchemes:
BearerAuth:
type: http
scheme: bearer
bearerFormat: JWT
security:
- BearerAuth: []
```
## Generating from Code
### From Express.js
```javascript
// Use swagger-jsdoc
/**
* @swagger
* /users:
* get:
* summary: List users
* responses:
* 200:
* description: Success
*/
app.get('/users', (req, res) => {
// Handler
});
```
### From FastAPI (Python)
```python
# FastAPI auto-generates OpenAPI
@app.get("/users", response_model=List[User])
async def list_users():
return users
```
### From ASP.NET Core
```csharp
// Use Swashbuckle
[HttpGet]
[ProducesResponseType(typeof(List<User>), 200)]
public IActionResult GetUsers()
{
return Ok(users);
}
```
## Tools
**Swagger Editor**: https://editor.swagger.io
**Swagger UI**: Interactive documentation
**Redoc**: Alternative documentation UI
**Postman**: Import OpenAPI for testing
## Best Practices
**Use $ref for reusability:**
- Define schemas once
- Reference in multiple places
- Easier maintenance
**Include examples:**
- Help developers understand
- Enable better testing
- Show expected formats
**Document errors:**
- All possible status codes
- Error response format
- Error codes and meanings
**Version your API:**
- Include version in URL or header
- Document breaking changes
- Maintain old versions
**Keep it updated:**
- Generate from code when possible
- Review regularly
- Update with API changes

View File

@@ -0,0 +1,187 @@
---
name: pytest
description: |
Python testing with pytest framework for unit, integration, and API tests.
Use when: (1) Writing test cases for Python code, (2) Setting up pytest fixtures,
(3) Testing async functions with pytest-asyncio, (4) Mocking dependencies,
(5) Parameterizing tests, (6) Testing FastAPI/Flask endpoints, (7) Setting up test coverage,
(8) Creating test factories with factory_boy, (9) Configuring CI/CD test pipelines.
---
# Pytest Testing Skill
Comprehensive testing patterns for Python applications using pytest.
## Quick Reference
| Feature | Reference File |
|---------|----------------|
| Fixtures, conftest, scopes | [references/fixtures.md](references/fixtures.md) |
| Async testing, pytest-asyncio | [references/async-testing.md](references/async-testing.md) |
| Mocking, patching, spies | [references/mocking.md](references/mocking.md) |
| FastAPI/Flask endpoint testing | [references/api-testing.md](references/api-testing.md) |
## Dependencies
```toml
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.24.0",
"pytest-cov>=4.1.0",
"httpx>=0.28.0", # Async HTTP client for API tests
"factory-boy>=3.3.0", # Test factories
"faker>=33.0.0", # Fake data generation
]
```
## Configuration
### pyproject.toml
```toml
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
python_files = ["test_*.py"]
python_functions = ["test_*"]
addopts = "-v --tb=short"
filterwarnings = ["ignore::DeprecationWarning"]
[tool.coverage.run]
source = ["app"]
omit = ["*/tests/*", "*/__init__.py"]
```
## Basic Test Structure
```python
import pytest
# Simple test function
def test_addition():
assert 1 + 1 == 2
# Test class (group related tests)
class TestCalculator:
def test_add(self):
assert add(2, 3) == 5
def test_subtract(self):
assert subtract(5, 3) == 2
# Expected exceptions
def test_division_by_zero():
with pytest.raises(ZeroDivisionError):
divide(1, 0)
# Parametrized tests
@pytest.mark.parametrize("input,expected", [
(1, 1),
(2, 4),
(3, 9),
])
def test_square(input, expected):
assert square(input) == expected
```
## Fixtures
```python
import pytest
@pytest.fixture
def sample_user():
return {"name": "John", "email": "john@example.com"}
@pytest.fixture
def db_connection():
conn = create_connection()
yield conn # Test runs here
conn.close() # Cleanup after test
# Use fixture in test
def test_user_name(sample_user):
assert sample_user["name"] == "John"
```
## Async Testing
```python
import pytest
@pytest.mark.asyncio
async def test_async_function():
result = await async_operation()
assert result == "success"
@pytest.fixture
async def async_client():
async with AsyncClient() as client:
yield client
```
## Mocking
```python
from unittest.mock import patch, MagicMock, AsyncMock
def test_with_mock():
with patch("module.external_api") as mock_api:
mock_api.return_value = {"data": "mocked"}
result = function_using_api()
assert result["data"] == "mocked"
mock_api.assert_called_once()
# Async mock
@pytest.mark.asyncio
async def test_async_mock():
with patch("module.async_call", new_callable=AsyncMock) as mock:
mock.return_value = "result"
result = await function_with_async_call()
assert result == "result"
```
## Running Tests
```bash
# Run all tests
pytest
# Verbose output
pytest -v
# Run specific file
pytest tests/test_users.py
# Run specific test
pytest tests/test_users.py::test_create_user
# Run with coverage
pytest --cov=app --cov-report=term-missing
# Stop on first failure
pytest -x
# Run last failed tests
pytest --lf
# Run tests matching pattern
pytest -k "user and not delete"
```
## Project Structure
```
project/
├── app/
│ ├── __init__.py
│ ├── main.py
│ └── services/
├── tests/
│ ├── __init__.py
│ ├── conftest.py # Shared fixtures
│ ├── test_main.py
│ └── test_services/
└── pyproject.toml
```

View File

@@ -0,0 +1,434 @@
# API Testing
## Table of Contents
- [FastAPI Testing](#fastapi-testing)
- [Flask Testing](#flask-testing)
- [Test Factories](#test-factories)
- [Authentication Testing](#authentication-testing)
- [Database Testing](#database-testing)
## FastAPI Testing
### Setup
```python
# tests/conftest.py
import pytest
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlmodel import SQLModel
from app.main import app
from app.core.database import get_session
TEST_DATABASE_URL = "sqlite+aiosqlite:///./test.db"
test_engine = create_async_engine(TEST_DATABASE_URL)
TestSession = async_sessionmaker(test_engine, class_=AsyncSession)
@pytest.fixture(scope="function")
async def db_session():
"""Fresh database for each test."""
async with test_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
async with TestSession() as session:
yield session
async with test_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.drop_all)
@pytest.fixture
async def client(db_session):
"""Test client with dependency override."""
async def override_session():
yield db_session
app.dependency_overrides[get_session] = override_session
async with AsyncClient(
transport=ASGITransport(app=app),
base_url="http://test",
) as ac:
yield ac
app.dependency_overrides.clear()
```
### Basic Endpoint Tests
```python
import pytest
from httpx import AsyncClient
class TestUsers:
@pytest.mark.asyncio
async def test_create_user(self, client: AsyncClient):
response = await client.post(
"/api/users",
json={"email": "test@example.com", "name": "Test User"},
)
assert response.status_code == 201
data = response.json()
assert data["email"] == "test@example.com"
assert "id" in data
@pytest.mark.asyncio
async def test_get_user(self, client: AsyncClient, test_user):
response = await client.get(f"/api/users/{test_user.id}")
assert response.status_code == 200
assert response.json()["id"] == test_user.id
@pytest.mark.asyncio
async def test_get_user_not_found(self, client: AsyncClient):
response = await client.get("/api/users/99999")
assert response.status_code == 404
@pytest.mark.asyncio
async def test_update_user(self, client: AsyncClient, test_user, auth_headers):
response = await client.patch(
f"/api/users/{test_user.id}",
json={"name": "Updated Name"},
headers=auth_headers,
)
assert response.status_code == 200
assert response.json()["name"] == "Updated Name"
@pytest.mark.asyncio
async def test_delete_user(self, client: AsyncClient, test_user, admin_headers):
response = await client.delete(
f"/api/users/{test_user.id}",
headers=admin_headers,
)
assert response.status_code == 204
```
### Testing Query Parameters
```python
@pytest.mark.asyncio
async def test_list_users_pagination(self, client: AsyncClient, auth_headers):
response = await client.get(
"/api/users",
params={"page": 1, "per_page": 10, "sort": "-created_at"},
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert "items" in data
assert "total" in data
assert data["page"] == 1
@pytest.mark.asyncio
async def test_search_users(self, client: AsyncClient, auth_headers):
response = await client.get(
"/api/users",
params={"q": "john", "active": True},
headers=auth_headers,
)
assert response.status_code == 200
```
### Testing File Uploads
```python
import io
@pytest.mark.asyncio
async def test_upload_file(self, client: AsyncClient, auth_headers):
file_content = b"Hello, World!"
files = {"file": ("test.txt", io.BytesIO(file_content), "text/plain")}
response = await client.post(
"/api/files/upload",
files=files,
headers=auth_headers,
)
assert response.status_code == 201
assert response.json()["filename"] == "test.txt"
```
## Flask Testing
### Setup
```python
import pytest
from app import create_app
from app.extensions import db
@pytest.fixture
def app():
"""Create application for testing."""
app = create_app("testing")
app.config["TESTING"] = True
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
with app.app_context():
db.create_all()
yield app
db.drop_all()
@pytest.fixture
def client(app):
"""Test client."""
return app.test_client()
@pytest.fixture
def runner(app):
"""CLI test runner."""
return app.test_cli_runner()
```
### Flask Test Examples
```python
def test_create_user(client):
response = client.post(
"/api/users",
json={"email": "test@example.com", "name": "Test"},
)
assert response.status_code == 201
def test_get_user(client, test_user):
response = client.get(f"/api/users/{test_user.id}")
assert response.status_code == 200
assert response.json["email"] == test_user.email
def test_protected_route(client, auth_token):
response = client.get(
"/api/protected",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
```
## Test Factories
### Using factory_boy
```python
# tests/factories.py
import factory
from faker import Faker
from app.models import User, Project
fake = Faker()
class UserFactory(factory.Factory):
class Meta:
model = User
email = factory.LazyAttribute(lambda _: fake.email())
name = factory.LazyAttribute(lambda _: fake.name())
hashed_password = "hashed_password"
is_active = True
class ProjectFactory(factory.Factory):
class Meta:
model = Project
name = factory.LazyAttribute(lambda _: fake.company())
description = factory.LazyAttribute(lambda _: fake.sentence())
owner_id = None
# Async helper
async def create_user(session, **kwargs) -> User:
user = UserFactory(**kwargs)
session.add(user)
await session.commit()
await session.refresh(user)
return user
async def create_users(session, count: int = 5, **kwargs) -> list[User]:
users = [UserFactory(**kwargs) for _ in range(count)]
session.add_all(users)
await session.commit()
return users
```
### Using Factories in Tests
```python
from tests.factories import create_user, create_users
@pytest.mark.asyncio
async def test_list_users(client, db_session, auth_headers):
# Create 20 users
await create_users(db_session, count=20)
response = await client.get("/api/users", headers=auth_headers)
assert response.status_code == 200
assert response.json()["total"] >= 20
@pytest.mark.asyncio
async def test_filter_active_users(client, db_session, auth_headers):
await create_user(db_session, is_active=True)
await create_user(db_session, is_active=False)
response = await client.get(
"/api/users",
params={"active": True},
headers=auth_headers,
)
data = response.json()
assert all(u["is_active"] for u in data["items"])
```
## Authentication Testing
### Auth Fixtures
```python
from app.core.security import create_access_token, hash_password
@pytest.fixture
async def test_user(db_session):
"""Create a regular test user."""
user = User(
email="test@example.com",
hashed_password=hash_password("password123"),
is_active=True,
)
db_session.add(user)
await db_session.commit()
await db_session.refresh(user)
return user
@pytest.fixture
async def admin_user(db_session):
"""Create an admin user."""
user = User(
email="admin@example.com",
hashed_password=hash_password("admin123"),
is_active=True,
is_admin=True,
)
db_session.add(user)
await db_session.commit()
await db_session.refresh(user)
return user
@pytest.fixture
def auth_headers(test_user):
"""Headers with valid auth token."""
token = create_access_token({"sub": str(test_user.id)})
return {"Authorization": f"Bearer {token}"}
@pytest.fixture
def admin_headers(admin_user):
"""Headers with admin auth token."""
token = create_access_token({"sub": str(admin_user.id)})
return {"Authorization": f"Bearer {token}"}
```
### Auth Test Examples
```python
class TestAuth:
@pytest.mark.asyncio
async def test_login_success(self, client, test_user):
response = await client.post(
"/api/auth/login",
data={"username": test_user.email, "password": "password123"},
)
assert response.status_code == 200
assert "access_token" in response.json()
@pytest.mark.asyncio
async def test_login_wrong_password(self, client, test_user):
response = await client.post(
"/api/auth/login",
data={"username": test_user.email, "password": "wrong"},
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_protected_endpoint_no_auth(self, client):
response = await client.get("/api/protected")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_admin_endpoint_regular_user(self, client, auth_headers):
response = await client.delete(
"/api/admin/users/1",
headers=auth_headers,
)
assert response.status_code == 403
```
## Database Testing
### Transaction Rollback Pattern
```python
@pytest.fixture
async def db_session(engine):
"""Each test runs in a transaction that rolls back."""
async with engine.connect() as conn:
await conn.begin()
async with AsyncSession(bind=conn) as session:
yield session
await conn.rollback()
```
### Testing Repository Layer
```python
from app.repositories import UserRepository
class TestUserRepository:
@pytest.mark.asyncio
async def test_create(self, db_session):
repo = UserRepository(db_session)
user = await repo.create({
"email": "test@example.com",
"hashed_password": "hash",
})
assert user.id is not None
assert user.email == "test@example.com"
@pytest.mark.asyncio
async def test_get_by_email(self, db_session, test_user):
repo = UserRepository(db_session)
user = await repo.get_by_email(test_user.email)
assert user.id == test_user.id
@pytest.mark.asyncio
async def test_soft_delete(self, db_session, test_user):
repo = UserRepository(db_session)
await repo.soft_delete(test_user.id)
user = await repo.get_by_id(test_user.id)
assert user.status == 0 # Soft deleted
```

View File

@@ -0,0 +1,290 @@
# Async Testing
## Table of Contents
- [Setup](#setup)
- [Basic Async Tests](#basic-async-tests)
- [Async Fixtures](#async-fixtures)
- [Testing Async Generators](#testing-async-generators)
- [Timeouts](#timeouts)
## Setup
### Installation
```bash
pip install pytest-asyncio
```
### Configuration
```toml
# pyproject.toml
[tool.pytest.ini_options]
asyncio_mode = "auto" # Automatically handle async tests
# OR
asyncio_mode = "strict" # Require explicit @pytest.mark.asyncio
```
### With pytest.ini
```ini
[pytest]
asyncio_mode = auto
```
## Basic Async Tests
### Simple Async Test
```python
import pytest
# With asyncio_mode = "auto", decorator is optional
async def test_async_function():
result = await async_operation()
assert result == "success"
# With asyncio_mode = "strict", decorator is required
@pytest.mark.asyncio
async def test_async_with_marker():
result = await async_operation()
assert result == "success"
```
### Async Test Class
```python
@pytest.mark.asyncio
class TestAsyncOperations:
async def test_fetch_data(self):
data = await fetch_data()
assert data is not None
async def test_process_data(self):
result = await process_data({"key": "value"})
assert result["processed"] is True
```
### Testing Async Exceptions
```python
import pytest
@pytest.mark.asyncio
async def test_async_exception():
with pytest.raises(ValueError):
await async_function_that_raises()
@pytest.mark.asyncio
async def test_async_exception_message():
with pytest.raises(ValueError, match="Invalid input"):
await validate_input("")
```
## Async Fixtures
### Basic Async Fixture
```python
import pytest
@pytest.fixture
async def async_client():
"""Async fixture for HTTP client."""
async with httpx.AsyncClient() as client:
yield client
@pytest.mark.asyncio
async def test_api_call(async_client):
response = await async_client.get("https://api.example.com/data")
assert response.status_code == 200
```
### Async Database Fixture
```python
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
@pytest.fixture(scope="session")
def event_loop():
"""Create event loop for session-scoped async fixtures."""
import asyncio
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="session")
async def engine():
"""Session-scoped async engine."""
engine = create_async_engine("postgresql+asyncpg://...")
yield engine
await engine.dispose()
@pytest.fixture
async def db_session(engine):
"""Function-scoped database session with rollback."""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with AsyncSession(engine) as session:
yield session
await session.rollback()
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
```
### Fixture with Async Setup/Teardown
```python
@pytest.fixture
async def resource():
# Async setup
resource = await create_resource()
await resource.initialize()
yield resource
# Async teardown
await resource.cleanup()
await resource.close()
```
## Testing Async Generators
### Async Generator Function
```python
async def async_data_generator():
for i in range(5):
await asyncio.sleep(0.1)
yield i
@pytest.mark.asyncio
async def test_async_generator():
results = []
async for item in async_data_generator():
results.append(item)
assert results == [0, 1, 2, 3, 4]
```
### Async Context Manager
```python
from contextlib import asynccontextmanager
@asynccontextmanager
async def async_resource():
resource = await create_resource()
try:
yield resource
finally:
await resource.close()
@pytest.mark.asyncio
async def test_async_context_manager():
async with async_resource() as resource:
result = await resource.operation()
assert result is not None
```
## Timeouts
### Test Timeout
```python
import asyncio
import pytest
@pytest.mark.asyncio
@pytest.mark.timeout(5) # 5 second timeout
async def test_slow_operation():
result = await potentially_slow_operation()
assert result is not None
```
### Using asyncio.wait_for
```python
@pytest.mark.asyncio
async def test_with_timeout():
try:
result = await asyncio.wait_for(
slow_operation(),
timeout=2.0
)
assert result is not None
except asyncio.TimeoutError:
pytest.fail("Operation timed out")
```
### Testing Timeout Behavior
```python
@pytest.mark.asyncio
async def test_timeout_is_raised():
"""Verify function properly times out."""
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(
never_completes(),
timeout=0.1
)
```
## Concurrent Tests
### Testing Concurrent Operations
```python
@pytest.mark.asyncio
async def test_concurrent_requests():
async with httpx.AsyncClient() as client:
# Run 10 requests concurrently
tasks = [
client.get(f"https://api.example.com/items/{i}")
for i in range(10)
]
responses = await asyncio.gather(*tasks)
assert all(r.status_code == 200 for r in responses)
```
### Testing Race Conditions
```python
@pytest.mark.asyncio
async def test_counter_thread_safety():
counter = AsyncCounter()
async def increment():
for _ in range(100):
await counter.increment()
# Run 10 concurrent incrementers
await asyncio.gather(*[increment() for _ in range(10)])
assert counter.value == 1000
```
## Event Loop Configuration
### Custom Event Loop
```python
import pytest
import asyncio
@pytest.fixture(scope="session")
def event_loop_policy():
"""Use uvloop for faster async tests."""
import uvloop
return uvloop.EventLoopPolicy()
@pytest.fixture(scope="session")
def event_loop(event_loop_policy):
asyncio.set_event_loop_policy(event_loop_policy)
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
```

View File

@@ -0,0 +1,296 @@
# Fixtures
## Table of Contents
- [Basic Fixtures](#basic-fixtures)
- [Fixture Scopes](#fixture-scopes)
- [Fixture Parameters](#fixture-parameters)
- [Conftest.py](#conftestpy)
- [Built-in Fixtures](#built-in-fixtures)
## Basic Fixtures
### Simple Fixture
```python
import pytest
@pytest.fixture
def user_data():
"""Return sample user data."""
return {
"name": "John Doe",
"email": "john@example.com",
"age": 30,
}
def test_user_name(user_data):
assert user_data["name"] == "John Doe"
def test_user_email(user_data):
assert "@" in user_data["email"]
```
### Fixture with Setup and Teardown
```python
@pytest.fixture
def database():
"""Setup database, yield, then cleanup."""
# Setup
db = create_database()
db.connect()
yield db # Test runs here
# Teardown
db.clear()
db.disconnect()
def test_insert(database):
database.insert({"id": 1, "name": "Test"})
assert database.count() == 1
```
### Fixture Returning Factory
```python
@pytest.fixture
def make_user():
"""Return a factory function for creating users."""
created_users = []
def _make_user(name: str, email: str = None):
user = User(name=name, email=email or f"{name}@example.com")
created_users.append(user)
return user
yield _make_user
# Cleanup all created users
for user in created_users:
user.delete()
def test_multiple_users(make_user):
user1 = make_user("Alice")
user2 = make_user("Bob")
assert user1.name != user2.name
```
## Fixture Scopes
```python
# Function scope (default) - runs for each test
@pytest.fixture(scope="function")
def fresh_data():
return {"count": 0}
# Class scope - runs once per test class
@pytest.fixture(scope="class")
def class_resource():
return ExpensiveResource()
# Module scope - runs once per test module
@pytest.fixture(scope="module")
def module_connection():
conn = create_connection()
yield conn
conn.close()
# Session scope - runs once per test session
@pytest.fixture(scope="session")
def session_config():
return load_config()
```
### Scope Hierarchy
```
session (once per test run)
└── package (once per test package)
└── module (once per test file)
└── class (once per test class)
└── function (once per test function)
```
## Fixture Parameters
### Parametrized Fixtures
```python
@pytest.fixture(params=["sqlite", "postgresql", "mysql"])
def database_type(request):
"""Run tests with each database type."""
return request.param
def test_connection(database_type):
# This test runs 3 times, once for each database
db = create_database(database_type)
assert db.connect()
```
### Fixture with IDs
```python
@pytest.fixture(params=[
pytest.param({"admin": True}, id="admin"),
pytest.param({"admin": False}, id="regular"),
])
def user_config(request):
return request.param
# Test output shows: test_permissions[admin], test_permissions[regular]
```
### Indirect Parametrization
```python
@pytest.fixture
def user(request):
"""Create user based on parameter."""
role = request.param
return User(role=role)
@pytest.mark.parametrize("user", ["admin", "editor", "viewer"], indirect=True)
def test_user_access(user):
# user fixture receives each role as request.param
assert user.role in ["admin", "editor", "viewer"]
```
## Conftest.py
### tests/conftest.py
```python
"""Shared fixtures for all tests."""
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Session-scoped database engine
@pytest.fixture(scope="session")
def engine():
return create_engine("sqlite:///:memory:")
# Function-scoped session with transaction rollback
@pytest.fixture(scope="function")
def db_session(engine):
connection = engine.connect()
transaction = connection.begin()
Session = sessionmaker(bind=connection)
session = Session()
yield session
session.close()
transaction.rollback()
connection.close()
# Shared test data
@pytest.fixture
def sample_products():
return [
{"name": "Widget", "price": 9.99},
{"name": "Gadget", "price": 19.99},
{"name": "Gizmo", "price": 29.99},
]
```
### Nested Conftest Files
```
tests/
├── conftest.py # Available to all tests
├── unit/
│ ├── conftest.py # Available to unit tests only
│ └── test_models.py
└── integration/
├── conftest.py # Available to integration tests only
└── test_api.py
```
## Built-in Fixtures
### tmp_path / tmp_path_factory
```python
def test_create_file(tmp_path):
"""tmp_path provides unique temporary directory."""
file = tmp_path / "test.txt"
file.write_text("Hello, World!")
assert file.read_text() == "Hello, World!"
@pytest.fixture(scope="session")
def session_temp_dir(tmp_path_factory):
"""Create temp dir for entire session."""
return tmp_path_factory.mktemp("session_data")
```
### capsys / capfd
```python
def test_print_output(capsys):
"""Capture stdout/stderr."""
print("Hello")
captured = capsys.readouterr()
assert captured.out == "Hello\n"
def test_file_descriptor_output(capfd):
"""Capture at file descriptor level."""
import os
os.system("echo 'Hello'")
captured = capfd.readouterr()
assert "Hello" in captured.out
```
### monkeypatch
```python
def test_env_variable(monkeypatch):
"""Modify environment temporarily."""
monkeypatch.setenv("API_KEY", "test-key")
assert os.environ["API_KEY"] == "test-key"
def test_module_attribute(monkeypatch):
"""Modify module attribute temporarily."""
monkeypatch.setattr("module.CONFIG", {"debug": True})
assert module.CONFIG["debug"] is True
def test_dict_item(monkeypatch):
"""Modify dictionary temporarily."""
monkeypatch.setitem(app.settings, "DEBUG", True)
```
### request
```python
@pytest.fixture
def resource(request):
"""Access test context."""
print(f"Running: {request.node.name}")
print(f"Module: {request.module.__name__}")
print(f"Function: {request.function.__name__}")
# Access fixture parameters
if hasattr(request, "param"):
return create_resource(request.param)
return create_default_resource()
```
### Autouse Fixtures
```python
@pytest.fixture(autouse=True)
def setup_logging():
"""Automatically runs for every test."""
logging.basicConfig(level=logging.DEBUG)
yield
logging.shutdown()
@pytest.fixture(autouse=True, scope="session")
def global_setup():
"""Run once at session start."""
initialize_system()
yield
cleanup_system()
```

View File

@@ -0,0 +1,354 @@
# Mocking
## Table of Contents
- [Basic Mocking](#basic-mocking)
- [Patching](#patching)
- [Mock Objects](#mock-objects)
- [Async Mocking](#async-mocking)
- [Fixture-Based Mocking](#fixture-based-mocking)
- [Common Patterns](#common-patterns)
## Basic Mocking
### MagicMock Basics
```python
from unittest.mock import MagicMock
def test_with_mock():
# Create a mock object
mock_service = MagicMock()
# Configure return value
mock_service.get_data.return_value = {"id": 1, "name": "Test"}
# Use mock
result = mock_service.get_data()
assert result["name"] == "Test"
# Verify call
mock_service.get_data.assert_called_once()
```
### Return Value Configuration
```python
from unittest.mock import MagicMock
mock = MagicMock()
# Simple return value
mock.method.return_value = 42
# Different returns on consecutive calls
mock.method.side_effect = [1, 2, 3]
assert mock.method() == 1
assert mock.method() == 2
assert mock.method() == 3
# Raise exception
mock.method.side_effect = ValueError("Error!")
# Dynamic return value
mock.method.side_effect = lambda x: x * 2
assert mock.method(5) == 10
```
## Patching
### patch Decorator
```python
from unittest.mock import patch
# Patch a function
@patch("mymodule.external_api")
def test_with_patched_api(mock_api):
mock_api.return_value = {"status": "ok"}
result = mymodule.call_api()
assert result["status"] == "ok"
# Patch multiple
@patch("mymodule.function_a")
@patch("mymodule.function_b")
def test_multiple_patches(mock_b, mock_a):
# Note: decorators apply bottom-up
mock_a.return_value = "a"
mock_b.return_value = "b"
```
### patch Context Manager
```python
from unittest.mock import patch
def test_with_context_manager():
with patch("mymodule.external_api") as mock_api:
mock_api.return_value = {"data": "mocked"}
result = mymodule.process()
assert result == {"data": "mocked"}
# Original function restored after with block
```
### patch.object
```python
from unittest.mock import patch
class MyService:
def fetch_data(self):
return "real data"
def test_patch_instance_method():
service = MyService()
with patch.object(service, "fetch_data", return_value="mocked"):
assert service.fetch_data() == "mocked"
```
### Patching Where Used
```python
# mymodule.py
from requests import get # 'get' is imported here
def fetch_url(url):
return get(url).text
# test_mymodule.py
# Patch where it's USED, not where it's DEFINED
@patch("mymodule.get") # NOT "requests.get"
def test_fetch_url(mock_get):
mock_get.return_value.text = "mocked response"
result = fetch_url("http://example.com")
assert result == "mocked response"
```
## Mock Objects
### Spec and Autospec
```python
from unittest.mock import MagicMock, create_autospec
class UserService:
def get_user(self, user_id: int) -> dict:
pass
def create_user(self, data: dict) -> dict:
pass
# Basic mock (allows any attribute)
mock = MagicMock()
mock.nonexistent_method() # Works, but shouldn't
# Spec mock (restricts to real attributes)
mock = MagicMock(spec=UserService)
# mock.nonexistent_method() # Raises AttributeError
# Autospec (also checks signatures)
mock = create_autospec(UserService)
# mock.get_user() # Raises TypeError (missing user_id)
mock.get_user(123) # Works
```
### Assertion Methods
```python
from unittest.mock import MagicMock, call
mock = MagicMock()
mock.method(1, 2, key="value")
mock.method(3, 4)
# Verify calls
mock.method.assert_called()
mock.method.assert_called_once() # Fails - called twice
mock.method.assert_called_with(3, 4)
mock.method.assert_any_call(1, 2, key="value")
# Check call count
assert mock.method.call_count == 2
# Check all calls
assert mock.method.call_args_list == [
call(1, 2, key="value"),
call(3, 4),
]
# Reset mock
mock.reset_mock()
assert mock.method.call_count == 0
```
## Async Mocking
### AsyncMock
```python
from unittest.mock import AsyncMock, patch
import pytest
@pytest.mark.asyncio
async def test_async_mock():
mock = AsyncMock(return_value={"data": "mocked"})
result = await mock()
assert result["data"] == "mocked"
@pytest.mark.asyncio
@patch("mymodule.async_fetch", new_callable=AsyncMock)
async def test_patched_async(mock_fetch):
mock_fetch.return_value = {"status": "ok"}
result = await mymodule.process()
assert result["status"] == "ok"
```
### Async Side Effects
```python
from unittest.mock import AsyncMock
@pytest.mark.asyncio
async def test_async_side_effect():
mock = AsyncMock()
# Return different values
mock.side_effect = [1, 2, 3]
assert await mock() == 1
assert await mock() == 2
# Async function as side effect
async def async_side_effect(x):
return x * 2
mock.side_effect = async_side_effect
assert await mock(5) == 10
```
## Fixture-Based Mocking
### Mock Fixtures
```python
import pytest
from unittest.mock import MagicMock, patch
@pytest.fixture
def mock_database():
"""Fixture providing mock database."""
mock_db = MagicMock()
mock_db.query.return_value = [{"id": 1}, {"id": 2}]
return mock_db
def test_with_mock_fixture(mock_database):
result = mock_database.query("SELECT * FROM users")
assert len(result) == 2
@pytest.fixture
def mock_external_api():
"""Fixture with patching."""
with patch("mymodule.external_api") as mock:
mock.return_value = {"status": "ok"}
yield mock
```
### Monkeypatch Fixture
```python
def test_with_monkeypatch(monkeypatch):
# Patch function
monkeypatch.setattr("mymodule.get_config", lambda: {"debug": True})
# Patch environment variable
monkeypatch.setenv("API_KEY", "test-key")
# Patch dictionary item
monkeypatch.setitem(mymodule.settings, "DEBUG", True)
# Patch attribute
monkeypatch.setattr(mymodule.client, "timeout", 5)
```
## Common Patterns
### Mock HTTP Responses
```python
from unittest.mock import patch, MagicMock
@patch("requests.get")
def test_http_request(mock_get):
# Configure mock response
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"data": "test"}
mock_get.return_value = mock_response
result = fetch_data("http://api.example.com")
assert result == {"data": "test"}
mock_get.assert_called_once_with("http://api.example.com")
```
### Mock File Operations
```python
from unittest.mock import mock_open, patch
def test_read_file():
mock_file_content = "Hello, World!"
with patch("builtins.open", mock_open(read_data=mock_file_content)):
result = read_file("test.txt")
assert result == "Hello, World!"
def test_write_file():
m = mock_open()
with patch("builtins.open", m):
write_file("test.txt", "content")
m().write.assert_called_once_with("content")
```
### Mock Datetime
```python
from unittest.mock import patch
from datetime import datetime
@patch("mymodule.datetime")
def test_with_frozen_time(mock_datetime):
mock_datetime.now.return_value = datetime(2024, 1, 15, 12, 0, 0)
mock_datetime.side_effect = lambda *args, **kwargs: datetime(*args, **kwargs)
result = mymodule.get_timestamp()
assert result == "2024-01-15 12:00:00"
```
### Mock Class Instance
```python
from unittest.mock import patch, MagicMock
class EmailService:
def send(self, to, subject, body):
# Real implementation
pass
@patch("mymodule.EmailService")
def test_email_sent(MockEmailService):
# Configure mock instance
mock_instance = MagicMock()
MockEmailService.return_value = mock_instance
# Call function that uses EmailService
send_notification("user@example.com", "Hello!")
# Verify email was sent
mock_instance.send.assert_called_once_with(
"user@example.com",
"Notification",
"Hello!",
)
```

View File

@@ -0,0 +1,399 @@
---
name: write-unit-tests
description: Writing unit and integration tests for the tldraw SDK. Use when creating new tests, adding test coverage, or fixing failing tests in packages/editor or packages/tldraw. Covers Vitest patterns, TestEditor usage, and test file organization.
---
# Writing tests
Unit and integration tests use Vitest. Tests run from workspace directories, not the repo root.
## Test file locations
**Unit tests** - alongside source files:
```
packages/editor/src/lib/primitives/Vec.ts
packages/editor/src/lib/primitives/Vec.test.ts # Same directory
```
**Integration tests** - in `src/test/` directory:
```
packages/tldraw/src/test/SelectTool.test.ts
packages/tldraw/src/test/commands/createShape.test.ts
```
**Shape/tool tests** - alongside the implementation:
```
packages/tldraw/src/lib/shapes/arrow/ArrowShapeUtil.test.ts
packages/tldraw/src/lib/shapes/arrow/ArrowShapeTool.test.ts
```
## Which workspace to test in
- **packages/editor**: Core primitives, geometry, managers, base editor functionality
- **packages/tldraw**: Anything needing default shapes/tools (most integration tests)
```bash
cd packages/tldraw && yarn test run
cd packages/tldraw && yarn test run --grep "SelectTool"
```
## TestEditor vs Editor
Use `TestEditor` for integration tests (includes default shapes/tools):
```typescript
import { createShapeId } from '@tldraw/editor'
import { TestEditor } from './TestEditor'
let editor: TestEditor
beforeEach(() => {
editor = new TestEditor()
editor.selectAll().deleteShapes(editor.getSelectedShapeIds())
})
afterEach(() => {
editor?.dispose()
})
```
Use raw `Editor` when testing editor setup or custom configurations:
```typescript
import { Editor, createTLStore } from '@tldraw/editor'
beforeEach(() => {
editor = new Editor({
shapeUtils: [CustomShape],
bindingUtils: [],
tools: [CustomTool],
store: createTLStore({ shapeUtils: [CustomShape], bindingUtils: [] }),
getContainer: () => document.body,
})
})
```
## Common TestEditor methods
```typescript
// Pointer simulation
editor.pointerDown(x, y, options?)
editor.pointerMove(x, y, options?)
editor.pointerUp(x, y, options?)
editor.click(x, y, shapeId?)
editor.doubleClick(x, y, shapeId?)
// Keyboard simulation
editor.keyDown(key, options?)
editor.keyUp(key, options?)
// State assertions
editor.expectToBeIn('select.idle')
editor.expectToBeIn('select.crop.idle')
// Shape assertions
editor.expectShapeToMatch({ id, x, y, props: { ... } })
// Shape operations
editor.createShapes([{ id, type, x, y, props }])
editor.updateShapes([{ id, type, props }])
editor.getShape(id)
editor.select(id1, id2)
editor.selectAll()
editor.selectNone()
editor.getSelectedShapeIds()
editor.getOnlySelectedShape()
// Tool operations
editor.setCurrentTool('arrow')
editor.getCurrentToolId()
// Undo/redo
editor.undo()
editor.redo()
```
## Pointer event options
```typescript
editor.pointerDown(100, 100, {
target: 'shape', // 'canvas' | 'shape' | 'handle' | 'selection'
shape: editor.getShape(id),
})
editor.pointerDown(150, 300, {
target: 'selection',
handle: 'bottom', // 'top' | 'bottom' | 'left' | 'right' | corners
})
editor.doubleClick(550, 550, {
target: 'selection',
handle: 'bottom_right',
})
```
## Setup patterns
### Standard setup with shape IDs
```typescript
const ids = {
box1: createShapeId('box1'),
box2: createShapeId('box2'),
arrow1: createShapeId('arrow1'),
}
vi.useFakeTimers()
beforeEach(() => {
editor = new TestEditor()
editor.selectAll().deleteShapes(editor.getSelectedShapeIds())
editor.createShapes([
{ id: ids.box1, type: 'geo', x: 100, y: 100, props: { w: 100, h: 100 } },
{ id: ids.box2, type: 'geo', x: 300, y: 300, props: { w: 100, h: 100 } },
])
})
afterEach(() => {
editor?.dispose()
})
```
### Reusable props
```typescript
const imageProps = {
assetId: null,
playing: true,
url: '',
w: 1200,
h: 800,
}
editor.createShapes([
{ id: ids.imageA, type: 'image', x: 100, y: 100, props: imageProps },
{ id: ids.imageB, type: 'image', x: 500, y: 500, props: { ...imageProps, w: 600, h: 400 } },
])
```
### Helper functions
```typescript
function arrow(id = ids.arrow1) {
return editor.getShape(id) as TLArrowShape
}
function bindings(id = ids.arrow1) {
return getArrowBindings(editor, arrow(id))
}
```
## Mocking with vi.spyOn
```typescript
// Mock return value
vi.spyOn(editor, 'getIsReadonly').mockReturnValue(true)
// Mock implementation
const isHiddenSpy = vi.spyOn(editor, 'isShapeHidden')
isHiddenSpy.mockImplementation((shape) => shape.id === ids.hiddenShape)
// Verify calls
const spy = vi.spyOn(editor, 'setSelectedShapes')
editor.selectAll()
expect(spy).toHaveBeenCalled()
expect(spy).not.toHaveBeenCalled()
// Always restore
isHiddenSpy.mockRestore()
```
## Fake timers
```typescript
vi.useFakeTimers()
// Mock animation frame
window.requestAnimationFrame = (cb) => setTimeout(cb, 1000 / 60)
window.cancelAnimationFrame = (id) => clearTimeout(id)
it('handles animation', () => {
editor.alignShapes(editor.getSelectedShapeIds(), 'right')
vi.advanceTimersByTime(1000)
// Assert after animation completes
})
```
## Assertions
### Shape matching
```typescript
// Partial matching (most common)
expect(editor.getShape(id)).toMatchObject({
type: 'geo',
x: 100,
props: { w: 100 },
})
editor.expectShapeToMatch({
id: ids.box1,
x: 350,
y: 350,
})
// Floating point matching (custom matcher)
expect(result).toCloselyMatchObject({
props: { normalizedAnchor: { x: 0.5, y: 0.75 } },
})
```
### Array assertions
```typescript
expect(editor.getSelectedShapeIds()).toMatchObject([ids.box1])
expect(Array.from(selectedIds).sort()).toEqual([id1, id2, id3].sort())
expect(shapes).toContain('geo')
expect(shapes).not.toContain(ids.lockedShape)
```
### State assertions
```typescript
editor.expectToBeIn('select.idle')
editor.expectToBeIn('select.brushing')
editor.expectToBeIn('select.crop.idle')
```
## Testing undo/redo
```typescript
it('handles undo/redo', () => {
editor.doubleClick(550, 550, ids.image)
editor.expectToBeIn('select.crop.idle')
editor.updateShape({ id: ids.image, type: 'image', props: { crop: newCrop } })
editor.undo()
editor.expectToBeIn('select.crop.idle')
expect(editor.getShape(ids.image)!.props.crop).toMatchObject(originalCrop)
editor.redo()
expect(editor.getShape(ids.image)!.props.crop).toMatchObject(newCrop)
})
```
## Testing TypeScript types
```typescript
it('Uses typescript generics', () => {
expect(() => {
// @ts-expect-error - wrong props type
editor.createShape({ id, type: 'geo', props: { w: 'OH NO' } })
// @ts-expect-error - unknown prop
editor.createShape({ id, type: 'geo', props: { foo: 'bar' } })
// Valid
editor.createShape<TLGeoShape>({ id, type: 'geo', props: { w: 100 } })
}).toThrow()
})
```
## Testing custom shapes
```typescript
declare module '@tldraw/tlschema' {
export interface TLGlobalShapePropsMap {
'my-custom-shape': { w: number; h: number; text: string | undefined }
}
}
class CustomShape extends ShapeUtil<ICustomShape> {
static override type = 'my-custom-shape'
static override props: RecordProps<ICustomShape> = {
w: T.number,
h: T.number,
text: T.string.optional(),
}
getDefaultProps() {
return { w: 200, h: 200, text: '' }
}
getGeometry(shape) {
return new Rectangle2d({ width: shape.props.w, height: shape.props.h })
}
indicator() {}
component() {}
}
```
## Testing side effects
```typescript
beforeEach(() => {
editor = new TestEditor()
editor.sideEffects.registerAfterChangeHandler('instance_page_state', (prev, next) => {
if (prev.croppingShapeId !== next.croppingShapeId) {
// Handle state change
}
})
})
```
## Testing events
```typescript
it('emits wheel events', () => {
const handler = vi.fn()
editor.on('event', handler)
editor.dispatch({
type: 'wheel',
name: 'wheel',
delta: { x: 0, y: 10, z: 0 },
point: { x: 100, y: 100, z: 1 },
shiftKey: false,
// ... other modifiers
})
editor.emit('tick', 16) // Flush batched events
expect(handler).toHaveBeenCalledWith(expect.objectContaining({ name: 'wheel' }))
})
```
## Method chaining
```typescript
editor
.expectToBeIn('select.idle')
.select(ids.imageA, ids.imageB)
.doubleClick(550, 550, { target: 'selection', handle: 'bottom_right' })
.expectToBeIn('select.idle')
editor.setCurrentTool('arrow').pointerDown(0, 0).pointerMove(100, 100).pointerUp()
```
## Running tests
```bash
cd packages/tldraw && yarn test run
cd packages/tldraw && yarn test run --grep "arrow"
cd packages/editor && yarn test run --grep "Vec"
# Watch mode
cd packages/tldraw && yarn test
```
## Key patterns summary
- Use `createShapeId()` for shape IDs
- Use `vi.useFakeTimers()` for time-dependent behavior
- Clear shapes in `beforeEach`, dispose in `afterEach`
- Test in `packages/tldraw` for shapes/tools
- Use `expectToBeIn()` for state machine assertions
- Use `toMatchObject()` for partial matching
- Use `toCloselyMatchObject()` for floating point values
- Mock with `vi.spyOn()` and always `mockRestore()`

View File

@@ -0,0 +1 @@
../../.agents/skills/api-documentation-generator

1
.claude/skills/pytest Symbolic link
View File

@@ -0,0 +1 @@
../../.agents/skills/pytest

View File

@@ -127,7 +127,7 @@ spec:
rules:
- http:
paths:
- path: /ailbl
- path: /ai
pathType: Prefix
backend:
service:

View File

@@ -1,7 +1,7 @@
name: "K8S Fission Deployment"
on:
push:
branches: [ 'main', 'ai' ]
branches: ["main"]
jobs:
deployment-fission:
name: Deployment fission functions
@@ -12,67 +12,75 @@ jobs:
FISSION_VER: 1.21.0
RAKE_VER: 0.1.7
steps:
- name: ☸️ Setup kubectl
uses: azure/setup-kubectl@v4
- name: 🔄 Cache
id: cache
uses: actions/cache@v4
with:
path: |
/usr/local/bin/rake
/usr/local/bin/fission
key: ${{ runner.os }}-${{ github.event.repository.name }}-${{ hashFiles('.fission/deployment.json') }}
- name: ☘️ Configure Kubeconfig
uses: azure/k8s-set-context@v4
with:
method: kubeconfig
kubeconfig: ${{ secrets[format('{0}_KUBECONFIG', env.FISSION_PROFILE)] }}
- name: 🔄 Install Dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
curl -L "https://${{ secrets.REGISTRY_PASSWORD }}@registry.vegastar.vn/vegacloud/make/releases/download/${RAKE_VER}/rake-${RAKE_VER}-x86_64-unknown-linux-musl.tar.gz" | tar xzv -C /tmp/
curl -L "https://github.com/fission/fission/releases/download/v${FISSION_VER}/fission-v${FISSION_VER}-linux-amd64" --output /tmp/fission
install -o root -g root -m 0755 /tmp/rake-${RAKE_VER}-x86_64-unknown-linux-musl/rake /usr/local/bin/rake
install -o root -g root -m 0755 /tmp/fission /usr/local/bin/fission
fission check
# rake cfg install fission -f
- name: 🕓 Checkout the previous codes
uses: actions/checkout@v4
with:
ref: ${{ github.event.before }}
- name: ♻️ Remove the previous version
# continue-on-error: true
run: |
echo "use profile [$FISSION_PROFILE]"
mkdir -p manifests || true
rake sec detail && rake cfm detail && rake env detail && rake pkg detail && rake fn detail && rake ht detail
rake sp build -fi && rake sp down -i
- name: 🔎 Checkout repository
uses: actions/checkout@v4
- name: 🐍 Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: 🧪 Run tests
run: |
cd apps
pip install -r requirements-dev.txt
pytest tests/ -v --tb=short
- name: ✨ Deploy the new version
id: deploy
run: |
echo "use profile [$FISSION_PROFILE]"
mkdir -p manifests || true
rake sec detail && rake cfm detail && rake env detail && rake pkg detail && rake fn detail && rake ht detail
rake sp build -fi && rake sp up -i
- name: 🔔 Send notification
uses: appleboy/telegram-action@master
if: always()
with:
to: ${{ secrets.TELEGRAM_TO }}
token: ${{ secrets.TELEGRAM_TOKEN }}
format: markdown
socks5: ${{ secrets.TELEGRAM_PROXY_URL != '' && secrets.TELEGRAM_PROXY_URL || '' }}
message: |
${{ steps.deploy.outcome == 'success' && '🟢 (=^ ◡ ^=)' || '🔴 (。•́︿•̀。)' }} Install fn ${{ github.event.repository.name }}
*Msg*: `${{ github.event.commits[0].message }}`
- name: ☸️ Setup kubectl
uses: azure/setup-kubectl@v4
- name: 🐍 Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
cache-dependency-path: "apps/requirements.txt"
- name: 📥 Checkout repository
uses: actions/checkout@v4
- name: 📦 Install dependencies and run tests
run: |
python -m pip install --upgrade pip
pip install -r apps/requirements.txt
pip install pytest pytest-mock
pytest apps/tests/ -v --tb=short
- name: 🔄 Cache
id: cache
uses: actions/cache@v4
with:
path: |
/usr/local/bin/rake
/usr/local/bin/fission
key: ${{ runner.os }}-${{ github.event.repository.name }}-${{ hashFiles('.fission/deployment.json') }}
- name: ☘️ Configure Kubeconfig
uses: azure/k8s-set-context@v4
with:
method: kubeconfig
kubeconfig: ${{ secrets[format('{0}_KUBECONFIG', env.FISSION_PROFILE)] }}
- name: 🔄 Install Dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
curl -L "https://${{ secrets.REGISTRY_PASSWORD }}@registry.vegastar.vn/vegacloud/make/releases/download/${RAKE_VER}/rake-${RAKE_VER}-x86_64-unknown-linux-musl.tar.gz" | tar xzv -C /tmp/
curl -L "https://github.com/fission/fission/releases/download/v${FISSION_VER}/fission-v${FISSION_VER}-linux-amd64" --output /tmp/fission
install -o root -g root -m 0755 /tmp/rake-${RAKE_VER}-x86_64-unknown-linux-musl/rake /usr/local/bin/rake
install -o root -g root -m 0755 /tmp/fission /usr/local/bin/fission
fission check
# rake cfg install fission -f
- name: 🕓 Checkout the previous codes
uses: actions/checkout@v4
with:
ref: ${{ github.event.before }}
- name: ♻️ Remove the previous version
# continue-on-error: true
run: |
echo "use profile [$FISSION_PROFILE]"
mkdir -p manifests || true
rake sec detail && rake cfm detail && rake env detail && rake pkg detail && rake fn detail && rake ht detail
rake sp build -fi && rake sp down -i
- name: 🔎 Checkout repository
uses: actions/checkout@v4
- name: ✨ Deploy the new version
id: deploy
run: |
echo "use profile [$FISSION_PROFILE]"
mkdir -p manifests || true
rake sec detail && rake cfm detail && rake env detail && rake pkg detail && rake fn detail && rake ht detail
rake sp build -fi && rake sp up -i
- name: 🔔 Send notification
uses: appleboy/telegram-action@master
if: always()
with:
to: ${{ secrets.TELEGRAM_TO }}
token: ${{ secrets.TELEGRAM_TOKEN }}
format: markdown
socks5: ${{ secrets.TELEGRAM_PROXY_URL != '' && secrets.TELEGRAM_PROXY_URL || '' }}
message: |
${{ steps.deploy.outcome == 'success' && '🟢 (=^ ◡ ^=)' || '🔴 (。•́︿•̀。)' }} Install fn ${{ github.event.repository.name }}
*Msg*: `${{ github.event.commits[0].message }}`

View File

@@ -18,6 +18,24 @@ jobs:
run: echo "K8S_PROFILE=`echo ${GITHUB_REF_NAME:-${GITHUB_REF#refs/heads/}} | tr '[:lower:]' '[:upper:]'`" >> $GITHUB_ENV
- name: ☸️ Setup kubectl
uses: azure/setup-kubectl@v4
- name: 🐍 Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache-dependency-path: 'apps/requirements.txt'
- name: 📥 Checkout repository
uses: actions/checkout@v4
- name: 📦 Install dependencies and run tests
run: |
cd apps
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest pytest-mock
pytest tests/ -v --tb=short
- name: 🛠️ Configure Kubeconfig
uses: azure/k8s-set-context@v4
with:

File diff suppressed because it is too large Load Diff

View File

@@ -1,79 +0,0 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
This is a Fission-based serverless Python microservice for AI user administration. It runs on Kubernetes with PostgreSQL as the data store.
**Stack:** Python 3.10, Flask, Fission FaaS, PostgreSQL, Pydantic
## Build & Deploy Commands
```bash
# Build (installs Python dependencies)
cd apps && ./build.sh
# Deploy to Kubernetes (reconciles all Fission specs)
fission spec apply
# Watch and redeploy on changes
fission spec apply --watch
```
## Architecture
Two Fission functions handle all HTTP endpoints:
| Function | Routes | Operations |
|----------|--------|------------|
| `ai-admin-filter-create-user` | `GET/POST /ai/admin/users` | Filter users with pagination, create new user |
| `ai-admin-update-delete-user` | `PUT/DELETE /ai/admin/users/{UserID}` | Update/delete user by ID |
**Source structure (`apps/`):**
- `filter_insert.py` - GET (filter) & POST (create) handler with `main()` entry point
- `update_delete.py` - PUT (update) & DELETE handler with `main()` entry point
- `schemas.py` - Pydantic models (`AiUserCreate`, `AiUserUpdate`) for validation
- `helpers.py` - Database connection, K8s secrets, CORS headers, utilities
- `vault.py` - PyNaCl symmetric encryption (`encrypt_vault`/`decrypt_vault`)
**Deployment configs (`.fission/`):**
- `local-deployment.json`, `dev-deployment.json`, `test-deployment.json`, `staging-deployment.json`, `deployment.json`
**Fission specs (`specs/`):**
- `env-work-py.yaml` - Python 3.10 runtime environment
- `package-ai-work.yaml` - Build configuration
- Function and HTTP trigger definitions
## Key Patterns
**Error handling:**
```python
try:
conn = init_db_connection()
# operations
except ValidationError as e:
return jsonify({"errorCode": "VALIDATION_ERROR", "details": e.errors()}), 400, CORS_HEADERS
except IntegrityError:
return jsonify({"errorCode": "DUPLICATE_TAG", ...}), 409, CORS_HEADERS
finally:
if conn:
conn.close()
```
**Dynamic SQL filtering:** Build conditions list and values dict, join with AND for WHERE clause.
**Fission route params:** Extracted from headers (e.g., `X-Fission-Params-UserID`).
**Concurrent updates:** Uses PostgreSQL row-level locking (`FOR UPDATE`).
## Secrets & Configuration
Secrets are read from K8s mounted volumes via `helpers.get_secret()` and `helpers.get_config()`, not environment variables. PostgreSQL credentials come from the `fission-ai-work-env` secret.
## Function Configuration
- Executor: `newdeploy` (dedicated pod per function)
- Timeout: 300 seconds
- Min/Max Scale: 1
- Concurrency: 500 requests per pod

View File

@@ -24,7 +24,7 @@ def main():
"fntimeout": 300,
"http_triggers": {
"ai-admin-filter-create-user-http": {
"url": "/ailbl/ai/admin/users",
"url": "/ai/admin/users",
"methods": ["POST", "GET"]
}
}
@@ -44,40 +44,34 @@ def main():
def make_insert_request():
r"""make_insert_request() -> tuple[Response, int, dict]
"""
Handle POST request to create a new AI user.
Create a new user from the request JSON body.
Validates the request body using :class:`AiUserCreate` schema, generates
a UUID for the new user, and inserts the record into the database.
Validates the request body using AiUserCreate schema, inserts a new record
into the public.ai_user table, and returns the created user data.
Returns:
tuple: A tuple containing:
- JSON response with created user data or error details
- HTTP status code (201 on success, 400/409/500 on error)
- CORS headers dict
Raises:
ValidationError: If request body fails Pydantic validation (returns 400).
IntegrityError: If email already exists in database (returns 409).
Example::
>>> # POST /ai/admin/users
>>> # Body: {"name": "John Doe", "email": "john@example.com"}
>>> # Response: 201 Created
>>> {
... "id": "550e8400-e29b-41d4-a716-446655440000",
... "name": "John Doe",
... "email": "john@example.com",
... "created": "2024-01-01T10:00:00",
... "modified": "2024-01-01T10:00:00"
... }
tuple: (json_response, status_code, headers)
- 201: User created successfully
- 400: Validation error in request body
- 409: Duplicate entry violation
- 500: Internal server error
"""
try:
body = AiUserCreate(**(request.get_json(silent=True) or {}))
except ValidationError as e:
return jsonify({"errorCode": "VALIDATION_ERROR", "details": e.errors()}), 400, CORS_HEADERS
errors = []
for err in e.errors():
errors.append({
"loc": err.get("loc", []),
"msg": err.get("msg", ""),
"type": err.get("type", ""),
})
return (
jsonify({"errorCode": "VALIDATION_ERROR", "details": errors}),
400,
CORS_HEADERS,
)
sql = """
INSERT INTO public.ai_user (id, name, dob, email, gender)
@@ -89,13 +83,19 @@ def make_insert_request():
conn = init_db_connection()
with conn:
with conn.cursor() as cur:
cur.execute(sql, (str(uuid.uuid4()), body.name,
body.dob, body.email, body.gender))
cur.execute(
sql,
(str(uuid.uuid4()), body.name, body.dob, body.email, body.gender),
)
row = cur.fetchone()
return jsonify(db_row_to_dict(cur, row)), 201, CORS_HEADERS
except IntegrityError as e:
# vi phạm unique(tag,kind,ref)
return jsonify({"errorCode": "DUPLICATE_TAG", "details": str(e)}), 409, CORS_HEADERS
return (
jsonify({"errorCode": "DUPLICATE_TAG", "details": str(e)}),
409,
CORS_HEADERS,
)
except Exception as err:
return jsonify({"error": str(err)}), 500, CORS_HEADERS
finally:
@@ -104,43 +104,16 @@ def make_insert_request():
def make_filter_request():
r"""make_filter_request() -> Response
"""
Handle GET request to filter and list AI users.
Filter and paginate users based on query parameters.
Builds a dynamic SQL query based on filter parameters from the request,
executes the query with pagination, and returns the matching users.
Query Parameters:
page (int): Page number (0-indexed). Default: ``0``
size (int): Number of records per page. Default: ``8``
sortby (str): Field to sort by (``created`` or ``modified``).
asc (bool): Sort ascending if ``True``, descending if ``False``.
filter[ids] (list): Filter by specific user IDs.
filter[keyword] (str): Search in name and email fields.
filter[name] (str): Filter by name (case-insensitive partial match).
filter[email] (str): Filter by email (case-insensitive partial match).
filter[created_from] (str): Filter by creation date (from).
filter[created_to] (str): Filter by creation date (to).
filter[dob_from] (str): Filter by date of birth (from).
filter[dob_to] (str): Filter by date of birth (to).
Retrieves pagination parameters from request queries, executes the filter
query against the database, and returns a list of matching users.
Returns:
Response: JSON array of user records with pagination metadata.
Example::
>>> # GET /ai/admin/users?page=0&size=10&filter[name]=john
>>> # Response: 200 OK
>>> [
... {
... "id": "550e8400-e29b-41d4-a716-446655440000",
... "name": "John Doe",
... "email": "john@example.com",
... "count": 100,
... "total": 5
... }
... ]
tuple: (json_response, status_code, headers)
- 200: Successfully retrieved users list
- 500: Internal server error
"""
paging = UserPage.from_request_queries()
@@ -157,18 +130,18 @@ def make_filter_request():
def __filter_users(cursor, paging: "UserPage"):
r"""Build and execute SQL query for filtering users.
"""
Build and execute SQL query to filter users based on pagination and filter criteria.
Constructs dynamic WHERE clause from UserFilter attributes, applies sorting
and pagination, and returns matching database records.
Args:
cursor: Database cursor object for executing queries.
paging (UserPage): Pagination and filter parameters.
cursor: Database cursor for executing queries.
paging: UserPage object containing pagination and filter parameters.
Returns:
list[dict]: List of user records as dictionaries, including
``count`` (total records) and ``total`` (filtered count).
Note:
This is a private function. Use :func:`make_filter_request` instead.
list: List of user records matching the filter criteria.
"""
conditions = []
values = {}
@@ -245,20 +218,21 @@ def __filter_users(cursor, paging: "UserPage"):
@dataclasses.dataclass
class Page:
r"""Base pagination parameters for list queries.
Attributes:
page (int, optional): Page number (0-indexed). Default: ``0``
size (int, optional): Number of records per page. Default: ``8``
asc (bool, optional): Sort order. ``True`` for ascending, ``False`` for descending.
"""
page: typing.Optional[int] = None
size: typing.Optional[int] = None
asc: typing.Optional[bool] = None
@classmethod
def from_request_queries(cls) -> "Page":
"""
Create Page instance from HTTP request query parameters.
Extracts 'page', 'size', and 'asc' parameters from the request URL.
Defaults: page=0, size=8, asc=None.
Returns:
Page: A new Page instance with values from query parameters.
"""
paging = Page()
paging.page = int(request.args.get("page", 0))
paging.size = int(request.args.get("size", 8))
@@ -268,22 +242,6 @@ class Page:
@dataclasses.dataclass
class UserFilter:
r"""Filter parameters for user queries.
Attributes:
ids (list[str], optional): Filter by specific user IDs.
keyword (str, optional): Search keyword for name and email fields.
name (str, optional): Filter by name (case-insensitive partial match).
email (str, optional): Filter by email (case-insensitive partial match).
gender (str, optional): Filter by gender.
created_from (str, optional): Filter users created on or after this date.
created_to (str, optional): Filter users created on or before this date.
modified_from (str, optional): Filter users modified on or after this date.
modified_to (str, optional): Filter users modified on or before this date.
dob_from (str, optional): Filter users with DOB on or after this date.
dob_to (str, optional): Filter users with DOB on or before this date.
"""
ids: typing.Optional[typing.List[str]] = None
keyword: typing.Optional[str] = None
name: typing.Optional[str] = None
@@ -298,6 +256,15 @@ class UserFilter:
@classmethod
def from_request_queries(cls) -> "UserFilter":
"""
Create UserFilter instance from HTTP request query parameters.
Extracts filter parameters with 'filter[...]' prefix from the request URL.
Supports filtering by ids, keyword, name, email, gender, date ranges.
Returns:
UserFilter: A new UserFilter instance with values from query parameters.
"""
filter = UserFilter()
filter.ids = request.args.getlist("filter[ids]")
filter.keyword = request.args.get("filter[keyword]")
@@ -314,39 +281,12 @@ class UserFilter:
class UserSortField(str, enum.Enum):
r"""Allowed sort fields for user queries.
Attributes:
CREATED: Sort by creation timestamp.
MODIFIED: Sort by last modification timestamp.
"""
CREATED = "created"
MODIFIED = "modified"
@dataclasses.dataclass
class UserPage(Page):
r"""Pagination parameters with user-specific filters and sorting.
Extends :class:`Page` with user filtering and sorting capabilities.
Attributes:
sortby (UserSortField, optional): Field to sort results by.
See :class:`UserSortField` for allowed values.
filter (UserFilter, optional): Filter parameters for the query.
Default: Parsed from request query parameters.
Example::
>>> # Parse from request: GET /users?page=1&size=20&sortby=created&asc=true
>>> paging = UserPage.from_request_queries()
>>> paging.page
1
>>> paging.sortby
<UserSortField.CREATED: 'created'>
"""
sortby: typing.Optional[UserSortField] = None
filter: typing.Optional[UserFilter] = dataclasses.field(
default_factory=UserFilter.from_request_queries
@@ -354,6 +294,15 @@ class UserPage(Page):
@classmethod
def from_request_queries(cls) -> "UserPage":
"""
Create UserPage instance from HTTP request query parameters.
Combines pagination (page, size, asc) and filter parameters from the request URL.
Also parses 'sortby' parameter to UserSortField enum.
Returns:
UserPage: A new UserPage instance with all query parameters.
"""
base = super(UserPage, cls).from_request_queries()
paging = UserPage(**dataclasses.asdict(base))

View File

@@ -19,31 +19,6 @@ logger = logging.getLogger(__name__)
def init_db_connection():
r"""init_db_connection() -> psycopg2.extensions.connection
Initialize and return a PostgreSQL database connection.
Reads connection parameters from Kubernetes secrets and establishes
a connection to the PostgreSQL database with logging enabled.
Returns:
psycopg2.extensions.connection: Active database connection with
:class:`LoggingConnection` factory for query logging.
Raises:
Exception: If database host/port is unreachable.
Note:
Connection parameters are read from K8s secrets:
- ``PG_HOST``: Database host (default: ``127.0.0.1``)
- ``PG_PORT``: Database port (default: ``5432``)
- ``PG_DB``: Database name (default: ``postgres``)
- ``PG_USER``: Database user (default: ``postgres``)
- ``PG_PASS``: Database password (default: ``secret``)
.. warning::
Caller is responsible for closing the connection after use.
"""
db_host = get_secret("PG_HOST", "127.0.0.1")
db_port = int(get_secret("PG_PORT", 5432))
@@ -68,28 +43,6 @@ def init_db_connection():
def db_row_to_dict(cursor, row):
r"""db_row_to_dict(cursor, row) -> dict
Convert a database row tuple to a dictionary.
Uses cursor description to map column names to values. Automatically
converts :class:`datetime.datetime` values to ISO format strings.
Args:
cursor: Database cursor with ``description`` attribute containing
column metadata.
row (tuple): Row tuple from ``cursor.fetchone()`` or similar.
Returns:
dict: Dictionary with column names as keys and row values.
Example::
>>> cursor.execute("SELECT id, name, created FROM users")
>>> row = cursor.fetchone()
>>> db_row_to_dict(cursor, row)
{'id': '123', 'name': 'John', 'created': '2024-01-01T10:00:00'}
"""
record = {}
for i, column in enumerate(cursor.description):
data = row[i]
@@ -100,34 +53,10 @@ def db_row_to_dict(cursor, row):
def db_rows_to_array(cursor, rows):
r"""db_rows_to_array(cursor, rows) -> list[dict]
Convert multiple database rows to a list of dictionaries.
Args:
cursor: Database cursor with ``description`` attribute.
rows (list[tuple]): List of row tuples from ``cursor.fetchall()``.
Returns:
list[dict]: List of dictionaries, one per row.
See Also:
:func:`db_row_to_dict` for single row conversion.
"""
return [db_row_to_dict(cursor, row) for row in rows]
def get_current_namespace() -> str:
r"""get_current_namespace() -> str
Get the current Kubernetes namespace.
Reads namespace from the K8s service account file. Falls back to
``default`` namespace if file is not accessible.
Returns:
str: Current K8s namespace name.
"""
try:
with open("/var/run/secrets/kubernetes.io/serviceaccount/namespace", "r") as f:
namespace = f.read()
@@ -138,20 +67,6 @@ def get_current_namespace() -> str:
def get_secret(key: str, default=None):
r"""get_secret(key, default=None) -> str | None
Read a secret value from Kubernetes mounted volume.
Args:
key (str): Secret key name (e.g., ``"PG_HOST"``).
default: Value to return if secret is not found. Default: ``None``
Returns:
str | None: Secret value or default if not found.
Note:
Secrets are mounted at ``/secrets/{namespace}/fission-ai-work-env/{key}``
"""
namespace = get_current_namespace()
path = f"/secrets/{namespace}/{SECRET_NAME}/{key}"
try:
@@ -163,20 +78,6 @@ def get_secret(key: str, default=None):
def get_config(key: str, default=None):
r"""get_config(key, default=None) -> str | None
Read a config value from Kubernetes ConfigMap mounted volume.
Args:
key (str): Config key name.
default: Value to return if config is not found. Default: ``None``
Returns:
str | None: Config value or default if not found.
Note:
ConfigMaps are mounted at ``/configs/{namespace}/fission-ai-work-config/{key}``
"""
namespace = get_current_namespace()
path = f"/configs/{namespace}/{CONFIG_NAME}/{key}"
try:
@@ -188,26 +89,6 @@ def get_config(key: str, default=None):
def str_to_bool(input: str | None) -> bool:
r"""str_to_bool(input) -> bool | None
Convert a string to boolean value.
Args:
input (str | None): String to convert. Case-insensitive.
Returns:
bool | None: ``True`` for ``"true"``, ``False`` for ``"false"``,
``None`` for any other value.
Example::
>>> str_to_bool("true")
True
>>> str_to_bool("FALSE")
False
>>> str_to_bool("yes")
None
"""
input = input or ""
# Dictionary to map string values to boolean
BOOL_MAP = {"true": True, "false": False}
@@ -215,25 +96,6 @@ def str_to_bool(input: str | None) -> bool:
def check_port_open(ip: str, port: int, timeout: int = 30):
r"""check_port_open(ip, port, timeout=30) -> bool
Check if a TCP port is open and accepting connections.
Args:
ip (str): IP address or hostname to check.
port (int): Port number to check.
timeout (int): Connection timeout in seconds. Default: ``30``
Returns:
bool: ``True`` if port is open, ``False`` otherwise.
Example::
>>> check_port_open("127.0.0.1", 5432)
True
>>> check_port_open("localhost", 9999, timeout=5)
False
"""
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(timeout)

View File

@@ -1,5 +0,0 @@
pytest==8.3.5
pytest-mock==3.14.0
flask==3.1.0
psycopg2-binary==2.9.10
pydantic==2.10.6

View File

@@ -2,3 +2,5 @@ psycopg2-binary==2.9.10
pydantic==2.11.7
PyNaCl==1.6.0
Flask==3.1.0
pytest==8.3.4
pytest-mock==3.14.0

View File

@@ -6,28 +6,6 @@ from pydantic import BaseModel, Field, field_validator
class AiUserCreate(BaseModel):
r"""Schema for creating a new AI user.
Validates user data for the POST /ai/admin/users endpoint.
Attributes:
id (str, optional): User UUID. Auto-generated if not provided.
name (str): User's full name. Required, 1-128 characters.
email (str): User's email address. Required, max 256 characters.
Must be a valid email format.
dob (date, optional): Date of birth in YYYY-MM-DD format.
gender (str, optional): User's gender, max 10 characters.
Example::
>>> user = AiUserCreate(
... name="John Doe",
... email="john@example.com",
... dob="1990-01-15",
... gender="male"
... )
"""
id: Optional[str] = None
name: str = Field(min_length=1, max_length=128)
email: str = Field(..., max_length=256)
@@ -42,31 +20,6 @@ class AiUserCreate(BaseModel):
class AiUserUpdate(BaseModel):
r"""Schema for updating an existing AI user.
Validates user data for the PUT /ai/admin/users/{UserID} endpoint.
All fields are optional for partial updates.
Attributes:
name (str, optional): User's full name, 1-128 characters.
email (str, optional): User's email address, max 256 characters.
Must be a valid email format if provided.
dob (date, optional): Date of birth in YYYY-MM-DD format.
gender (str, optional): User's gender, max 10 characters.
Example::
>>> # Partial update - only change name
>>> update = AiUserUpdate(name="Jane Doe")
>>> # Full update
>>> update = AiUserUpdate(
... name="Jane Doe",
... email="jane@example.com",
... gender="female"
... )
"""
name: Optional[str] = Field(default=None, min_length=1, max_length=128)
email: Optional[str] = Field(default=None, max_length=256)
dob: Optional[date] = None
@@ -80,36 +33,6 @@ class AiUserUpdate(BaseModel):
class AiUserFilter(BaseModel):
r"""Schema for filtering and paginating AI users.
Used for parsing query parameters in the GET /ai/admin/users endpoint.
Attributes:
q (str, optional): Search keyword for name and email fields.
name (str, optional): Filter by name (partial match).
email (str, optional): Filter by email (partial match).
gender (str, optional): Filter by gender (exact match).
dob_from (date, optional): Filter users with DOB on or after this date.
dob_to (date, optional): Filter users with DOB on or before this date.
created_from (str, optional): Filter users created on or after this datetime.
created_to (str, optional): Filter users created on or before this datetime.
page (int): Page number (0-indexed). Default: ``0``
size (int): Records per page, 1-200. Default: ``20``
sortby (str): Field to sort by. Default: ``"modified"``
asc (bool): Sort ascending if ``True``. Default: ``False``
Example::
>>> # Parse from query string
>>> filter = AiUserFilter(
... q="john",
... page=0,
... size=10,
... sortby="created",
... asc=True
... )
"""
q: Optional[str] = None
name: Optional[str] = None
email: Optional[str] = None

View File

@@ -1 +0,0 @@
# Tests package for AI Admin API

View File

@@ -1,99 +1,5 @@
"""Shared fixtures for API handler tests."""
import sys
from pathlib import Path
from unittest.mock import MagicMock
import pytest
from flask import Flask
import os
# Add apps directory to path for imports
apps_dir = Path(__file__).parent.parent
sys.path.insert(0, str(apps_dir))
@pytest.fixture
def flask_app():
"""Create Flask app context for testing."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def app_context(flask_app):
"""Provide Flask application context."""
with flask_app.app_context():
yield flask_app
@pytest.fixture
def request_context(flask_app):
"""Provide Flask request context."""
with flask_app.test_request_context():
yield flask_app
@pytest.fixture
def mock_db_connection(mocker):
"""Mock database connection with cursor that has description attribute."""
mock_cursor = MagicMock()
mock_cursor.description = [
MagicMock(name="id"),
MagicMock(name="name"),
MagicMock(name="dob"),
MagicMock(name="email"),
MagicMock(name="gender"),
MagicMock(name="created"),
MagicMock(name="modified"),
]
# Set name attribute on description items
for i, col_name in enumerate(["id", "name", "dob", "email", "gender", "created", "modified"]):
mock_cursor.description[i].name = col_name
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mocker.patch("helpers.init_db_connection", return_value=mock_conn)
return {"connection": mock_conn, "cursor": mock_cursor}
@pytest.fixture
def mock_secrets(mocker):
"""Mock get_secret to return test values."""
secrets = {
"PG_HOST": "localhost",
"PG_PORT": "5432",
"PG_DB": "test_db",
"PG_USER": "test_user",
"PG_PASS": "test_pass",
}
mocker.patch("helpers.get_secret", side_effect=lambda key, default=None: secrets.get(key, default))
return secrets
@pytest.fixture
def sample_user_data():
"""Sample user data for testing."""
return {
"name": "Test User",
"email": "test@example.com",
"dob": "1990-01-15",
"gender": "male",
}
@pytest.fixture
def sample_db_row():
"""Sample database row tuple."""
return (
"550e8400-e29b-41d4-a716-446655440000", # id
"Test User", # name
"1990-01-15", # dob
"test@example.com", # email
"male", # gender
"2024-01-01T10:00:00", # created
"2024-01-01T10:00:00", # modified
)
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

View File

@@ -1,108 +1,86 @@
"""Tests for filter_insert.py - GET (filter) & POST (create) handlers."""
from unittest.mock import MagicMock
import pytest
from flask import Flask
from psycopg2 import IntegrityError
from unittest.mock import MagicMock, patch
class TestMain:
"""Tests for main() dispatcher function."""
"""Test cases for the main() function"""
def test_main_get_calls_filter(self, mocker):
"""Test GET request routes to make_filter_request().
def test_main_get_method(self):
"""Test main() with GET method calls make_filter_request()"""
from flask import Flask
from filter_insert import main
Given:
A GET request to the endpoint.
When:
main() is called.
Then:
make_filter_request() is invoked and returns 200.
"""
app = Flask(__name__)
with app.test_request_context(method="GET"):
mock_filter = mocker.patch(
"filter_insert.make_filter_request",
return_value=({"data": []}, 200, {}),
)
mocker.patch("filter_insert.init_db_connection")
import filter_insert
result = filter_insert.main()
with app.test_request_context("/ai/admin/users", method="GET"):
with patch("filter_insert.make_filter_request") as mock_filter:
mock_filter.return_value = ({"data": "test"}, 200, {})
result = main()
mock_filter.assert_called_once()
assert result[1] == 200
assert result == ({"data": "test"}, 200, {})
def test_main_post_calls_insert(self, mocker):
"""Test POST request routes to make_insert_request().
def test_main_post_method(self):
"""Test main() with POST method calls make_insert_request()"""
from flask import Flask
from filter_insert import main
Given:
A POST request with JSON body containing user data.
When:
main() is called.
Then:
make_insert_request() is invoked and returns 201.
"""
app = Flask(__name__)
with app.test_request_context(
method="POST",
json={"name": "Test", "email": "test@example.com"},
):
mock_insert = mocker.patch(
"filter_insert.make_insert_request",
return_value=({"id": "123"}, 201, {}),
)
import filter_insert
result = filter_insert.main()
with app.test_request_context("/ai/admin/users", method="POST", json={"name": "Test", "email": "test@example.com"}):
with patch("filter_insert.make_insert_request") as mock_insert:
mock_insert.return_value = ({"id": "123"}, 201, {})
result = main()
mock_insert.assert_called_once()
assert result[1] == 201
assert result == ({"id": "123"}, 201, {})
def test_main_invalid_method_returns_405(self, mocker):
"""Test unsupported HTTP method returns 405.
def test_main_method_not_allowed(self):
"""Test main() with unsupported HTTP method returns 405"""
from flask import Flask
from filter_insert import main, CORS_HEADERS
Given:
A PATCH request (unsupported method).
When:
main() is called.
Then:
Returns 405 Method Not Allowed with error message.
"""
app = Flask(__name__)
with app.test_request_context("/ai/admin/users", method="PUT"):
result = main()
with app.test_request_context(method="PATCH"):
import filter_insert
expected = ({"error": "Method not allow"}, 405, CORS_HEADERS)
assert result == expected
result = filter_insert.main()
def test_main_exception_handling(self):
"""Test main() catches exceptions and returns 500"""
from flask import Flask
from filter_insert import main
assert result[1] == 405
app = Flask(__name__)
with app.test_request_context("/ai/admin/users", method="GET"):
with patch("filter_insert.make_filter_request") as mock_filter:
mock_filter.side_effect = Exception("Database connection failed")
result = main()
assert result[1] == 500
assert "error" in result[0]
class TestMakeInsertRequest:
"""Tests for make_insert_request() - user creation."""
"""Test cases for make_insert_request() function"""
def test_insert_success(self, mocker, sample_user_data, sample_db_row):
"""Test successful user creation returns 201.
def test_make_insert_request_success(self):
"""Test successful user insertion"""
from flask import Flask
from filter_insert import make_insert_request, CORS_HEADERS
Given:
Valid user data with name, email, dob, and gender.
Database connection is available.
When:
POST request to create user.
Then:
Returns 201 status code.
Executes INSERT query with user data.
"""
app = Flask(__name__)
mock_request = MagicMock()
mock_request.get_json.return_value = {
"name": "John Doe",
"email": "john@example.com",
"dob": "1990-01-01",
"gender": "male"
}
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = ("uuid-123", "John Doe", "1990-01-01", "john@example.com", "male", "2024-01-01", "2024-01-01")
mock_cursor.description = [
MagicMock(name="id"),
MagicMock(name="name"),
@@ -112,367 +90,258 @@ class TestMakeInsertRequest:
MagicMock(name="created"),
MagicMock(name="modified"),
]
for i, col_name in enumerate(
["id", "name", "dob", "email", "gender", "created", "modified"]
):
mock_cursor.description[i].name = col_name
mock_cursor.fetchone.return_value = sample_db_row
mock_conn = MagicMock()
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
mock_conn.__exit__ = MagicMock(return_value=False)
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mocker.patch("filter_insert.init_db_connection", return_value=mock_conn)
with app.test_request_context("/ai/admin/users", method="POST", json=mock_request.get_json.return_value):
with patch("filter_insert.init_db_connection", return_value=mock_conn):
with patch("filter_insert.db_row_to_dict", return_value={"id": "uuid-123", "name": "John Doe"}):
result = make_insert_request()
with app.test_request_context(
method="POST",
json=sample_user_data,
content_type="application/json",
):
import filter_insert
assert result[1] == 201
assert result[2] == CORS_HEADERS
result = filter_insert.make_insert_request()
def test_make_insert_request_validation_error(self):
"""Test make_insert_request() with invalid data returns 400"""
from flask import Flask
from filter_insert import make_insert_request, CORS_HEADERS
assert result[1] == 201
mock_cursor.execute.assert_called_once()
def test_insert_validation_error_missing_name(self, mocker):
"""Test missing required field 'name' returns 400.
Given:
Request body with email but missing required 'name' field.
When:
POST request to create user.
Then:
Returns 400 Bad Request.
Response contains errorCode 'VALIDATION_ERROR'.
"""
app = Flask(__name__)
with app.test_request_context(
method="POST",
json={"email": "test@example.com"}, # missing 'name'
content_type="application/json",
):
import filter_insert
# Invalid email format - pydantic will validate and reject
request_data = {
"name": "John Doe",
"email": "invalid-email"
}
result = filter_insert.make_insert_request()
with app.test_request_context("/ai/admin/users", method="POST", json=request_data):
result = make_insert_request()
assert result[1] == 400
response_data = result[0].get_json()
assert response_data["errorCode"] == "VALIDATION_ERROR"
assert result[1] == 400
assert "errorCode" in result[0].json
assert result[0].json["errorCode"] == "VALIDATION_ERROR"
assert result[2] == CORS_HEADERS
def test_insert_validation_error_invalid_email(self, mocker):
"""Test invalid email format raises serialization error.
def test_make_insert_request_duplicate_error(self):
"""Test make_insert_request() with duplicate email returns 409"""
from flask import Flask
from filter_insert import make_insert_request, CORS_HEADERS
from psycopg2 import IntegrityError
Given:
Request body with invalid email format 'invalid-email'.
When:
POST request to create user.
Then:
Raises TypeError because ValidationError.errors() contains
ValueError which is not JSON serializable.
Note:
This test documents a bug in the source code where e.errors()
is passed directly to jsonify without sanitization.
"""
app = Flask(__name__)
with app.test_request_context(
method="POST",
json={"name": "Test", "email": "invalid-email"},
content_type="application/json",
):
import filter_insert
with pytest.raises(TypeError, match="not JSON serializable"):
filter_insert.make_insert_request()
def test_insert_duplicate_email(self, mocker, sample_user_data):
"""Test duplicate email returns 409 Conflict.
Given:
Valid user data but email already exists in database.
Database raises IntegrityError on INSERT.
When:
POST request to create user.
Then:
Returns 409 Conflict.
Response contains errorCode 'DUPLICATE_TAG'.
"""
app = Flask(__name__)
request_data = {
"name": "John Doe",
"email": "john@example.com"
}
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
mock_conn.__exit__ = MagicMock(return_value=False)
mock_conn.cursor.return_value.__enter__ = MagicMock(
side_effect=IntegrityError("duplicate key value")
)
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mocker.patch("filter_insert.init_db_connection", return_value=mock_conn)
with app.test_request_context("/ai/admin/users", method="POST", json=request_data):
with patch("filter_insert.init_db_connection", return_value=mock_conn):
# Simulate IntegrityError
mock_cursor.execute.side_effect = IntegrityError("duplicate key value violates unique constraint")
result = make_insert_request()
with app.test_request_context(
method="POST",
json=sample_user_data,
content_type="application/json",
):
import filter_insert
result = filter_insert.make_insert_request()
assert result[1] == 409
response_data = result[0].get_json()
assert response_data["errorCode"] == "DUPLICATE_TAG"
assert result[1] == 409
assert result[0].json["errorCode"] == "DUPLICATE_TAG"
assert result[2] == CORS_HEADERS
class TestMakeFilterRequest:
"""Tests for make_filter_request() - user filtering."""
"""Test cases for make_filter_request() function"""
def test_filter_empty_result(self, mocker):
"""Test filter with no matching results returns empty array.
def test_make_filter_request_success(self):
"""Test successful filter request"""
from flask import Flask
from filter_insert import make_filter_request
Given:
Database has no users matching the filter criteria.
When:
GET request to filter users.
Then:
Returns empty JSON array [].
"""
app = Flask(__name__)
mock_cursor = MagicMock()
mock_cursor.description = [
MagicMock(name="id"),
MagicMock(name="name"),
MagicMock(name="dob"),
MagicMock(name="email"),
MagicMock(name="gender"),
MagicMock(name="created"),
MagicMock(name="modified"),
MagicMock(name="count"),
MagicMock(name="total"),
]
for i, col_name in enumerate(
["id", "name", "dob", "email", "gender", "created", "modified", "count", "total"]
):
mock_cursor.description[i].name = col_name
mock_cursor.fetchall.return_value = []
with app.test_request_context("/ai/admin/users?page=0&size=8"):
with patch("filter_insert.init_db_connection") as mock_init_db:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = []
mock_conn.cursor.return_value = mock_cursor
mock_init_db.return_value = mock_conn
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
with patch("filter_insert.db_rows_to_array", return_value=[]):
result = make_filter_request()
mocker.patch("filter_insert.init_db_connection", return_value=mock_conn)
# make_filter_request returns Flask Response object (jsonify)
assert result.status_code == 200
with app.test_request_context(method="GET"):
import filter_insert
response = filter_insert.make_filter_request()
class TestUserFilter:
"""Test cases for UserFilter.from_request_queries()"""
# make_filter_request returns Response object directly (not tuple)
response_data = response.get_json()
assert response_data == []
def test_user_filter_from_queries(self):
"""Test UserFilter parses query parameters correctly"""
from flask import Flask
from filter_insert import UserFilter
def test_filter_with_pagination(self, mocker, sample_db_row):
"""Test filter with page and size parameters.
Given:
Request with query params page=1 and size=5.
When:
GET request to filter users.
Then:
SQL query uses LIMIT 5 and OFFSET 5 (page * size).
"""
app = Flask(__name__)
mock_cursor = MagicMock()
mock_cursor.description = [
MagicMock(name="id"),
MagicMock(name="name"),
MagicMock(name="dob"),
MagicMock(name="email"),
MagicMock(name="gender"),
MagicMock(name="created"),
MagicMock(name="modified"),
MagicMock(name="count"),
MagicMock(name="total"),
]
for i, col_name in enumerate(
["id", "name", "dob", "email", "gender", "created", "modified", "count", "total"]
):
mock_cursor.description[i].name = col_name
def mock_get(key, default=None):
values = {
"filter[ids]": None,
"filter[keyword]": "test",
"filter[name]": "John",
"filter[email]": "john@example.com",
"filter[gender]": "male",
"filter[created_from]": "2024-01-01",
"filter[created_to]": "2024-12-31",
"filter[modified_from]": None,
"filter[modified_to]": None,
"filter[dob_from]": None,
"filter[dob_to]": None,
}
return values.get(key, default)
# Add count and total to sample row
row_with_counts = sample_db_row + (10, 10)
mock_cursor.fetchall.return_value = [row_with_counts]
with app.test_request_context():
with patch("filter_insert.request") as mock_request:
mock_request.args.get = mock_get
mock_request.args.getlist.return_value = []
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
result = UserFilter.from_request_queries()
mocker.patch("filter_insert.init_db_connection", return_value=mock_conn)
assert result.keyword == "test"
assert result.name == "John"
assert result.email == "john@example.com"
assert result.gender == "male"
assert result.created_from == "2024-01-01"
assert result.created_to == "2024-12-31"
with app.test_request_context(method="GET", query_string={"page": "1", "size": "5"}):
import filter_insert
filter_insert.make_filter_request()
class TestPage:
"""Test cases for Page.from_request_queries()"""
# Check that execute was called with correct offset
# cursor.execute(sql, values) - values is second positional arg
call_args = mock_cursor.execute.call_args
values = call_args[0][1] # Second positional argument
assert values["limit"] == 5
assert values["offset"] == 5 # page 1 * size 5
def test_page_default_values(self):
"""Test Page uses default values when params not provided"""
from flask import Flask
from filter_insert import Page
def test_filter_with_keyword(self, mocker, sample_db_row):
"""Test filter with keyword search across name and email.
Given:
Request with query param filter[keyword]='test'.
When:
GET request to filter users.
Then:
SQL query contains ILIKE clause for keyword matching.
"""
app = Flask(__name__)
mock_cursor = MagicMock()
mock_cursor.description = [
MagicMock(name="id"),
MagicMock(name="name"),
MagicMock(name="dob"),
MagicMock(name="email"),
MagicMock(name="gender"),
MagicMock(name="created"),
MagicMock(name="modified"),
MagicMock(name="count"),
MagicMock(name="total"),
]
for i, col_name in enumerate(
["id", "name", "dob", "email", "gender", "created", "modified", "count", "total"]
):
mock_cursor.description[i].name = col_name
def mock_get(key, default=None, type=None):
values = {
"page": 0,
"size": 8,
"asc": "false",
}
return values.get(key, default)
row_with_counts = sample_db_row + (1, 1)
mock_cursor.fetchall.return_value = [row_with_counts]
with app.test_request_context():
with patch("filter_insert.request") as mock_request:
mock_request.args.get = mock_get
result = Page.from_request_queries()
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
assert result.page == 0
assert result.size == 8
assert result.asc == False or result.asc == "false"
mocker.patch("filter_insert.init_db_connection", return_value=mock_conn)
with app.test_request_context(
method="GET", query_string={"filter[keyword]": "test"}
):
import filter_insert
class TestUserPage:
"""Test cases for UserPage.from_request_queries()"""
result = filter_insert.make_filter_request()
def test_user_page_with_sortby(self):
"""Test UserPage parses sortby parameter correctly"""
from flask import Flask
from filter_insert import UserPage, UserSortField
# Check SQL contains keyword filter
call_args = mock_cursor.execute.call_args
sql = call_args[0][0]
assert "ILIKE" in sql
def test_filter_with_name(self, mocker, sample_db_row):
"""Test filter by name with case-insensitive partial match.
Given:
Request with query param filter[name]='John'.
When:
GET request to filter users.
Then:
SQL query contains 'LOWER(name) LIKE %john%'.
"""
app = Flask(__name__)
mock_cursor = MagicMock()
mock_cursor.description = [
MagicMock(name="id"),
MagicMock(name="name"),
MagicMock(name="dob"),
MagicMock(name="email"),
MagicMock(name="gender"),
MagicMock(name="created"),
MagicMock(name="modified"),
MagicMock(name="count"),
MagicMock(name="total"),
]
for i, col_name in enumerate(
["id", "name", "dob", "email", "gender", "created", "modified", "count", "total"]
):
mock_cursor.description[i].name = col_name
def mock_get(key, default=None, type=None):
values = {
"page": 0,
"size": 8,
"asc": "true",
"sortby": "created",
}
return values.get(key, default)
row_with_counts = sample_db_row + (1, 1)
mock_cursor.fetchall.return_value = [row_with_counts]
with app.test_request_context():
with patch("filter_insert.request") as mock_request:
mock_request.args.get = mock_get
mock_request.args.getlist.return_value = []
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
result = UserPage.from_request_queries()
mocker.patch("filter_insert.init_db_connection", return_value=mock_conn)
assert result.sortby == UserSortField.CREATED
assert result.asc == True or result.asc == "true"
with app.test_request_context(
method="GET", query_string={"filter[name]": "John"}
):
import filter_insert
def test_user_page_invalid_sortby(self):
"""Test UserPage handles invalid sortby gracefully"""
from flask import Flask
from filter_insert import UserPage
filter_insert.make_filter_request()
call_args = mock_cursor.execute.call_args
sql = call_args[0][0]
values = call_args[0][1] # Second positional argument
assert "LOWER(name) LIKE" in sql
assert values["name"] == "%john%"
def test_filter_with_sortby(self, mocker, sample_db_row):
"""Test filter with sortby and asc parameters.
Given:
Request with query params sortby='created' and asc='true'.
When:
GET request to filter users.
Then:
SQL query contains 'ORDER BY created ASC'.
"""
app = Flask(__name__)
mock_cursor = MagicMock()
mock_cursor.description = [
MagicMock(name="id"),
MagicMock(name="name"),
MagicMock(name="dob"),
MagicMock(name="email"),
MagicMock(name="gender"),
MagicMock(name="created"),
MagicMock(name="modified"),
MagicMock(name="count"),
MagicMock(name="total"),
]
for i, col_name in enumerate(
["id", "name", "dob", "email", "gender", "created", "modified", "count", "total"]
):
mock_cursor.description[i].name = col_name
def mock_get(key, default=None, type=None):
values = {
"page": 0,
"size": 8,
"asc": "false",
"sortby": "invalid_field",
}
return values.get(key, default)
row_with_counts = sample_db_row + (1, 1)
mock_cursor.fetchall.return_value = [row_with_counts]
with app.test_request_context():
with patch("filter_insert.request") as mock_request:
mock_request.args.get = mock_get
mock_request.args.getlist.return_value = []
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
result = UserPage.from_request_queries()
mocker.patch("filter_insert.init_db_connection", return_value=mock_conn)
assert result.sortby is None
with app.test_request_context(
method="GET", query_string={"sortby": "created", "asc": "true"}
):
import filter_insert
result = filter_insert.make_filter_request()
class TestUserSortField:
"""Test cases for UserSortField enum"""
call_args = mock_cursor.execute.call_args
sql = call_args[0][0]
assert "ORDER BY created ASC" in sql
def test_user_sort_field_values(self):
"""Test UserSortField has correct values"""
from filter_insert import UserSortField
assert UserSortField.CREATED.value == "created"
assert UserSortField.MODIFIED.value == "modified"
class TestHelpers:
"""Test cases for helper functions"""
def test_str_to_bool_true(self):
"""Test str_to_bool with true values"""
from helpers import str_to_bool
assert str_to_bool("true") is True
assert str_to_bool("True") is True
assert str_to_bool("TRUE") is True
def test_str_to_bool_false(self):
"""Test str_to_bool with false values"""
from helpers import str_to_bool
assert str_to_bool("false") is False
assert str_to_bool("False") is False
assert str_to_bool("FALSE") is False
def test_str_to_bool_none(self):
"""Test str_to_bool with invalid values returns None"""
from helpers import str_to_bool
assert str_to_bool("invalid") is None
assert str_to_bool("") is None
assert str_to_bool(None) is None
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,514 +0,0 @@
"""Tests for update_delete.py - PUT (update) & DELETE handlers."""
from unittest.mock import MagicMock
import pytest
from flask import Flask
from psycopg2 import IntegrityError
class TestMain:
"""Tests for main() dispatcher function."""
def test_main_put_calls_update(self, mocker):
"""Test PUT request routes to make_update_request().
Given:
A PUT request to the endpoint.
When:
main() is called.
Then:
make_update_request() is invoked and returns 200.
"""
app = Flask(__name__)
with app.test_request_context(method="PUT"):
mock_update = mocker.patch(
"update_delete.make_update_request",
return_value=({"id": "123"}, 200, {}),
)
import update_delete
result = update_delete.main()
mock_update.assert_called_once()
assert result[1] == 200
def test_main_delete_calls_delete(self, mocker):
"""Test DELETE request routes to make_delete_request().
Given:
A DELETE request to the endpoint.
When:
main() is called.
Then:
make_delete_request() is invoked and returns 200.
"""
app = Flask(__name__)
with app.test_request_context(method="DELETE"):
mock_delete = mocker.patch(
"update_delete.make_delete_request",
return_value=({"id": "123"}, 200, {}),
)
import update_delete
result = update_delete.main()
mock_delete.assert_called_once()
assert result[1] == 200
def test_main_invalid_method_returns_405(self, mocker):
"""Test unsupported HTTP method returns 405.
Given:
A POST request (unsupported method for this endpoint).
When:
main() is called.
Then:
Returns 405 Method Not Allowed with error message.
"""
app = Flask(__name__)
with app.test_request_context(method="POST"):
import update_delete
result = update_delete.main()
assert result[1] == 405
assert "error" in result[0]
class TestMakeUpdateRequest:
"""Tests for make_update_request() - user update."""
def test_update_missing_user_id(self, mocker):
"""Test missing X-Fission-Params-UserID header returns 400.
Given:
PUT request without X-Fission-Params-UserID header.
When:
make_update_request() is called.
Then:
Returns 400 Bad Request.
Response contains errorCode 'MISSING_USER_ID'.
"""
app = Flask(__name__)
with app.test_request_context(
method="PUT",
json={"name": "Updated Name"},
content_type="application/json",
):
import update_delete
result = update_delete.make_update_request()
assert result[1] == 400
response_data = result[0].get_json()
assert response_data["errorCode"] == "MISSING_USER_ID"
def test_update_user_not_found(self, mocker):
"""Test update non-existent user returns 404.
Given:
Valid X-Fission-Params-UserID header.
User does not exist in database (SELECT returns None).
When:
PUT request to update user.
Then:
Returns 404 Not Found.
Response contains errorCode 'USER_NOT_FOUND'.
"""
app = Flask(__name__)
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = None # User not found
mock_conn = MagicMock()
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
mock_conn.__exit__ = MagicMock(return_value=False)
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mocker.patch("update_delete.init_db_connection", return_value=mock_conn)
with app.test_request_context(
method="PUT",
headers={"X-Fission-Params-UserID": "nonexistent-id"},
json={"name": "Updated Name"},
content_type="application/json",
):
import update_delete
result = update_delete.make_update_request()
assert result[1] == 404
response_data = result[0].get_json()
assert response_data["errorCode"] == "USER_NOT_FOUND"
def test_update_success(self, mocker, sample_db_row):
"""Test successful user update returns 200.
Given:
Valid X-Fission-Params-UserID header.
User exists in database.
Valid update data in request body.
When:
PUT request to update user.
Then:
Returns 200 OK.
User data is updated in database.
"""
app = Flask(__name__)
mock_cursor = MagicMock()
mock_cursor.description = [
MagicMock(name="id"),
MagicMock(name="name"),
MagicMock(name="dob"),
MagicMock(name="email"),
MagicMock(name="gender"),
MagicMock(name="created"),
MagicMock(name="modified"),
]
for i, col_name in enumerate(
["id", "name", "dob", "email", "gender", "created", "modified"]
):
mock_cursor.description[i].name = col_name
# First fetchone returns existing user, second returns updated user
mock_cursor.fetchone.side_effect = [sample_db_row, sample_db_row]
mock_conn = MagicMock()
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
mock_conn.__exit__ = MagicMock(return_value=False)
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mocker.patch("update_delete.init_db_connection", return_value=mock_conn)
with app.test_request_context(
method="PUT",
headers={"X-Fission-Params-UserID": "550e8400-e29b-41d4-a716-446655440000"},
json={"name": "Updated Name"},
content_type="application/json",
):
import update_delete
result = update_delete.make_update_request()
assert result[1] == 200
def test_update_validation_error_invalid_email(self, mocker):
"""Test invalid email format raises serialization error.
Given:
Valid X-Fission-Params-UserID header.
Request body with invalid email format 'invalid-email'.
When:
PUT request to update user.
Then:
Raises TypeError because ValidationError.errors() contains
ValueError which is not JSON serializable.
Note:
This test documents a bug in the source code where e.errors()
is passed directly to jsonify without sanitization.
"""
app = Flask(__name__)
with app.test_request_context(
method="PUT",
headers={"X-Fission-Params-UserID": "some-id"},
json={"email": "invalid-email"},
content_type="application/json",
):
import update_delete
with pytest.raises(TypeError, match="not JSON serializable"):
update_delete.make_update_request()
def test_update_duplicate_email(self, mocker, sample_db_row):
"""Test duplicate email on update returns 409 Conflict.
Given:
Valid X-Fission-Params-UserID header.
User exists in database.
New email already exists for another user.
Database raises IntegrityError on UPDATE.
When:
PUT request to update user email.
Then:
Returns 409 Conflict.
Response contains errorCode 'DUPLICATE_USER'.
"""
app = Flask(__name__)
mock_cursor = MagicMock()
mock_cursor.description = [
MagicMock(name="id"),
MagicMock(name="name"),
MagicMock(name="dob"),
MagicMock(name="email"),
MagicMock(name="gender"),
MagicMock(name="created"),
MagicMock(name="modified"),
]
for i, col_name in enumerate(
["id", "name", "dob", "email", "gender", "created", "modified"]
):
mock_cursor.description[i].name = col_name
# First fetchone returns existing user
mock_cursor.fetchone.return_value = sample_db_row
# Second execute raises IntegrityError
mock_cursor.execute.side_effect = [None, IntegrityError("duplicate key")]
mock_conn = MagicMock()
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
mock_conn.__exit__ = MagicMock(return_value=False)
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mocker.patch("update_delete.init_db_connection", return_value=mock_conn)
with app.test_request_context(
method="PUT",
headers={"X-Fission-Params-UserID": "550e8400-e29b-41d4-a716-446655440000"},
json={"email": "duplicate@example.com"},
content_type="application/json",
):
import update_delete
result = update_delete.make_update_request()
assert result[1] == 409
response_data = result[0].get_json()
assert response_data["errorCode"] == "DUPLICATE_USER"
class TestMakeDeleteRequest:
"""Tests for make_delete_request() - user deletion."""
def test_delete_missing_user_id(self, mocker):
"""Test missing X-Fission-Params-UserID header returns 400.
Given:
DELETE request without X-Fission-Params-UserID header.
When:
make_delete_request() is called.
Then:
Returns 400 Bad Request.
Response contains errorCode 'MISSING_USER_ID'.
"""
app = Flask(__name__)
with app.test_request_context(method="DELETE"):
import update_delete
result = update_delete.make_delete_request()
assert result[1] == 400
response_data = result[0].get_json()
assert response_data["errorCode"] == "MISSING_USER_ID"
def test_delete_user_not_found(self, mocker):
"""Test delete non-existent user returns 404.
Given:
Valid X-Fission-Params-UserID header.
User does not exist in database (SELECT returns None).
When:
DELETE request to delete user.
Then:
Returns 404 Not Found.
"""
app = Flask(__name__)
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = None # User not found
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mocker.patch("update_delete.init_db_connection", return_value=mock_conn)
with app.test_request_context(
method="DELETE",
headers={"X-Fission-Params-UserID": "nonexistent-id"},
):
import update_delete
result = update_delete.make_delete_request()
assert result[1] == 404
def test_delete_success(self, mocker, sample_db_row):
"""Test successful user deletion returns 200.
Given:
Valid X-Fission-Params-UserID header.
User exists in database.
When:
DELETE request to delete user.
Then:
Returns 200 OK.
User is deleted from database.
Response contains deleted user data.
"""
app = Flask(__name__)
mock_cursor = MagicMock()
mock_cursor.description = [
MagicMock(name="id"),
MagicMock(name="name"),
MagicMock(name="dob"),
MagicMock(name="email"),
MagicMock(name="gender"),
MagicMock(name="created"),
MagicMock(name="modified"),
]
for i, col_name in enumerate(
["id", "name", "dob", "email", "gender", "created", "modified"]
):
mock_cursor.description[i].name = col_name
# First fetchone checks existence, second returns deleted row
mock_cursor.fetchone.side_effect = [(1,), sample_db_row]
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mocker.patch("update_delete.init_db_connection", return_value=mock_conn)
with app.test_request_context(
method="DELETE",
headers={"X-Fission-Params-UserID": "550e8400-e29b-41d4-a716-446655440000"},
):
import update_delete
result = update_delete.make_delete_request()
assert result[1] == 200
class TestUpdatePartialFields:
"""Tests for partial field updates."""
def test_update_only_name(self, mocker, sample_db_row):
"""Test update only name field.
Given:
Valid X-Fission-Params-UserID header.
User exists in database.
Request body contains only 'name' field.
When:
PUT request to update user.
Then:
Returns 200 OK.
UPDATE SQL only includes name field (plus modified timestamp).
"""
app = Flask(__name__)
mock_cursor = MagicMock()
mock_cursor.description = [
MagicMock(name="id"),
MagicMock(name="name"),
MagicMock(name="dob"),
MagicMock(name="email"),
MagicMock(name="gender"),
MagicMock(name="created"),
MagicMock(name="modified"),
]
for i, col_name in enumerate(
["id", "name", "dob", "email", "gender", "created", "modified"]
):
mock_cursor.description[i].name = col_name
mock_cursor.fetchone.side_effect = [sample_db_row, sample_db_row]
mock_conn = MagicMock()
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
mock_conn.__exit__ = MagicMock(return_value=False)
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mocker.patch("update_delete.init_db_connection", return_value=mock_conn)
with app.test_request_context(
method="PUT",
headers={"X-Fission-Params-UserID": "550e8400-e29b-41d4-a716-446655440000"},
json={"name": "New Name Only"},
content_type="application/json",
):
import update_delete
result = update_delete.make_update_request()
assert result[1] == 200
# Check that UPDATE SQL contains name field
call_args = mock_cursor.execute.call_args_list[-1]
sql = call_args[0][0]
assert "name=" in sql
def test_update_multiple_fields(self, mocker, sample_db_row):
"""Test update multiple fields at once.
Given:
Valid X-Fission-Params-UserID header.
User exists in database.
Request body contains name, email, and gender fields.
When:
PUT request to update user.
Then:
Returns 200 OK.
UPDATE SQL includes all three fields.
"""
app = Flask(__name__)
mock_cursor = MagicMock()
mock_cursor.description = [
MagicMock(name="id"),
MagicMock(name="name"),
MagicMock(name="dob"),
MagicMock(name="email"),
MagicMock(name="gender"),
MagicMock(name="created"),
MagicMock(name="modified"),
]
for i, col_name in enumerate(
["id", "name", "dob", "email", "gender", "created", "modified"]
):
mock_cursor.description[i].name = col_name
mock_cursor.fetchone.side_effect = [sample_db_row, sample_db_row]
mock_conn = MagicMock()
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
mock_conn.__exit__ = MagicMock(return_value=False)
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mocker.patch("update_delete.init_db_connection", return_value=mock_conn)
with app.test_request_context(
method="PUT",
headers={"X-Fission-Params-UserID": "550e8400-e29b-41d4-a716-446655440000"},
json={"name": "New Name", "email": "new@example.com", "gender": "female"},
content_type="application/json",
):
import update_delete
result = update_delete.make_update_request()
assert result[1] == 200
call_args = mock_cursor.execute.call_args_list[-1]
sql = call_args[0][0]
assert "name=" in sql
assert "email=" in sql
assert "gender=" in sql

View File

@@ -1,3 +1,5 @@
from flask import current_app, jsonify, request
from helpers import CORS_HEADERS, db_row_to_dict, init_db_connection
from psycopg2 import IntegrityError
@@ -13,7 +15,7 @@ def main():
"fntimeout": 300,
"http_triggers": {
"ai-admin-update-delete-user-http": {
"url": "/ailbl/ai/admin/users/{UserID}",
"url": "/ai/admin/users/{UserID}",
"methods": ["DELETE", "PUT"]
}
}
@@ -33,40 +35,6 @@ def main():
def make_update_request():
r"""make_update_request() -> tuple[Response, int, dict]
Update an existing user by ID.
Retrieves the user ID from ``X-Fission-Params-UserID`` header, validates
the request body using :class:`AiUserUpdate` schema, and performs a
partial update on the user record.
Uses row-level locking (``SELECT ... FOR UPDATE``) to prevent concurrent
modification conflicts.
Returns:
tuple: A tuple containing:
- JSON response with updated user data or error details
- HTTP status code (200 on success, 400/404/409 on error)
- CORS headers dict
Raises:
ValidationError: If request body fails Pydantic validation (returns 400).
IntegrityError: If email conflicts with another user (returns 409).
Example::
>>> # PUT /ai/admin/users/550e8400-e29b-41d4-a716-446655440000
>>> # Header: X-Fission-Params-UserID: 550e8400-e29b-41d4-a716-446655440000
>>> # Body: {"name": "Jane Doe"}
>>> # Response: 200 OK
>>> {
... "id": "550e8400-e29b-41d4-a716-446655440000",
... "name": "Jane Doe",
... "email": "john@example.com",
... "modified": "2024-01-02T10:00:00"
... }
"""
user_id = request.headers.get("X-Fission-Params-UserID")
if not user_id:
return jsonify({"errorCode": "MISSING_USER_ID"}), 400, CORS_HEADERS
@@ -129,19 +97,6 @@ def make_update_request():
def __delete_user(cursor, id: str):
r"""Delete a user from the database by ID.
Args:
cursor: Database cursor object for executing queries.
id (str): UUID of the user to delete.
Returns:
dict | str: User data dict if deleted successfully,
or ``"USER_NOT_FOUND"`` string if user doesn't exist.
Note:
This is a private function. Use :func:`make_delete_request` instead.
"""
cursor.execute("SELECT 1 FROM ai_user WHERE id = %(id)s", {"id": id})
if not cursor.fetchone():
return "USER_NOT_FOUND"
@@ -151,30 +106,7 @@ def __delete_user(cursor, id: str):
return db_row_to_dict(cursor, row)
def make_delete_request():
r"""make_delete_request() -> tuple[Response, int, dict]
Delete a user by ID.
Retrieves the user ID from ``X-Fission-Params-UserID`` header and
deletes the user from the database if found.
Returns:
tuple: A tuple containing:
- JSON response with deleted user data or error details
- HTTP status code (200 on success, 400/404/500 on error)
- CORS headers dict (may be omitted on some error responses)
Example::
>>> # DELETE /ai/admin/users/550e8400-e29b-41d4-a716-446655440000
>>> # Header: X-Fission-Params-UserID: 550e8400-e29b-41d4-a716-446655440000
>>> # Response: 200 OK
>>> {
... "id": "550e8400-e29b-41d4-a716-446655440000",
... "name": "John Doe",
... "email": "john@example.com"
... }
"""
user_id = request.headers.get("X-Fission-Params-UserID")
if not user_id:
return jsonify({"errorCode": "MISSING_USER_ID"}), 400, CORS_HEADERS