Spaces:
Sleeping
Sleeping
import os | |
import time | |
import requests | |
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
import openai | |
import langchain | |
import random | |
from langchain_openai import ChatOpenAI | |
from langchain.cache import InMemoryCache | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import LLMChain | |
# Set environment variables and OpenAI configurations | |
api_keys = os.environ | |
openai.api_key = os.environ['OPENAI_API_KEY'] | |
langchain.llm_cache = InMemoryCache() | |
app = Flask(__name__) | |
CORS(app) | |
# This sends final LLM output to dynamically get tagged with UneeQ emotions and expressions | |
def process_emotions(query): | |
try: | |
# URL of the FT LLM API endpoint (ECS) | |
url = "https://api-ft-inline-tags.caipinnovation.com/query" | |
# Payload to be sent in the POST request | |
payload = {"prompt": query} | |
# Making the POST request | |
response = requests.post(url, json=payload) | |
# Checking if the request was successful | |
response.raise_for_status() | |
# Returning the 'answer' from the response | |
ft_answer = response.json().get("answer") | |
return ft_answer | |
except Exception as e: | |
raise | |
#This handles general Q&A to the LLM | |
def process_query(query, chat_history, systemMessage, emotions): | |
try: | |
print(f"calling fine_tuned_model") | |
#Get name of model from ENV | |
ft_model_name = os.environ.get("OPENAI_MODEL_NAME") | |
#Model name from env will be used here: | |
fine_tuned_model = ChatOpenAI( | |
temperature=0, model_name=ft_model_name | |
) | |
prompt_template = """System: {systemMessage}. | |
User: The user is inquiring about cataracts or cataract surgery. Answer their question: {query}""" | |
PROMPT = PromptTemplate(template=prompt_template, | |
input_variables=["systemMessage", "query"]) | |
chain = LLMChain(llm=fine_tuned_model, prompt=PROMPT, verbose=False) | |
input_prompt = [{"systemMessage": systemMessage, "query": query}] | |
generatedResponse = chain.apply(input_prompt) | |
#Replace/filter out any prepended strings from LLM response | |
#Sometimes we have issues that the LLM writes these following strings before answer. Use if needed. | |
llm_response = generatedResponse[0]["text"].replace("Answer:", "").replace("System:", "").lstrip() | |
#NOW SEND RESPONSE TO GET TAGGED w/ Emotions and Expressions | |
if emotions: | |
try: | |
llm_response_ft = process_emotions(llm_response) | |
except Exception as e: | |
# Log the error | |
print(f"Error processing emotions for query: {llm_response}. Error: {str(e)}") | |
# Return the error response | |
return {"error": "Error processing emotions", "query": llm_response} | |
return { | |
"answer": llm_response_ft, | |
"source_documents": "" | |
} | |
else: | |
return { | |
"answer": llm_response, | |
"source_documents": "" | |
} | |
except Exception as e: | |
print(f"Error processing query: {query}. Error: {str(e)}") | |
return {"error": "Error processing query"} | |
#This handles the chart functionality in HIMSS | |
def process_chart(query, s1 ,s2): | |
try: | |
print("calling fine_tuned_model") | |
#Get name of model from ENV | |
ft_model_name = os.environ.get("OPENAI_MODEL_NAME") | |
#Model name from env will be used here: | |
fine_tuned_model = ChatOpenAI( | |
temperature=0, model_name=ft_model_name | |
) | |
prompt_template = """System: {systemMessage}. | |
User: {query}""" | |
PROMPT = PromptTemplate(template=prompt_template, | |
input_variables=["systemMessage", "query"]) | |
chain = LLMChain(llm=fine_tuned_model, prompt=PROMPT, verbose=False) | |
# Get systemMessage from env file: | |
systemMessage = os.environ.get("SYSTEM_MESSAGE") | |
input_prompt = [{"systemMessage": systemMessage, "query": query}] | |
generatedResponse = chain.apply(input_prompt) | |
print("after ", generatedResponse[0]["text"]) | |
#Replace/filter out any prepended strings from LLM response | |
#Sometimes we have issues that the LLM writes these following strings before answer. Use if needed. | |
generatedResponse_filtered = generatedResponse[0]["text"].replace("Answer:", "").replace("System:", "").lstrip() | |
stripped_answer = f"I see you have {s1} and {s2}. {generatedResponse_filtered}" | |
return { | |
"answer": stripped_answer, | |
"source_documents": "" | |
} | |
except Exception as e: | |
print(f"Error processing query: {query}. Error: {str(e)}") | |
return {"error": "Error processing query"} | |
#POST request to this service | |
def handle_query(): | |
data = request.json | |
query=data['prompt'] | |
chatHistory=data['chatHistory'] | |
systemMessage='You are a helpful medical assistant.' | |
answer = '' | |
emotions = '' | |
result = process_query(query, chatHistory, systemMessage, emotions) | |
answer = result['answer'] | |
serialized_result = { | |
"query": query, | |
"answer": answer, | |
"source_documents": "" | |
} | |
return jsonify(serialized_result) | |
# Helper Functions | |
def pick_random_issues(issues): | |
# Randomly select two strings from the list | |
random_strings = random.sample(issues, 2) | |
return random_strings | |
def generate_description(): | |
issues = ["Anterior Uveitis", "Corneal Guttata", "Diabetes", "Diabetes Mellitus", "Glaucoma", "Retinal Detachment", "Corticosteroids", "Phenothiazine", "Chlorpromazine", "Ultraviolet Radiation Exposure", "Smoking", "High Alcohol Consumption", "Poor Nutrition"] | |
random_strings = pick_random_issues(issues) | |
random_string_1, random_string_2 = random_strings | |
description = f"Describe any issues I may encounter due to {random_string_1} and {random_string_2} relative to my upcoming cataract surgery?" | |
return description, random_string_1, random_string_2 | |
#GET request to chart feature | |
def handle_chart(): | |
description, random_string_1, random_string_2 = generate_description() | |
query = description | |
if not query: | |
return jsonify({"error": "No query provided"}), 400 | |
result = process_chart(query, random_string_1, random_string_2) | |
if "error" in result: | |
return jsonify(result), 500 | |
serialized_result = { | |
"query": query, | |
"answer": result["answer"], | |
"source_documents": "" | |
} | |
return jsonify(serialized_result) | |
def hello(): | |
version = os.environ.get("CODE_VERSION") | |
return jsonify({"status": "Healthy", "version": version}), 200 | |
if __name__ == '__main__': | |
app.run(port=7860) | |