Implement recurring transactions

This commit is contained in:
William Brawner 2021-09-15 14:03:50 -06:00
parent 1ab9af9d17
commit a9611eee23
13 changed files with 201 additions and 143 deletions

View file

@ -10,15 +10,18 @@ import io.ktor.auth.*
import io.ktor.http.* import io.ktor.http.*
import io.ktor.response.* import io.ktor.response.*
import io.ktor.util.pipeline.* import io.ktor.util.pipeline.*
import java.time.Instant
suspend inline fun PipelineContext<Unit, ApplicationCall>.requireBudgetWithPermission( suspend inline fun PipelineContext<Unit, ApplicationCall>.requireBudgetWithPermission(
permissionRepository: PermissionRepository, permissionRepository: PermissionRepository,
userId: String, userId: String,
budgetId: String, budgetId: String?,
permission: Permission, permission: Permission,
otherwise: () -> Unit otherwise: () -> Unit
) { ) {
if (budgetId.isNullOrBlank()) {
errorResponse(HttpStatusCode.BadRequest, "budgetId is required")
return
}
permissionRepository.findAll( permissionRepository.findAll(
userId = userId, userId = userId,
budgetIds = listOf(budgetId) budgetIds = listOf(budgetId)
@ -57,5 +60,3 @@ suspend inline fun PipelineContext<Unit, ApplicationCall>.errorResponse(
call.respond(httpStatusCode, ErrorResponse(message)) call.respond(httpStatusCode, ErrorResponse(message))
}?: call.respond(httpStatusCode) }?: call.respond(httpStatusCode)
} }
fun String.toInstant(): Instant = Instant.parse(this)

View file

@ -1,82 +1,80 @@
package com.wbrawner.twigs package com.wbrawner.twigs
import com.wbrawner.twigs.model.Permission import com.wbrawner.twigs.model.Permission
import com.wbrawner.twigs.model.RecurringTransaction
import com.wbrawner.twigs.model.Session import com.wbrawner.twigs.model.Session
import com.wbrawner.twigs.model.Transaction
import com.wbrawner.twigs.storage.PermissionRepository import com.wbrawner.twigs.storage.PermissionRepository
import com.wbrawner.twigs.storage.TransactionRepository import com.wbrawner.twigs.storage.RecurringTransactionRepository
import io.ktor.application.* import io.ktor.application.*
import io.ktor.auth.* import io.ktor.auth.*
import io.ktor.http.* import io.ktor.http.*
import io.ktor.request.* import io.ktor.request.*
import io.ktor.response.* import io.ktor.response.*
import io.ktor.routing.* import io.ktor.routing.*
import io.ktor.util.pipeline.*
import java.time.Instant import java.time.Instant
fun Application.recurringTransactionRoutes( fun Application.recurringTransactionRoutes(
transactionRepository: TransactionRepository, recurringTransactionRepository: RecurringTransactionRepository,
permissionRepository: PermissionRepository permissionRepository: PermissionRepository
) { ) {
suspend fun PipelineContext<Unit, ApplicationCall>.recurringTransactionAfterPermissionCheck(
id: String?,
userId: String,
success: suspend (RecurringTransaction) -> Unit
) {
if (id.isNullOrBlank()) {
errorResponse(HttpStatusCode.BadRequest, "id is required")
return
}
val recurringTransaction = recurringTransactionRepository.findAll(ids = listOf(id)).firstOrNull()
?: run {
errorResponse()
return
}
requireBudgetWithPermission(
permissionRepository,
userId,
recurringTransaction.budgetId,
Permission.WRITE
) {
application.log.info("No permissions on budget ${recurringTransaction.budgetId}.")
return
}
success(recurringTransaction)
}
routing { routing {
route("/api/recurringtransactions") { route("/api/recurringtransactions") {
authenticate(optional = false) { authenticate(optional = false) {
get { get {
val session = call.principal<Session>()!! val session = call.principal<Session>()!!
val budgetId = call.request.queryParameters["budgetId"]
requireBudgetWithPermission(
permissionRepository,
session.userId,
budgetId,
Permission.WRITE
) {
return@get
}
call.respond( call.respond(
transactionRepository.findAll( recurringTransactionRepository.findAll(
budgetIds = permissionRepository.findAll( budgetId = budgetId!!
budgetIds = call.request.queryParameters.getAll("budgetIds"), ).map { it.asResponse() }
userId = session.userId )
).map { it.budgetId },
categoryIds = call.request.queryParameters.getAll("categoryIds"),
from = call.request.queryParameters["from"]?.let { Instant.parse(it) },
to = call.request.queryParameters["to"]?.let { Instant.parse(it) },
expense = call.request.queryParameters["expense"]?.toBoolean(),
).map { it.asResponse() })
} }
get("/{id}") { get("/{id}") {
val session = call.principal<Session>()!! val session = call.principal<Session>()!!
val transaction = transactionRepository.findAll( recurringTransactionAfterPermissionCheck(call.parameters["id"]!!, session.userId) {
ids = call.parameters.getAll("id"), call.respond(it.asResponse())
budgetIds = permissionRepository.findAll(
userId = session.userId
)
.map { it.budgetId }
)
.map { it.asResponse() }
.firstOrNull()
transaction?.let {
call.respond(it)
} ?: errorResponse()
}
get("/sum") {
val categoryId = call.request.queryParameters["categoryId"]
val budgetId = call.request.queryParameters["budgetId"]
val from = call.request.queryParameters["from"]?.toInstant() ?: firstOfMonth
val to = call.request.queryParameters["to"]?.toInstant() ?: endOfMonth
val balance = if (!categoryId.isNullOrBlank()) {
if (!budgetId.isNullOrBlank()) {
errorResponse(
HttpStatusCode.BadRequest,
"budgetId and categoryId cannot be provided together"
)
return@get
}
transactionRepository.sumByCategory(categoryId, from, to)
} else if (!budgetId.isNullOrBlank()) {
transactionRepository.sumByBudget(budgetId, from, to)
} else {
errorResponse(HttpStatusCode.BadRequest, "budgetId or categoryId must be provided to sum")
return@get
} }
call.respond(BalanceResponse(balance))
} }
post { post {
val session = call.principal<Session>()!! val session = call.principal<Session>()!!
val request = call.receive<TransactionRequest>() val request = call.receive<RecurringTransactionRequest>()
if (request.title.isNullOrBlank()) { if (request.title.isNullOrBlank()) {
errorResponse(HttpStatusCode.BadRequest, "Title cannot be null or empty") errorResponse(HttpStatusCode.BadRequest, "Title cannot be null or empty")
return@post return@post
@ -94,8 +92,8 @@ fun Application.recurringTransactionRoutes(
return@post return@post
} }
call.respond( call.respond(
transactionRepository.save( recurringTransactionRepository.save(
Transaction( RecurringTransaction(
title = request.title, title = request.title,
description = request.description, description = request.description,
amount = request.amount ?: 0L, amount = request.amount ?: 0L,
@ -103,7 +101,9 @@ fun Application.recurringTransactionRoutes(
budgetId = request.budgetId, budgetId = request.budgetId,
categoryId = request.categoryId, categoryId = request.categoryId,
createdBy = session.userId, createdBy = session.userId,
date = request.date?.let { Instant.parse(it) } ?: Instant.now() start = request.start?.toInstant() ?: Instant.now(),
finish = request.finish?.toInstant(),
frequency = request.frequency.asFrequency()
) )
).asResponse() ).asResponse()
) )
@ -111,59 +111,49 @@ fun Application.recurringTransactionRoutes(
put("/{id}") { put("/{id}") {
val session = call.principal<Session>()!! val session = call.principal<Session>()!!
val request = call.receive<TransactionRequest>() val request = call.receive<RecurringTransactionRequest>()
val transaction = transactionRepository.findAll(ids = call.parameters.getAll("id")) recurringTransactionAfterPermissionCheck(
.firstOrNull() call.parameters["id"]!!,
?: run { session.userId
errorResponse() ) { recurringTransaction ->
return@put if (request.budgetId != recurringTransaction.budgetId) {
requireBudgetWithPermission(
permissionRepository,
session.userId,
request.budgetId,
Permission.WRITE
) {
return@recurringTransactionAfterPermissionCheck
}
} }
requireBudgetWithPermission( call.respond(
permissionRepository, recurringTransactionRepository.save(
session.userId, recurringTransaction.copy(
transaction.budgetId, title = request.title ?: recurringTransaction.title,
Permission.WRITE description = request.description ?: recurringTransaction.description,
) { amount = request.amount ?: recurringTransaction.amount,
return@put expense = request.expense ?: recurringTransaction.expense,
categoryId = request.categoryId ?: recurringTransaction.categoryId,
budgetId = request.budgetId ?: recurringTransaction.budgetId,
start = request.start?.toInstant() ?: recurringTransaction.start,
finish = request.finish?.toInstant() ?: recurringTransaction.finish,
frequency = request.frequency.asFrequency()
)
).asResponse()
)
} }
call.respond(
transactionRepository.save(
transaction.copy(
title = request.title ?: transaction.title,
description = request.description ?: transaction.description,
amount = request.amount ?: transaction.amount,
expense = request.expense ?: transaction.expense,
date = request.date?.let { Instant.parse(it) } ?: transaction.date,
categoryId = request.categoryId ?: transaction.categoryId,
budgetId = request.budgetId ?: transaction.budgetId,
createdBy = transaction.createdBy,
)
).asResponse()
)
} }
delete("/{id}") { delete("/{id}") {
val session = call.principal<Session>()!! val session = call.principal<Session>()!!
val transaction = transactionRepository.findAll(ids = call.parameters.getAll("id")) recurringTransactionAfterPermissionCheck(call.parameters["id"]!!, session.userId) {
.firstOrNull() val response = if (recurringTransactionRepository.delete(it)) {
?: run { HttpStatusCode.NoContent
errorResponse() } else {
return@delete HttpStatusCode.InternalServerError
} }
requireBudgetWithPermission( call.respond(response)
permissionRepository,
session.userId,
transaction.budgetId,
Permission.WRITE
) {
return@delete
} }
val response = if (transactionRepository.delete(transaction)) {
HttpStatusCode.NoContent
} else {
HttpStatusCode.InternalServerError
}
call.respond(response)
} }
} }
} }

View file

@ -11,6 +11,9 @@ data class RecurringTransactionRequest(
val categoryId: String? = null, val categoryId: String? = null,
val expense: Boolean? = null, val expense: Boolean? = null,
val budgetId: String? = null, val budgetId: String? = null,
val frequency: String,
val start: String? = null,
val finish: String? = null,
) )
@Serializable @Serializable
@ -18,7 +21,9 @@ data class RecurringTransactionResponse(
val id: String, val id: String,
val title: String?, val title: String?,
val description: String?, val description: String?,
// val frequency: FrequencyResponse, val frequency: String,
val start: String,
val finish: String? = null,
val amount: Long?, val amount: Long?,
val expense: Boolean?, val expense: Boolean?,
val budgetId: String, val budgetId: String,
@ -30,7 +35,9 @@ fun RecurringTransaction.asResponse(): RecurringTransactionResponse = RecurringT
id = id, id = id,
title = title, title = title,
description = description, description = description,
// frequency = date.toString(), frequency = frequency.toString(),
start = start.toString(),
finish = finish.toString(),
amount = amount, amount = amount,
expense = expense, expense = expense,
budgetId = budgetId, budgetId = budgetId,

View file

@ -91,7 +91,7 @@ tasks.register("package") {
tasks.register("publish") { tasks.register("publish") {
dependsOn(":app:package") dependsOn(":app:package")
doLast { doLast {
var command = listOf("caprover", "deploy", "-t", "build/${tarFile.name}", "-n", "wbrawner", "-a", "twigs") var command = listOf("caprover", "deploy", "-t", "build/${tarFile.name}", "-n", "wbrawner", "-a", "twigs-dev")
command = if (System.getProperty("os.name").toLowerCase(Locale.ROOT).contains("windows")) { command = if (System.getProperty("os.name").toLowerCase(Locale.ROOT).contains("windows")) {
listOf("powershell", "-Command") + command listOf("powershell", "-Command") + command
} else { } else {

View file

@ -24,7 +24,7 @@ import kotlin.time.ExperimentalTime
fun main(args: Array<String>): Unit = io.ktor.server.cio.EngineMain.main(args) fun main(args: Array<String>): Unit = io.ktor.server.cio.EngineMain.main(args)
private const val DATABASE_VERSION = 1 private const val DATABASE_VERSION = 2
@ExperimentalTime @ExperimentalTime
fun Application.module() { fun Application.module() {
@ -44,7 +44,7 @@ fun Application.module() {
budgetRepository = JdbcBudgetRepository(it), budgetRepository = JdbcBudgetRepository(it),
categoryRepository = JdbcCategoryRepository(it), categoryRepository = JdbcCategoryRepository(it),
permissionRepository = JdbcPermissionRepository(it), permissionRepository = JdbcPermissionRepository(it),
// recurringTransactionRepository = Fa, recurringTransactionRepository = JdbcRecurringTransactionRepository(it),
sessionRepository = JdbcSessionRepository(it), sessionRepository = JdbcSessionRepository(it),
transactionRepository = JdbcTransactionRepository(it), transactionRepository = JdbcTransactionRepository(it),
userRepository = JdbcUserRepository(it) userRepository = JdbcUserRepository(it)
@ -58,7 +58,7 @@ fun Application.moduleWithDependencies(
budgetRepository: BudgetRepository, budgetRepository: BudgetRepository,
categoryRepository: CategoryRepository, categoryRepository: CategoryRepository,
permissionRepository: PermissionRepository, permissionRepository: PermissionRepository,
// recurringTransactionRepository: RecurringTransactionRepository, recurringTransactionRepository: RecurringTransactionRepository,
sessionRepository: SessionRepository, sessionRepository: SessionRepository,
transactionRepository: TransactionRepository, transactionRepository: TransactionRepository,
userRepository: UserRepository userRepository: UserRepository
@ -112,6 +112,7 @@ fun Application.moduleWithDependencies(
} }
budgetRoutes(budgetRepository, permissionRepository) budgetRoutes(budgetRepository, permissionRepository)
categoryRoutes(categoryRepository, permissionRepository) categoryRoutes(categoryRepository, permissionRepository)
recurringTransactionRoutes(recurringTransactionRepository, permissionRepository)
transactionRoutes(transactionRepository, permissionRepository) transactionRoutes(transactionRepository, permissionRepository)
userRoutes(permissionRepository, sessionRepository, userRepository) userRoutes(permissionRepository, sessionRepository, userRepository)
webRoutes() webRoutes()
@ -134,7 +135,7 @@ fun Application.moduleWithDependencies(
} }
val jobs = listOf( val jobs = listOf(
SessionCleanupJob(sessionRepository), SessionCleanupJob(sessionRepository),
// RecurringTransactionProcessingJob(recurringTransactionRepository, transactionRepository) RecurringTransactionProcessingJob(recurringTransactionRepository, transactionRepository)
) )
while (currentCoroutineContext().isActive) { while (currentCoroutineContext().isActive) {
delay(Duration.hours(24)) delay(Duration.hours(24))

View file

@ -1,15 +1,16 @@
package com.wbrawner.twigs package com.wbrawner.twigs
import at.favre.lib.crypto.bcrypt.BCrypt import at.favre.lib.crypto.bcrypt.BCrypt
import com.wbrawner.twigs.model.Frequency
import java.time.Instant import java.time.Instant
import java.util.* import java.util.*
private val CALENDAR_FIELDS = intArrayOf( private val CALENDAR_FIELDS = intArrayOf(
Calendar.MILLISECOND, Calendar.MILLISECOND,
Calendar.SECOND, Calendar.SECOND,
Calendar.MINUTE, Calendar.MINUTE,
Calendar.HOUR_OF_DAY, Calendar.HOUR_OF_DAY,
Calendar.DATE Calendar.DATE
) )
val firstOfMonth: Instant val firstOfMonth: Instant
@ -46,3 +47,7 @@ fun randomString(length: Int = 32): String {
lateinit var salt: String lateinit var salt: String
fun String.hash(): String = String(BCrypt.withDefaults().hash(10, salt.toByteArray(), this.toByteArray())) fun String.hash(): String = String(BCrypt.withDefaults().hash(10, salt.toByteArray(), this.toByteArray()))
fun String.toInstant(): Instant = Instant.parse(this)
fun String.asFrequency(): Frequency = Frequency.parse(this)

View file

@ -12,7 +12,7 @@ data class RecurringTransaction(
val description: String? = null, val description: String? = null,
val frequency: Frequency, val frequency: Frequency,
val start: Instant, val start: Instant,
val end: Instant? = null, val finish: Instant? = null,
val amount: Long, val amount: Long,
val expense: Boolean, val expense: Boolean,
val createdBy: String, val createdBy: String,
@ -37,6 +37,8 @@ sealed class Frequency {
abstract val time: Time abstract val time: Time
data class Daily(override val count: Int, override val time: Time) : Frequency() { data class Daily(override val count: Int, override val time: Time) : Frequency() {
override fun toString(): String = "D;$count;$time"
companion object { companion object {
fun parse(s: String): Daily { fun parse(s: String): Daily {
require(s[0] == 'D') { "Invalid format for Daily: $s" } require(s[0] == 'D') { "Invalid format for Daily: $s" }
@ -51,6 +53,7 @@ sealed class Frequency {
} }
data class Weekly(override val count: Int, val daysOfWeek: Set<DayOfWeek>, override val time: Time) : Frequency() { data class Weekly(override val count: Int, val daysOfWeek: Set<DayOfWeek>, override val time: Time) : Frequency() {
override fun toString(): String = "W;$count;${daysOfWeek.joinToString(",")};$time"
companion object { companion object {
fun parse(s: String): Weekly { fun parse(s: String): Weekly {
require(s[0] == 'W') { "Invalid format for Weekly: $s" } require(s[0] == 'W') { "Invalid format for Weekly: $s" }
@ -70,6 +73,7 @@ sealed class Frequency {
val dayOfMonth: DayOfMonth, val dayOfMonth: DayOfMonth,
override val time: Time override val time: Time
) : Frequency() { ) : Frequency() {
override fun toString(): String = "M;$count;$dayOfMonth;$time"
companion object { companion object {
fun parse(s: String): Monthly { fun parse(s: String): Monthly {
require(s[0] == 'M') { "Invalid format for Monthly: $s" } require(s[0] == 'M') { "Invalid format for Monthly: $s" }
@ -85,13 +89,16 @@ sealed class Frequency {
} }
data class Yearly(override val count: Int, val dayOfYear: MonthDay, override val time: Time) : Frequency() { data class Yearly(override val count: Int, val dayOfYear: MonthDay, override val time: Time) : Frequency() {
override fun toString(): String = "Y;$count;%02d-%02d;$time".format(dayOfYear.monthValue, dayOfYear.dayOfMonth)
companion object { companion object {
fun parse(s: String): Yearly { fun parse(s: String): Yearly {
require(s[0] == 'Y') { "Invalid format for Yearly: $s" } require(s[0] == 'Y') { "Invalid format for Yearly: $s" }
return with(s.split(';')) { return with(s.split(';')) {
Yearly( Yearly(
get(1).toInt(), get(1).toInt(),
MonthDay.parse(get(2)), with(get(2).split("-")) {
MonthDay.of(get(0).toInt(), get(1).toInt())
},
Time.parse(get(3)) Time.parse(get(3))
) )
} }
@ -101,13 +108,6 @@ sealed class Frequency {
fun instant(now: Instant): Instant = Instant.parse(now.toString().split("T")[0] + "T" + time.toString() + "Z") fun instant(now: Instant): Instant = Instant.parse(now.toString().split("T")[0] + "T" + time.toString() + "Z")
override fun toString(): String = when (this) {
is Daily -> "D;$count;$time"
is Weekly -> "W;$count;${daysOfWeek.joinToString(",")};$time"
is Monthly -> "M;$count;$dayOfMonth;$time"
is Yearly -> "Y;$count;$dayOfYear;$time"
}
companion object { companion object {
fun parse(s: String): Frequency = when (s[0]) { fun parse(s: String): Frequency = when (s[0]) {
'D' -> Daily.parse(s) 'D' -> Daily.parse(s)
@ -119,7 +119,7 @@ sealed class Frequency {
} }
} }
data class Time(val hours: Int, val minutes: Int, val seconds: Int, val milliseconds: Int) { data class Time(val hours: Int, val minutes: Int, val seconds: Int) {
override fun toString(): String { override fun toString(): String {
val s = StringBuilder() val s = StringBuilder()
if (hours < 10) { if (hours < 10) {
@ -136,28 +136,18 @@ data class Time(val hours: Int, val minutes: Int, val seconds: Int, val millisec
s.append("0") s.append("0")
} }
s.append(seconds) s.append(seconds)
s.append(".")
if (milliseconds < 100) {
s.append("0")
}
if (milliseconds < 10) {
s.append("0")
}
s.append(milliseconds)
return s.toString() return s.toString()
} }
companion object { companion object {
fun parse(s: String): Time { fun parse(s: String): Time {
require(s.length < 12) { "Invalid time format: $s. Time should be formatted as HH:mm:ss.SSS" } require(s.length < 9) { "Invalid time format: $s. Time should be formatted as HH:mm:ss" }
require(s[3] == ':') { "Invalid time format: $s. Time should be formatted as HH:mm:ss.SSS" } require(s[2] == ':') { "Invalid time format: $s. Time should be formatted as HH:mm:ss" }
require(s[6] == ':') { "Invalid time format: $s. Time should be formatted as HH:mm:ss.SSS" } require(s[5] == ':') { "Invalid time format: $s. Time should be formatted as HH:mm:ss" }
require(s[9] == '.') { "Invalid time format: $s. Time should be formatted as HH:mm:ss.SSS" }
return Time( return Time(
s.substring(0, 3).toInt(), s.substring(0, 2).toInt(),
s.substring(4, 6).toInt(), s.substring(3, 5).toInt(),
s.substring(7, 9).toInt(), s.substring(7).toInt(),
s.substring(10).toInt()
) )
} }
} }

View file

@ -0,0 +1,58 @@
package com.wbrawner.twigs.db
import com.wbrawner.twigs.asFrequency
import com.wbrawner.twigs.model.RecurringTransaction
import com.wbrawner.twigs.storage.RecurringTransactionRepository
import java.sql.ResultSet
import java.time.Instant
import javax.sql.DataSource
class JdbcRecurringTransactionRepository(dataSource: DataSource) :
JdbcRepository<RecurringTransaction, JdbcRecurringTransactionRepository.Fields>(dataSource),
RecurringTransactionRepository {
override val tableName: String = TABLE_RECURRING_TRANSACTION
override val fields: Map<Fields, (RecurringTransaction) -> Any?> = Fields.values().associateWith { it.entityField }
override val conflictFields: Collection<String> = listOf(ID)
override suspend fun findAll(now: Instant): List<RecurringTransaction> = dataSource.connection.use { conn ->
conn.executeQuery("SELECT * FROM $tableName WHERE ${Fields.START.name.lowercase()} < ?", listOf(now))
}
override suspend fun findAll(budgetId: String): List<RecurringTransaction> = dataSource.connection.use { conn ->
if (budgetId.isBlank()) throw IllegalArgumentException("budgetId cannot be null")
conn.executeQuery("SELECT * FROM $tableName WHERE ${Fields.BUDGET_ID.name.lowercase()} = ?", listOf(budgetId))
}
override fun ResultSet.toEntity(): RecurringTransaction = RecurringTransaction(
id = getString(ID),
title = getString(Fields.TITLE.name.lowercase()),
description = getString(Fields.DESCRIPTION.name.lowercase()),
frequency = getString(Fields.FREQUENCY.name.lowercase()).asFrequency(),
start = getInstant(Fields.START.name.lowercase())!!,
finish = getInstant(Fields.FINISH.name.lowercase()),
amount = getLong(Fields.AMOUNT.name.lowercase()),
expense = getBoolean(Fields.EXPENSE.name.lowercase()),
createdBy = getString(Fields.CREATED_BY.name.lowercase()),
categoryId = getString(Fields.CATEGORY_ID.name.lowercase()),
budgetId = getString(Fields.BUDGET_ID.name.lowercase()),
)
enum class Fields(val entityField: (RecurringTransaction) -> Any?) {
TITLE({ it.title }),
DESCRIPTION({ it.description }),
FREQUENCY({ it.frequency }),
START({ it.start }),
FINISH({ it.finish }),
LAST_RUN({ it.lastRun }),
AMOUNT({ it.amount }),
EXPENSE({ it.expense }),
CREATED_BY({ it.createdBy }),
CATEGORY_ID({ it.categoryId }),
BUDGET_ID({ it.budgetId }),
}
companion object {
const val TABLE_RECURRING_TRANSACTION = "recurring_transactions"
}
}

View file

@ -1,6 +1,7 @@
package com.wbrawner.twigs.db package com.wbrawner.twigs.db
import com.wbrawner.twigs.Identifiable import com.wbrawner.twigs.Identifiable
import com.wbrawner.twigs.model.Frequency
import com.wbrawner.twigs.storage.Repository import com.wbrawner.twigs.storage.Repository
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import java.sql.Connection import java.sql.Connection
@ -101,6 +102,7 @@ abstract class JdbcRepository<Entity, Fields : Enum<Fields>>(protected val dataS
is Long -> setLong(index + 1, param) is Long -> setLong(index + 1, param)
is String -> setString(index + 1, param) is String -> setString(index + 1, param)
is Enum<*> -> setString(index + 1, param.name) is Enum<*> -> setString(index + 1, param.name)
is Frequency -> setString(index + 1, param.toString())
null -> setNull(index + 1, NULL) null -> setNull(index + 1, NULL)
else -> throw Error("Unhandled parameter type: ${param.javaClass.name}") else -> throw Error("Unhandled parameter type: ${param.javaClass.name}")
} }
@ -117,4 +119,4 @@ private val dateFormatter = DateTimeFormatterBuilder()
.toFormatter() .toFormatter()
.withZone(ZoneId.of("UTC")) .withZone(ZoneId.of("UTC"))
fun ResultSet.getInstant(column: String): Instant = dateFormatter.parse(getString(column), Instant::from) fun ResultSet.getInstant(column: String): Instant? = getString(column)?.let { dateFormatter.parse(it, Instant::from) }

View file

@ -30,7 +30,7 @@ class JdbcSessionRepository(dataSource: DataSource) : JdbcRepository<Session, Jd
id = getString(ID), id = getString(ID),
userId = getString(Fields.USER_ID.name.lowercase()), userId = getString(Fields.USER_ID.name.lowercase()),
token = getString(Fields.TOKEN.name.lowercase()), token = getString(Fields.TOKEN.name.lowercase()),
expiration = getInstant(Fields.EXPIRATION.name.lowercase()), expiration = getInstant(Fields.EXPIRATION.name.lowercase())!!,
) )
enum class Fields(val entityField: (Session) -> Any?) { enum class Fields(val entityField: (Session) -> Any?) {

View file

@ -86,7 +86,7 @@ class JdbcTransactionRepository(dataSource: DataSource) :
id = getString(ID), id = getString(ID),
title = getString(Fields.TITLE.name.lowercase()), title = getString(Fields.TITLE.name.lowercase()),
description = getString(Fields.DESCRIPTION.name.lowercase()), description = getString(Fields.DESCRIPTION.name.lowercase()),
date = getInstant(Fields.DATE.name.lowercase()), date = getInstant(Fields.DATE.name.lowercase())!!,
amount = getLong(Fields.AMOUNT.name.lowercase()), amount = getLong(Fields.AMOUNT.name.lowercase()),
expense = getBoolean(Fields.EXPENSE.name.lowercase()), expense = getBoolean(Fields.EXPENSE.name.lowercase()),
createdBy = getString(Fields.CREATED_BY.name.lowercase()), createdBy = getString(Fields.CREATED_BY.name.lowercase()),
@ -109,4 +109,3 @@ class JdbcTransactionRepository(dataSource: DataSource) :
const val TABLE_TRANSACTION = "transactions" const val TABLE_TRANSACTION = "transactions"
} }
} }

View file

@ -5,4 +5,5 @@ import java.time.Instant
interface RecurringTransactionRepository : Repository<RecurringTransaction> { interface RecurringTransactionRepository : Repository<RecurringTransaction> {
suspend fun findAll(now: Instant): List<RecurringTransaction> suspend fun findAll(now: Instant): List<RecurringTransaction>
suspend fun findAll(budgetId: String): List<RecurringTransaction>
} }

View file

@ -6,6 +6,10 @@ import java.time.Instant
class FakeRecurringTransactionsRepository : FakeRepository<RecurringTransaction>(), RecurringTransactionRepository { class FakeRecurringTransactionsRepository : FakeRepository<RecurringTransaction>(), RecurringTransactionRepository {
override suspend fun findAll(now: Instant): List<RecurringTransaction> = entities.filter { override suspend fun findAll(now: Instant): List<RecurringTransaction> = entities.filter {
(it.start == now || it.start.isBefore(now)) && it.end?.isAfter(now) ?: true (it.start == now || it.start.isBefore(now)) && it.finish?.isAfter(now) ?: true
}
override suspend fun findAll(budgetId: String): List<RecurringTransaction> = entities.filter {
it.budgetId == budgetId
} }
} }