diff --git a/models.py b/models.py index 65f5895..67bb900 100644 --- a/models.py +++ b/models.py @@ -1,6 +1,5 @@ from sqlalchemy import Boolean, Column, ForeignKey, Integer, Interval, String from sqlalchemy.orm import relationship - from database import Base diff --git a/schemas.py b/schemas.py index 7341ac2..024f7af 100644 --- a/schemas.py +++ b/schemas.py @@ -1,6 +1,7 @@ from datetime import timedelta -from typing import List +from typing import List, Optional +from fastapi import UploadFile from pydantic import BaseModel diff --git a/settings.py b/settings.py index db5f47f..4787712 100644 --- a/settings.py +++ b/settings.py @@ -1,12 +1,3 @@ -from os import environ -from pathlib import Path - -from starlette.config import Config - -__config = Config(environ.get("SIMPLEPODCAST_CONFIG", ".env")) - -UPLOAD_DIR: Path = __config.get("UPLOAD_DIR", Path, "data/uploads") -SQLALCHEMY_DATABASE_URL: str = __config.get( - "UPLOAD_DIR", str, "sqlite:///./data/test.sqlite3" -) -PUBLIC_URL: str = __config.get("PUBLIC_URL", str, "http://localhost:8000") +UPLOAD_DIR = "data/uploads" +SQLALCHEMY_DATABASE_URL = "sqlite:///./data/test.sqlite3" +BASE_URL = "http://localhost:8000" \ No newline at end of file diff --git a/simplepodcast.py b/simplepodcast.py index 8ce4068..13bcd8a 100644 --- a/simplepodcast.py +++ b/simplepodcast.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Dict import podgen + from fastapi import FastAPI, Depends, HTTPException, UploadFile, Form, File from sqlalchemy.orm import Session from starlette.requests import Request @@ -13,8 +14,8 @@ import models import schemas import utils from database import SessionLocal, engine -from schemas import Podcast, PodcastBase, Episode -from settings import UPLOAD_DIR, PUBLIC_URL +from schemas import Podcast, PodcastBase, EpisodeCreate, Episode +from settings import UPLOAD_DIR, BASE_URL Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) @@ -41,13 +42,9 @@ def get_db(request: Request): @app.get("/podcast") -def list_podcasts( - offset: int = 0, limit: int = 100, db: Session = Depends(get_db) -) -> Dict[str, str]: +def list_podcasts(offset: int = 0, limit: int = 100, db: Session = Depends(get_db)) -> Dict[str, str]: db_podcasts = utils.get_all_podcasts(db, offset, limit) - return { - podcast.name: f"{PUBLIC_URL}/podcast/{podcast.id}" for podcast in db_podcasts - } + return {podcast.name: f"{BASE_URL}/podcast/{podcast.id}" for podcast in db_podcasts} @app.get("/podcast/{podcast_id}", response_model=schemas.Podcast) @@ -64,22 +61,11 @@ def create_podcast(podcast: PodcastBase, db: Session = Depends(get_db)) -> Podca @app.post("/podcast/{podcast_id}/episode", response_model=schemas.Episode) -def create_episode( - podcast_id: int, - upload_file: UploadFile = File(...), - summary: str = Form(...), - long_summary: str = Form(...), - title: str = Form(...), - subtitle: str = Form(...), - duration: datetime.timedelta = Form(...), - db: Session = Depends(get_db), -) -> Episode: +def create_episode(podcast_id: int, upload_file: UploadFile = File(...), summary: str = Form(...), long_summary: str = Form(...), title: str = Form(...), subtitle: str = Form(...), duration: datetime.timedelta = Form(...), db: Session = Depends(get_db)) -> Episode: db_podcast = utils.get_podcast(db, podcast_id) if db_podcast is None: raise HTTPException(status_code=404, detail="Podcast not found") - return utils.create_episode( - db, podcast_id, summary, long_summary, title, subtitle, duration, upload_file - ) + return utils.create_episode(db, podcast_id, summary, long_summary, title, subtitle, duration, upload_file) app.mount("/download", StaticFiles(directory=str(UPLOAD_DIR)), name="download") @@ -93,7 +79,7 @@ def read_podcast_feed(podcast_id: int, db: Session = Depends(get_db)): p = podgen.Podcast( name=db_podcast.name, - website=PUBLIC_URL, + website=BASE_URL, description=db_podcast.description, explicit=db_podcast.explicit, ) @@ -104,7 +90,9 @@ def read_podcast_feed(podcast_id: int, db: Session = Depends(get_db)): title=db_episode.title, subtitle=db_episode.subtitle, media=podgen.Media( - url=db_episode.url, size=db_episode.size, duration=db_episode.duration + url=db_episode.url, + size=db_episode.size, + duration=datetime.timedelta(db_episode.duration), ), ) for db_episode in db_podcast.episodes diff --git a/utils.py b/utils.py index 989fd47..fd734ef 100644 --- a/utils.py +++ b/utils.py @@ -1,16 +1,16 @@ -import logging import os import shutil from datetime import timedelta from pathlib import Path from typing import List +import aiofiles from fastapi import HTTPException, UploadFile from sqlalchemy.orm import Session import models import schemas -from settings import UPLOAD_DIR, PUBLIC_URL +from settings import UPLOAD_DIR, BASE_URL def get_all_podcasts( @@ -34,14 +34,7 @@ def create_podcast(db: Session, podcast: schemas.PodcastBase) -> models.Podcast: def create_episode( - db: Session, - podcast_id: int, - summary: str, - long_summary: str, - title: str, - subtitle: str, - duration: timedelta, - upload_file: UploadFile, + db: Session, podcast_id: int, summary: str, long_summary: str, title: str, subtitle: str, duration: timedelta, upload_file: UploadFile ) -> models.Episode: db_episode = models.Episode( summary=summary, @@ -57,31 +50,18 @@ def create_episode( db.commit() db.refresh(db_episode) - filename = upload_file.filename - upload_dir = Path(UPLOAD_DIR) / f"podcast_{podcast_id}" / f"episode_{db_episode.id}" - upload_path = upload_dir / filename + upload_path = upload_dir / episode.filename if not upload_dir == upload_path.parent: raise HTTPException(status_code=404, detail="Invalid filename") try: - upload_dir.mkdir(parents=True, exist_ok=True) with upload_path.open("wb") as buffer: shutil.copyfileobj(upload_file.file, buffer) - except OSError as e: - db.delete(db_episode) - upload_file.file.close() - logging.exception("Unable to store upload to disk. %s", e) - raise HTTPException(500, "Unable to store upload to disk. " + str(e)) finally: upload_file.file.close() - db_episode.url = "%s/download/podcast_%d/episode_%d/%s" % ( - PUBLIC_URL, - podcast_id, - db_episode.id, - filename, - ) + db_episode.url = f"{BASE_URL}/download/podcast_{podcast_id}/episode_{db_episode.id}/{episode.filename}" db_episode.size = os.path.getsize(str(upload_path)) db.commit() db.refresh(db_episode)