Spaces:
Sleeping
Sleeping
import streamlit as st | |
from azure.cosmos import CosmosClient, PartitionKey | |
import os | |
import re | |
# Cosmos DB configuration | |
ENDPOINT = "https://acae-afd.documents.azure.com:443/" | |
SUBSCRIPTION_ID = "003fba60-5b3f-48f4-ab36-3ed11bc40816" | |
# You'll need to set these environment variables or use Azure Key Vault | |
DATABASE_NAME = os.environ.get("COSMOS_DATABASE_NAME") | |
CONTAINER_NAME = os.environ.get("COSMOS_CONTAINER_NAME") | |
def create_stored_procedure(container): | |
stored_procedure_definition = { | |
'id': 'processQTPrompts', | |
'body': ''' | |
function processQTPrompts(promptText) { | |
var context = getContext(); | |
var container = context.getCollection(); | |
var response = context.getResponse(); | |
var prompts = promptText.split('\\n'); | |
var results = []; | |
for (var i in prompts) { | |
var prompt = prompts[i].trim(); | |
if (prompt.startsWith('QT ')) { | |
var querySpec = { | |
query: "SELECT * FROM c WHERE c.prompt = @prompt", | |
parameters: [{ name: "@prompt", value: prompt }] | |
}; | |
var isAccepted = container.queryDocuments( | |
container.getSelfLink(), | |
querySpec, | |
function (err, items, responseOptions) { | |
if (err) throw err; | |
if (items.length > 0) { | |
// Update existing record | |
var item = items[0]; | |
item.occurrenceCount++; | |
container.replaceDocument( | |
item._self, | |
item, | |
function (err, replacedItem) { | |
if (err) throw err; | |
results.push(replacedItem); | |
} | |
); | |
} else { | |
// Create new record | |
var newItem = { | |
prompt: prompt, | |
occurrenceCount: 1, | |
evaluation: "" | |
}; | |
container.createDocument( | |
container.getSelfLink(), | |
newItem, | |
function (err, createdItem) { | |
if (err) throw err; | |
results.push(createdItem); | |
} | |
); | |
} | |
} | |
); | |
if (!isAccepted) throw new Error("The query was not accepted by the server."); | |
} | |
} | |
response.setBody(results); | |
} | |
''' | |
} | |
container.scripts.create_stored_procedure(body=stored_procedure_definition) | |
def ensure_stored_procedure_exists(container): | |
try: | |
container.scripts.get_stored_procedure('processQTPrompts') | |
except: | |
create_stored_procedure(container) | |
def process_qt_prompts(container, prompt_text): | |
return container.scripts.execute_stored_procedure( | |
sproc='processQTPrompts', | |
params=[prompt_text], | |
partition_key=None | |
) | |
# Streamlit app | |
st.title("π QT Prompt Processor") | |
# Login section | |
if 'logged_in' not in st.session_state: | |
st.session_state.logged_in = False | |
if not st.session_state.logged_in: | |
st.subheader("π Login") | |
input_key = st.text_input("Enter your Cosmos DB Primary Key", type="password") | |
if st.button("π Login"): | |
if input_key: | |
st.session_state.primary_key = input_key | |
st.session_state.logged_in = True | |
st.experimental_rerun() | |
else: | |
st.error("Please enter a valid key") | |
else: | |
# Initialize Cosmos DB client | |
client = CosmosClient(ENDPOINT, credential=st.session_state.primary_key) | |
database = client.get_database_client(DATABASE_NAME) | |
container = database.get_container_client(CONTAINER_NAME) | |
# Ensure the stored procedure exists | |
ensure_stored_procedure_exists(container) | |
# Input field for QT prompts | |
st.subheader("π Enter QT Prompts") | |
default_text = "QT Crystal finders: Bioluminescent crystal caverns, quantum-powered explorers, prismatic hues, alien planet\nQT robot art: Cybernetic metropolis, sentient androids, rogue AI, neon-infused palette\nQT the Lava girl: Volcanic exoplanet, liquid metal rivers, heat-immune heroine, molten metallic palette" | |
qt_prompts = st.text_area("QT Prompts", value=default_text, height=300) | |
# Submit button | |
if st.button("π Process QT Prompts"): | |
results = process_qt_prompts(container, qt_prompts) | |
# Display results in a dataframe | |
df_data = [{"Prompt": item['prompt'], "Occurrence Count": item['occurrenceCount']} for item in results] | |
st.dataframe(df_data) | |
# Logout button | |
if st.button("πͺ Logout"): | |
st.session_state.logged_in = False | |
st.experimental_rerun() | |
# Display connection info | |
st.sidebar.subheader("π Connection Information") | |
st.sidebar.text(f"Endpoint: {ENDPOINT}") | |
st.sidebar.text(f"Subscription ID: {SUBSCRIPTION_ID}") | |
st.sidebar.text(f"Database: {DATABASE_NAME}") | |
st.sidebar.text(f"Container: {CONTAINER_NAME}") |