diff --git a/src/alembic/env.py b/src/alembic/env.py index 7605c32..3693ae0 100644 --- a/src/alembic/env.py +++ b/src/alembic/env.py @@ -12,6 +12,7 @@ import nominees.tables import officers.tables import candidates.tables +import event.tables from alembic import context # this is the Alembic Config object, which provides diff --git a/src/alembic/versions/4928dc3f0b07_create_event_table.py b/src/alembic/versions/4928dc3f0b07_create_event_table.py new file mode 100644 index 0000000..e2ffe3f --- /dev/null +++ b/src/alembic/versions/4928dc3f0b07_create_event_table.py @@ -0,0 +1,43 @@ +"""create_event_table + +Revision ID: 4928dc3f0b07 +Revises: 0a2c458d1ddd +Create Date: 2026-05-24 17:39:22.538239 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '4928dc3f0b07' +down_revision: Union[str, None] = '0a2c458d1ddd' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('event_info', + sa.Column('eid', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('name', sa.String(length=64), nullable=False), + sa.Column('start_time', sa.DateTime(timezone=True), nullable=False), + sa.Column('end_time', sa.DateTime(timezone=True), nullable=False), + sa.Column('frequency', sa.String(length=64), server_default=sa.text("'NONE'"), nullable=True), + sa.Column('repeat_start_date', sa.Date(), nullable=True), + sa.Column('repeat_end_date', sa.Date(), nullable=True), + sa.CheckConstraint("frequency IN ('NONE', 'DAILY', 'WEEKLY', 'MONTHLY', 'SEMESTERLY', 'YEARLY')", name=op.f('ck_event_info_valid_frequency_value')), + sa.CheckConstraint('repeat_start_date < repeat_end_date', name=op.f('ck_event_info_check_repeat_start_date_before_repeat_end_date')), + sa.CheckConstraint('start_time < end_time', name=op.f('ck_event_info_check_start_time_before_end_time')), + sa.PrimaryKeyConstraint('eid', name=op.f('pk_event_info')) + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('event_info') + # ### end Alembic commands ### diff --git a/src/event/constants.py b/src/event/constants.py new file mode 100644 index 0000000..8764031 --- /dev/null +++ b/src/event/constants.py @@ -0,0 +1,10 @@ +from enum import StrEnum + + +class EventFrequencyEnum(StrEnum): + NONE = "NONE" + DAILY = "DAILY" + WEEKLY = "WEEKLY" + MONTHLY = "MONTHLY" + SEMESTERLY = "SEMESTERLY" + YEARLY = "YEARLY" diff --git a/src/event/crud.py b/src/event/crud.py new file mode 100644 index 0000000..ece1124 --- /dev/null +++ b/src/event/crud.py @@ -0,0 +1,58 @@ +from collections.abc import Sequence +from datetime import date, datetime + +from sqlalchemy import and_, delete, extract, or_, select +from sqlalchemy.ext.asyncio import AsyncSession + +from event.tables import EventDB + + +async def get_all_events(db_session: AsyncSession) -> Sequence[EventDB]: + events = (await db_session.scalars(select(EventDB))).all() + return events + + +async def get_events_for_this_year( + db_session: AsyncSession, + year: int, +) -> Sequence[EventDB]: + events = ( + await db_session.scalars( + select(EventDB).where( + or_(extract("year", EventDB.start_time) == year, extract("year", EventDB.end_time) == year) + ) + ) + ).all() + return events + + +async def get_events_for_this_year_month( + db_session: AsyncSession, + year: int, + month: int, +) -> Sequence[EventDB]: + events = ( + await db_session.scalars( + select(EventDB).where( + or_( + and_(extract("year", EventDB.start_time) == year, extract("month", EventDB.start_time) == month), + and_(extract("year", EventDB.end_time) == year, extract("month", EventDB.end_time) == month), + ) + ) + ) + ).all() + return events + + +async def get_event_by_eid(db_session: AsyncSession, eid: int) -> EventDB | None: + return (await db_session.execute(select(EventDB).where(EventDB.eid == eid))).scalar_one_or_none() + + +async def create_event(db_session: AsyncSession, info: EventDB): + db_session.add(info) + + +async def delete_event(db_session: AsyncSession, eid: int): + result = await db_session.execute(delete(EventDB).where(EventDB.eid == eid)) + # Return the number of rows affected + return result.rowcount diff --git a/src/event/models.py b/src/event/models.py new file mode 100644 index 0000000..40c8977 --- /dev/null +++ b/src/event/models.py @@ -0,0 +1,54 @@ +import datetime + +from pydantic import BaseModel, ConfigDict, model_validator + +from event.constants import EventFrequencyEnum + + +class BaseEvent(BaseModel): + name: str + start_time: datetime.datetime + end_time: datetime.datetime + description: str | None = None + frequency: EventFrequencyEnum | None = None + repeat_start_date: datetime.date | None = None + repeat_end_date: datetime.date | None = None + + @model_validator(mode="after") + def validate_time_range(self) -> "BaseEvent": + if self.start_time >= self.end_time: + raise ValueError("The event start must be before the event end") + + if self.repeat_start_date and self.repeat_end_date: + if self.repeat_start_date > self.repeat_end_date: + raise ValueError("The event repeat start date must be before the end date") + + if (self.repeat_start_date is None) != (self.repeat_end_date is None): + raise ValueError("The event must have both repeat start and repeat end or have neither.") + + return self + + +class Event(BaseEvent): + model_config = ConfigDict(from_attributes=True) + eid: int + + +class EventCreate(BaseEvent): + pass + + +class EventUpdate(BaseModel): + model_config = ConfigDict(extra="forbid") + name: str | None = None + start_time: datetime.datetime | None = None + end_time: datetime.datetime | None = None + description: str | None = None + frequency: EventFrequencyEnum | None = None + repeat_start_date: datetime.date | None = None + repeat_end_date: datetime.date | None = None + + +class EventDelete(BaseModel): + result: bool + eid: int diff --git a/src/event/tables.py b/src/event/tables.py new file mode 100644 index 0000000..2981cb2 --- /dev/null +++ b/src/event/tables.py @@ -0,0 +1,26 @@ +from datetime import date, datetime + +from sqlalchemy import CheckConstraint, Date, DateTime, Integer, String, Text, text +from sqlalchemy.orm import Mapped, mapped_column + +from database import Base +from event.constants import EventFrequencyEnum + + +class EventDB(Base): + __tablename__ = "event_info" + + eid: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + description: Mapped[str] = mapped_column(Text, nullable=True) + name: Mapped[str] = mapped_column(String(64)) + start_time: Mapped[datetime] = mapped_column(DateTime(timezone=True)) + end_time: Mapped[datetime] = mapped_column(DateTime(timezone=True)) + frequency: Mapped[EventFrequencyEnum] = mapped_column(String(64), server_default=text("'NONE'"), nullable=True) + repeat_start_date: Mapped[date] = mapped_column(Date, nullable=True) + repeat_end_date: Mapped[date] = mapped_column(Date, nullable=True) + + __table_args__ = ( + CheckConstraint("start_time < end_time", name="check_start_time_before_end_time"), + CheckConstraint("repeat_start_date < repeat_end_date", name="check_repeat_start_date_before_repeat_end_date"), + CheckConstraint(frequency.in_([e.value for e in EventFrequencyEnum]), name="valid_frequency_value") + ) diff --git a/src/event/urls.py b/src/event/urls.py new file mode 100644 index 0000000..76968d6 --- /dev/null +++ b/src/event/urls.py @@ -0,0 +1,134 @@ +from datetime import date, datetime + +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.encoders import jsonable_encoder +from fastapi.responses import JSONResponse +from pydantic import ValidationError + +import database +import event.crud +from dependencies import perm_admin +from event.models import Event, EventCreate, EventDelete, EventUpdate +from event.tables import EventDB +from utils.shared_models import DetailModel, SuccessResponse + +router = APIRouter( + prefix="/event", + tags=["event"], +) + + +@router.get( + "", + description="Get all events", + response_model=list[Event], + operation_id="get_all_events", +) +async def get_all_events( + db_session: database.DBSession, +): + events_list = await event.crud.get_all_events(db_session) + + return events_list + + +@router.get( + "/{year}", + description="Get events that start OR end in this year", + response_model=list[Event], + operation_id="get_events_for_this_year", +) +async def get_events_for_this_year( + db_session: database.DBSession, + year: int, +): + events_list = await event.crud.get_events_for_this_year(db_session, year) + + return events_list + + +@router.get( + "/{year}/{month}", + description="Get events that start OR end in the given year and month", + response_model=list[Event], + operation_id="get_events_for_this_year_month", +) +async def get_events_for_this_year_month(db_session: database.DBSession, year: int, month: int): + events_list = await event.crud.get_events_for_this_year_month(db_session, year, month) + + return events_list + + +@router.post( + "", + description="Create a new event", + response_model=Event, + status_code=status.HTTP_201_CREATED, + responses={ + 500: {"description": "failed to fetch new event", "model": DetailModel}, + }, + operation_id="create_event", + dependencies=[Depends(perm_admin)], +) +async def create_event(db_session: database.DBSession, body: EventCreate): + new_event = EventDB(**body.model_dump()) + await event.crud.create_event( + db_session, + new_event, + ) + + await db_session.commit() + await db_session.refresh(new_event) + + return new_event + + +@router.patch( + "/{eid}", + description="Update an Event detail", + response_model=Event, + responses={404: {"description": "Event doesn't exist."}}, + operation_id="update_event", + dependencies=[Depends(perm_admin)], +) +async def update_event(db_session: database.DBSession, eid: int, body: EventUpdate): + db_event = await event.crud.get_event_by_eid(db_session, eid) + if db_event is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Event doesn't exist.") + + db_data = Event.model_validate(db_event).model_dump() + patch_data = body.model_dump(exclude_unset=True) + + merged_data = {**db_data, **patch_data} + try: + Event.model_validate(merged_data) + except ValidationError as e: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=jsonable_encoder(e.errors()) + ) from e + + for key, value in patch_data.items(): + setattr(db_event, key, value) + + await db_session.commit() + await db_session.refresh(db_event) + + return db_event + + +@router.delete( + "/{eid}", + description="Delete an event", + response_model=EventDelete, + responses={404: {"description": "Event doesn't exist."}}, + operation_id="delete_event", + dependencies=[Depends(perm_admin)], +) +async def delete_event(db_session: database.DBSession, eid: int): + rows_deleted = await event.crud.delete_event(db_session, eid) + + if rows_deleted == 0: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Event doesn't exist.") + + await db_session.commit() + return EventDelete(result=True, eid=eid) diff --git a/src/main.py b/src/main.py index 66734c0..60bb5a7 100755 --- a/src/main.py +++ b/src/main.py @@ -10,6 +10,7 @@ import candidates.urls import database import elections.urls +import event.urls import nominees.urls import officers.urls import permission.urls @@ -58,6 +59,7 @@ app.include_router(nominees.urls.router) app.include_router(officers.urls.router) app.include_router(permission.urls.router) +app.include_router(event.urls.router) @app.get("/") diff --git a/tests/wip/test_github.py b/tests/wip/test_github.py index 671aefc..054ca91 100644 --- a/tests/wip/test_github.py +++ b/tests/wip/test_github.py @@ -4,11 +4,13 @@ # NOTE: must export API key to use github api (mostly...) + @pytest.mark.asyncio async def test__list_users(): member_list = await github.internals.list_members() print(member_list) + @pytest.mark.asyncio async def test__get_user_by_name(): user = await github.internals.get_user_by_username("EarthenSky")