Persist data in PostgreSQL database

This commit is contained in:
William Brawner 2023-01-25 21:49:24 -07:00
parent 544b97f31c
commit d82dcc979d
Signed by: wbrawner
GPG key ID: 8FF12381C6C90D35
10 changed files with 339 additions and 181 deletions

12
.idea/dataSources.xml Normal file
View file

@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="postgres@localhost" uuid="30a58456-ae1f-4ed0-95e1-e83dd7234d83">
<driver-ref>postgresql</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>org.postgresql.Driver</jdbc-driver>
<jdbc-url>jdbc:postgresql://localhost:5432/postgres</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
</component>
</project>

View file

@ -6,12 +6,14 @@ plugins {
application
id("com.github.johnrengelman.shadow") version "7.1.2"
kotlin("jvm") version "1.8.0"
id("app.cash.sqldelight") version "2.0.0-alpha05"
}
group = "com.wbrawner"
version = "1.0-SNAPSHOT"
repositories {
google()
mavenCentral()
}
@ -20,6 +22,10 @@ dependencies {
testImplementation("org.junit.jupiter:junit-jupiter-api:5.8.1")
testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine:5.8.1")
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
implementation("app.cash.sqldelight:jdbc-driver:2.0.0-alpha05")
implementation("org.postgresql:postgresql:42.5.1")
implementation("com.zaxxer:HikariCP:5.0.1")
implementation("ch.qos.logback:logback-classic:1.3.5")
}
kotlin {
@ -30,7 +36,7 @@ tasks.getByName<Test>("test") {
useJUnitPlatform()
}
val main = "com.wbrawner.MainKt"
val main = "com.wbrawner.civicsquizbot.Main"
application {
mainClass.set(main)
@ -53,4 +59,13 @@ tasks.register("updateQuestions") {
it.readText()
}
File(projectDir, "src/main/resources/questions.txt").writeText(text)
}
sqldelight {
databases {
create("Database") {
packageName.set("com.wbrawner.civicsquizbot")
dialect("app.cash.sqldelight:postgresql-dialect:2.0.0-alpha05")
}
}
}

View file

@ -1,165 +0,0 @@
package com.wbrawner
import org.telegram.telegrambots.bots.TelegramLongPollingBot
import org.telegram.telegrambots.meta.api.methods.send.SendMessage
import org.telegram.telegrambots.meta.api.objects.Update
import org.telegram.telegrambots.meta.api.objects.replykeyboard.ReplyKeyboardMarkup
import org.telegram.telegrambots.meta.api.objects.replykeyboard.buttons.KeyboardButton
import org.telegram.telegrambots.meta.api.objects.replykeyboard.buttons.KeyboardRow
import org.telegram.telegrambots.meta.exceptions.TelegramApiException
import java.security.SecureRandom
import java.util.concurrent.ConcurrentHashMap
class CivicsQuizHandler : TelegramLongPollingBot() {
private val acknowledgementPhrases = listOf(
"Got it",
"Understood",
"Done",
"Roger roger",
"👍",
""
)
private val random = SecureRandom()
private lateinit var questions: Map<Int, Question>
// TODO: Persist this in DB
private val buckets = ConcurrentHashMap<Long, List<MutableList<Int>>>()
// TODO: Persist this in DB
private val lastChatQuestions = ConcurrentHashMap<Long, Question>()
override fun getBotUsername(): String = "CivicsQuizBot"
override fun getBotToken(): String = System.getenv("TELEGRAM_TOKEN")
override fun onUpdateReceived(update: Update) {
println(update)
val message = update.message
if (message == null) {
println("No message, returning early")
return
}
// TODO: Add option to enable smart reminders
when (val command = message.text.asCommand()) {
Command.SHOW_ANSWER -> sendAnswer(message.chatId)
Command.NEW_QUESTION -> sendQuestion(message.chatId)
Command.NEED_PRACTICE, Command.TOO_EASY -> handleFeedback(message.chatId, command)
else -> sendOptions(message.chatId)
}
}
private fun sendQuestion(chatId: Long) {
val question = randomQuestion(chatId)
lastChatQuestions[chatId] = question
sendMessage(chatId, question.prompt, Command.SHOW_ANSWER)
}
private fun handleFeedback(chatId: Long, command: Command) {
val userBuckets = buckets[chatId] ?: run {
val userBuckets = listOf(
questions.values.map { it.number }.toMutableList(),
mutableListOf(),
mutableListOf(),
mutableListOf()
)
buckets[chatId] = userBuckets
userBuckets
}
val lastQuestion = requireNotNull(lastChatQuestions[chatId])
for (i in userBuckets.indices) {
if (userBuckets[i].contains(lastQuestion.number)) {
if (command == Command.TOO_EASY && i < userBuckets.lastIndex) {
userBuckets[i].remove(lastQuestion.number)
userBuckets[i + 1].add(lastQuestion.number)
} else if (command == Command.NEED_PRACTICE && i > 0) {
userBuckets[i].remove(lastQuestion.number)
userBuckets[i - 1].add(lastQuestion.number)
}
break
}
}
println(userBuckets)
sendMessage(chatId, acknowledgementPhrases.random())
sendQuestion(chatId)
}
private fun sendAnswer(chatId: Long) {
lastChatQuestions[chatId]?.let {
sendMessage(chatId, it.answer, Command.NEED_PRACTICE, Command.TOO_EASY)
} ?: run {
sendMessage(chatId, "I can't answer a question that hasn't been asked", Command.NEW_QUESTION)
}
}
private fun sendOptions(chatId: Long) = sendMessage(
chatId,
"I'm not sure how to respond to that. Here are a few actions I can help you with:",
Command.NEW_QUESTION
)
init {
try {
javaClass.getResourceAsStream("/questions.txt")
?.bufferedReader()
?.let { questionsResource ->
val parser = QuestionTextParser()
questions = parser.parseQuestions(questionsResource.readText()).associateBy { it.number }
} ?: throw RuntimeException("Questions resource was null")
} catch (e: Exception) {
throw RuntimeException(e)
}
}
private fun sendMessage(chatId: Long, text: String, vararg keyboardButtons: Command) {
val keyboard = keyboardButtons.map {
KeyboardRow(listOf(KeyboardButton(it.text.first())))
}
val response = SendMessage.builder()
.text(text)
.chatId(chatId)
.replyMarkup(ReplyKeyboardMarkup(keyboard, true, true, false, null, true))
.build()
try {
execute(response)
} catch (e: TelegramApiException) {
e.printStackTrace()
}
}
private fun randomQuestion(chatId: Long): Question {
val userBuckets = buckets[chatId] ?: run {
val userBuckets = listOf(
questions.values.map { it.number }.toMutableList(),
mutableListOf(),
mutableListOf(),
mutableListOf()
)
buckets[chatId] = userBuckets
userBuckets
}
if (userBuckets.first().isNotEmpty()) {
return requireNotNull(questions[userBuckets.first().removeFirst()]) { "Failed to retrieve a question" }
}
var bucket = when (random.nextInt(10)) {
in 0..5 -> 1
in 6..8 -> 2
else -> 3
}
var question = userBuckets[bucket].shuffled().firstOrNull()
while (question == null) {
if (--bucket == -1) {
bucket = 2
}
question = userBuckets[bucket].shuffled().firstOrNull()
}
return requireNotNull(questions[question]) { "Failed to retrieve a question" }
}
}
fun String.asCommand() = Command.values().firstOrNull { it.text.contains(this) }
enum class Command(vararg val text: String) {
NEW_QUESTION("New question", "/start"),
SHOW_ANSWER("Show answer"),
NEED_PRACTICE("I need to practice this"),
TOO_EASY("This was easy"),
}

View file

@ -1,14 +0,0 @@
package com.wbrawner
import org.telegram.telegrambots.meta.TelegramBotsApi
import org.telegram.telegrambots.meta.exceptions.TelegramApiException
import org.telegram.telegrambots.updatesreceivers.DefaultBotSession
fun main() {
try {
val telegramBotsApi = TelegramBotsApi(DefaultBotSession::class.java)
telegramBotsApi.registerBot(CivicsQuizHandler())
} catch (e: TelegramApiException) {
e.printStackTrace()
}
}

View file

@ -0,0 +1,101 @@
package com.wbrawner.civicsquizbot
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.telegram.telegrambots.bots.TelegramLongPollingBot
import org.telegram.telegrambots.meta.api.methods.send.SendMessage
import org.telegram.telegrambots.meta.api.objects.Update
import org.telegram.telegrambots.meta.api.objects.replykeyboard.ReplyKeyboardMarkup
import org.telegram.telegrambots.meta.api.objects.replykeyboard.buttons.KeyboardButton
import org.telegram.telegrambots.meta.api.objects.replykeyboard.buttons.KeyboardRow
import org.telegram.telegrambots.meta.exceptions.TelegramApiException
class CivicsQuizHandler(
private val questionService: QuestionService
) : TelegramLongPollingBot() {
private val acknowledgementPhrases = listOf(
"Got it",
"Understood",
"Done",
"Roger roger",
"👍",
""
)
override fun getBotUsername(): String = "CivicsQuizBot"
override fun getBotToken(): String = System.getenv("TELEGRAM_TOKEN")
override fun onUpdateReceived(update: Update) {
logger.info("Update: $update")
val message = update.message
if (message == null) {
logger.warn("No message, returning early")
return
}
// TODO: Add option to enable smart reminders
when (val command = message.text.asCommand()) {
Command.SHOW_ANSWER -> sendAnswer(message.chatId)
Command.NEW_QUESTION -> sendQuestion(message.chatId)
Command.NEED_PRACTICE, Command.TOO_EASY -> handleFeedback(message.chatId, command)
else -> sendOptions(message.chatId)
}
}
private fun sendQuestion(chatId: Long) {
val question = questionService.randomQuestionForUser(chatId)
sendMessage(chatId, question.prompt, Command.SHOW_ANSWER)
}
private fun handleFeedback(chatId: Long, command: Command) {
if (command == Command.NEED_PRACTICE) {
questionService.increaseLastQuestionFrequency(chatId)
} else if (command == Command.TOO_EASY) {
questionService.decreaseLastQuestionFrequency(chatId)
}
sendMessage(chatId, acknowledgementPhrases.random())
sendQuestion(chatId)
}
private fun sendAnswer(chatId: Long) {
questionService.answerLastQuestion(chatId)?.let { answer ->
sendMessage(chatId, answer, Command.NEED_PRACTICE, Command.TOO_EASY)
} ?: run {
sendMessage(chatId, "I can't answer a question that hasn't been asked", Command.NEW_QUESTION)
}
}
private fun sendOptions(chatId: Long) = sendMessage(
chatId,
"I'm not sure how to respond to that. Here are a few actions I can help you with:",
Command.NEW_QUESTION
)
private fun sendMessage(chatId: Long, text: String, vararg keyboardButtons: Command) {
val keyboard = keyboardButtons.map {
KeyboardRow(listOf(KeyboardButton(it.text.first())))
}
val response = SendMessage.builder()
.text(text)
.chatId(chatId)
.replyMarkup(ReplyKeyboardMarkup(keyboard, true, true, false, null, true))
.build()
try {
execute(response)
} catch (e: TelegramApiException) {
e.printStackTrace()
}
}
}
fun String.asCommand() = Command.values().firstOrNull { it.text.contains(this) }
enum class Command(vararg val text: String) {
NEW_QUESTION("New question", "/start"),
SHOW_ANSWER("Show answer"),
NEED_PRACTICE("I need to practice this"),
TOO_EASY("This was easy"),
}
val Any.logger: Logger
get() = LoggerFactory.getLogger(this.javaClass)

View file

@ -0,0 +1,42 @@
package com.wbrawner.civicsquizbot
import app.cash.sqldelight.driver.jdbc.asJdbcDriver
import com.zaxxer.hikari.HikariConfig
import com.zaxxer.hikari.HikariDataSource
import org.telegram.telegrambots.meta.TelegramBotsApi
import org.telegram.telegrambots.meta.exceptions.TelegramApiException
import org.telegram.telegrambots.updatesreceivers.DefaultBotSession
object Main {
@JvmStatic
fun main(args: Array<String>) {
val questions = try {
javaClass.getResourceAsStream("/questions.txt")
?.bufferedReader()
?.let { questionsResource ->
val parser = QuestionTextParser()
parser.parseQuestions(questionsResource.readText()).associateBy { it.number }
} ?: throw RuntimeException("Questions resource was null")
} catch (e: Exception) {
throw RuntimeException(e)
}
val dataSource = HikariDataSource(HikariConfig().apply {
val host = System.getenv("CIVICS_DB_HOST") ?: "localhost"
val port = System.getenv("CIVICS_DB_PORT") ?: 5432
val name = System.getenv("CIVICS_DB_NAME") ?: "postgres"
jdbcUrl = "jdbc:postgresql://$host:$port/$name"
username = System.getenv("CIVICS_DB_USER") ?: "postgres"
password = System.getenv("CIVICS_DB_PASSWORD") ?: "postgres"
})
val driver = dataSource.asJdbcDriver()
val database = Database(driver)
val questionService = DatabaseQuestionService(database, questions)
try {
val telegramBotsApi = TelegramBotsApi(DefaultBotSession::class.java)
telegramBotsApi.registerBot(CivicsQuizHandler(questionService))
} catch (e: TelegramApiException) {
e.printStackTrace()
}
}
}

View file

@ -0,0 +1,104 @@
package com.wbrawner.civicsquizbot
import java.security.SecureRandom
import kotlin.math.max
interface QuestionService {
fun randomQuestionForUser(userId: Long): Question
fun increaseLastQuestionFrequency(userId: Long)
fun decreaseLastQuestionFrequency(userId: Long)
fun answerLastQuestion(userId: Long): String?
}
class DatabaseQuestionService(
private val database: Database,
private val questions: Map<Int, Question>
) : QuestionService {
private val random = SecureRandom()
init {
database.lastQuestionQueries.create()
database.repetitionQueries.create()
}
override fun randomQuestionForUser(userId: Long): Question {
if (database.repetitionQueries.countByUserId(userId).executeAsOne() == 0L) {
questions.values.forEach {
database.repetitionQueries.insertRepetition(Repetition(it.number, userId, 0))
}
}
var question = database.repetitionQueries.selectRandomByUserIdAndBucket(0, userId)
.executeAsOneOrNull()
?.question_id
if (question != null) {
database.lastQuestionQueries.upsertLastQuestion(userId, question)
return requireNotNull(questions[question]) { "Failed to retrieve random question from bucket 0" }
}
var bucket = when (random.nextInt(10)) {
in 0..5 -> 1
in 6..8 -> 2
else -> 3
}
val initialBucket = bucket
while (question == null) {
if (--bucket == 0) {
bucket = 3
} else if (bucket == initialBucket) {
throw IllegalStateException("Failed to find questions in any bucket")
}
question = database.repetitionQueries.selectRandomByUserIdAndBucket(0, userId)
.executeAsOneOrNull()
?.question_id
}
database.lastQuestionQueries.upsertLastQuestion(userId, question)
return requireNotNull(questions[question]) { "Failed to retrieve random question from bucket 0" }
}
override fun increaseLastQuestionFrequency(userId: Long) {
val lastQuestion = database.lastQuestionQueries.selectByUserId(userId)
.executeAsOneOrNull()
?.question_id
?: run {
logger.info("Ignoring feedback for user $userId since they don't have a previously asked question")
return
}
val bucket = database.repetitionQueries.selectByQuestionIdAndUserId(lastQuestion, userId)
.executeAsOneOrNull()
?.bucket
?: 0
val newBucket = if (bucket == 0) {
1
} else {
max(bucket - 1, 1)
}
database.repetitionQueries.updateQuestionBucket(newBucket, lastQuestion, userId)
}
override fun decreaseLastQuestionFrequency(userId: Long) {
val lastQuestion = database.lastQuestionQueries.selectByUserId(userId)
.executeAsOneOrNull()
?.question_id
?: run {
logger.info("Ignoring feedback for user $userId since they don't have a previously asked question")
return
}
val bucket = database.repetitionQueries.selectByQuestionIdAndUserId(lastQuestion, userId)
.executeAsOneOrNull()
?.bucket
?: 0
val newBucket = if (bucket == 0) {
2
} else {
max(bucket - 1, 1)
}
database.repetitionQueries.updateQuestionBucket(newBucket, lastQuestion, userId)
}
override fun answerLastQuestion(userId: Long): String? = database.lastQuestionQueries.selectByUserId(userId)
.executeAsOneOrNull()
?.question_id
?.let {
questions[it]?.answer
}
}

View file

@ -1,4 +1,4 @@
package com.wbrawner
package com.wbrawner.civicsquizbot
class QuestionTextParser {
private val promptRegex = Regex("^(\\d+)\\.\\s+(.*)$")

View file

@ -0,0 +1,22 @@
create:
CREATE TABLE IF NOT EXISTS last_question (
user_id BIGINT NOT NULL PRIMARY KEY,
question_id INTEGER NOT NULL
);
selectByUserId:
SELECT *
FROM last_question
WHERE user_id = ?;
upsertLastQuestion:
INSERT INTO last_question (user_id, question_id)
VALUES (:user_id, :question_id)
ON CONFLICT(user_id) DO
UPDATE SET question_id = :question_id;
updateQuestionBucket:
UPDATE repetition
SET bucket = ?
WHERE question_id = ?
AND user_id = ?;

View file

@ -0,0 +1,41 @@
create:
CREATE TABLE IF NOT EXISTS repetition (
question_id INTEGER NOT NULL,
user_id BIGINT NOT NULL,
bucket INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (question_id, user_id)
);
countByUserId:
SELECT COUNT(*)
FROM repetition
WHERE user_id = ?;
selectByUserId:
SELECT *
FROM repetition
WHERE user_id = ?;
selectByQuestionIdAndUserId:
SELECT *
FROM repetition
WHERE question_id = ?
AND user_id = ?;
selectRandomByUserIdAndBucket:
SELECT *
FROM repetition
WHERE bucket = ?
AND user_id = ?
ORDER BY RANDOM()
LIMIT 1;
insertRepetition:
INSERT INTO repetition
VALUES ?;
updateQuestionBucket:
UPDATE repetition
SET bucket = ?
WHERE question_id = ?
AND user_id = ?;