nexusai-v2 / scripts /text_gen /story_gen.py
suneeln-duke's picture
push
85300c0
raw
history blame contribute delete
No virus
2.47 kB
import langchain.document_loaders
from langchain.document_loaders import DirectoryLoader, PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores.chroma import Chroma
import os
import shutil
from langchain.vectorstores.chroma import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
def get_chunks(file_path):
loader = PyPDFLoader(file_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=300,
chunk_overlap=100,
length_function=len,
add_start_index=True,
)
chunks = text_splitter.split_documents(documents)
return chunks
def get_vectordb(chunks, CHROMA_PATH):
CHROMA_PATH = f"../../chroma/{CHROMA_PATH}"
if os.path.exists(CHROMA_PATH):
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=OpenAIEmbeddings())
else:
db = Chroma.from_documents(
chunks, OpenAIEmbeddings(), persist_directory=CHROMA_PATH
)
db.persist()
print(f"Saved {len(chunks)} chunks to {CHROMA_PATH}.")
return db
def gen_sample(text, decision, db):
PROMPT_TEMPLATE = """
Answer the question based only on the following context:
{context}
---
Answer the question based on the above context: {question}
"""
query_text = f"""
Act as the author of a Choose Your Own Adventure Book. This book is special as it is based on existing material.
Now, as with any choose your own adventure book, there are inifinite paths based on the choices a user makes.
Given some relevant text and the decision taken with respect to the relevant text, generate the next part of the story.
It should be within 6-8 sentences and be coherent as it were actually part of the story.
Relevant: {text}
Decision: {decision}
"""
results = db.similarity_search_with_relevance_scores(query_text, k=5)
context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
prompt = prompt_template.format(context=context_text, question=query_text)
model = ChatOpenAI()
response_text = model.predict(prompt)
return eval(response_text)