""" Stocks API Endpoints """ from typing import List, Optional from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func from app.core.database import get_db from app.models.stock import Stock from app.schemas.stock import StockResponse, StockCreate, StockWithPrice router = APIRouter() @router.get("/", response_model=List[StockResponse]) async def list_stocks( db: AsyncSession = Depends(get_db), sector: Optional[str] = Query(None, description="Filter by sector"), industry: Optional[str] = Query(None, description="Filter by industry"), search: Optional[str] = Query(None, description="Search by symbol or name"), skip: int = Query(0, ge=0), limit: int = Query(50, ge=1, le=100), ): """List all tracked stocks with optional filters.""" query = select(Stock).where(Stock.is_active == True) if sector: query = query.where(Stock.sector == sector) if industry: query = query.where(Stock.industry == industry) if search: search_term = f"%{search}%" query = query.where( (Stock.symbol.ilike(search_term)) | (Stock.name.ilike(search_term)) ) query = query.offset(skip).limit(limit).order_by(Stock.symbol) result = await db.execute(query) return result.scalars().all() @router.get("/sectors") async def list_sectors(db: AsyncSession = Depends(get_db)): """Get list of all unique sectors.""" query = select(Stock.sector).distinct().where(Stock.is_active == True) result = await db.execute(query) sectors = [row[0] for row in result.fetchall() if row[0]] return {"sectors": sorted(sectors)} @router.get("/industries") async def list_industries( db: AsyncSession = Depends(get_db), sector: Optional[str] = Query(None, description="Filter by sector"), ): """Get list of all unique industries.""" query = select(Stock.industry).distinct().where(Stock.is_active == True) if sector: query = query.where(Stock.sector == sector) result = await db.execute(query) industries = [row[0] for row in result.fetchall() if row[0]] return {"industries": sorted(industries)} @router.get("/{symbol}", response_model=StockWithPrice) async def get_stock( symbol: str, db: AsyncSession = Depends(get_db), ): """Get detailed stock information including latest price.""" query = select(Stock).where(Stock.symbol == symbol.upper()) result = await db.execute(query) stock = result.scalar_one_or_none() if not stock: raise HTTPException(status_code=404, detail=f"Stock {symbol} not found") # TODO: Add latest price from stock_prices table return stock @router.post("/", response_model=StockResponse) async def add_stock( stock: StockCreate, db: AsyncSession = Depends(get_db), ): """Add a new stock to track.""" # Check if already exists existing = await db.execute( select(Stock).where(Stock.symbol == stock.symbol.upper()) ) if existing.scalar_one_or_none(): raise HTTPException( status_code=400, detail=f"Stock {stock.symbol} already exists" ) db_stock = Stock( symbol=stock.symbol.upper(), name=stock.name, sector=stock.sector, industry=stock.industry, exchange=stock.exchange, country=stock.country, ) db.add(db_stock) await db.commit() await db.refresh(db_stock) return db_stock @router.delete("/{symbol}") async def remove_stock( symbol: str, db: AsyncSession = Depends(get_db), ): """Remove a stock from tracking (soft delete).""" query = select(Stock).where(Stock.symbol == symbol.upper()) result = await db.execute(query) stock = result.scalar_one_or_none() if not stock: raise HTTPException(status_code=404, detail=f"Stock {symbol} not found") stock.is_active = False await db.commit() return {"message": f"Stock {symbol} removed from tracking"}