vqa-guessing-game / response_db.py
sedrickkeh's picture
Upload 13 files
016285f
raw
history blame
2.56 kB
import sqlite3
from sqlite3 import Connection
import datetime
class ResponseDb:
DB_PATH = "response.db"
def __init__(self):
# establish connection
self.con = self.get_connection()
cur = self.con.cursor()
# create table if it doesnt already exist
table_string = "CREATE TABLE IF NOT EXISTS responses (dialogue_id text, task_id text, turn integer, question text, response text, datetime date)"
cur.execute(table_string)
self.con.commit()
cur.close()
def get_connection(self):
"""Put the connection in cache to reuse if path does not change between Streamlit reruns.
NB : https://stackoverflow.com/questions/48218065/programmingerror-sqlite-objects-created-in-a-thread-can-only-be-used-in-that-sa
"""
return sqlite3.connect(self.DB_PATH, check_same_thread=False)
#def __del__(self):
#self.con.close()
def add(self, dialogue_id, task_id, turn, question, response):
cur = self.con.cursor()
curr_datetime = datetime.datetime.now()
cur.execute(
"insert into responses values (?, ?, ?, ?, ?, ?)",
(dialogue_id, int(task_id), turn, question, response, curr_datetime),
)
self.con.commit()
cur.close()
def get_id(self, dialogue_id):
cur = self.con.cursor()
cur.execute(
"select * from responses where conv=:id",
{"id": dialogue_id},
)
results = cur.fetchall()
cur.close()
return results
def get_id_turn(self, dialogue_id, turn):
cur = self.con.cursor()
cur.execute(
"select * from responses where conv=:id and turn=:turn",
{
"id": dialogue_id,
"turn": turn,
},
)
results = cur.fetchall()
cur.close()
return results
def get_all(self):
cur = self.con.cursor()
cur.execute(
"select * from responses",
)
results = cur.fetchall()
cur.close()
return results
class StResponseDb(ResponseDb):
def get_connection(self):
"""Put the connection in cache to reuse if path does not change between Streamlit reruns.
NB : https://stackoverflow.com/questions/48218065/programmingerror-sqlite-objects-created-in-a-thread-can-only-be-used-in-that-sa
"""
return sqlite3.connect(self.DB_PATH, check_same_thread=False)
if __name__ == "__main__":
db = ResponseDb()
print(db.get_all())