Professor-Bean / app.py
nprasad24's picture
Update app.py
faf18c9 verified
import subprocess
#subprocess.run('pip install -r requirements.txt', shell = True)
import gradio as gr
import os
from PIL import Image
import numpy as np
from transformers import pipeline
import transformers
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_fireworks import ChatFireworks
from langchain_core.prompts import ChatPromptTemplate
from transformers import AutoModelForImageClassification, AutoImageProcessor
from langchain import HuggingFacePipeline
def image_to_query(image):
"""
input: Image
function: Performs image classification using fine-tuned model
output: Query for the LLM
"""
image = Image.fromarray(image)
model = AutoModelForImageClassification.from_pretrained("nprasad24/bean_classifier", from_tf=True)
image_processor = AutoImageProcessor.from_pretrained("nprasad24/bean_classifier")
classifier = pipeline("image-classification", model=model, image_processor=image_processor)
scores = classifier(image)
# Get the dictionary with the maximum score
max_score_dict = max(scores, key=lambda x: x['score'])
# Extract the label with the maximum score
label_with_max_score = max_score_dict['label']
# script to check if the image uploaded is indeed a leaf or not
counter = 0
for ele in scores:
if 0.2 <= ele['score'] <= 0.4:
counter += 1
if label_with_max_score == 'healthy' and counter != 3:
query = "The plant is healthy. Give tips on maintaining the plant"
elif label_with_max_score == 'bean_rust' and counter != 3:
query = "The detected disease is bean rust. Explain the disease"
elif label_with_max_score == 'angular_leaf_spot' and counter != 3:
query = "The detected disease is angular leaf spot. Explain the disease"
else:
query = "Given image is not of a plant."
return query
def ragChain():
"""
function: creates a rag chain
output: rag chain
"""
loader = TextLoader("knowledgeBase.txt")
docs = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(docs)
vectorstore = vectorstore = FAISS.load_local("faiss_index", embeddings = HuggingFaceEmbeddings(), allow_dangerous_deserialization = True)
retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5})
api_key = os.getenv("APIKEY")
llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct", api_key = api_key)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a knowledgeable agricultural assistant. If a disease is detected, you have to give information on the disease.
If the image is not of a plant, ask human to upload image of a plant and stop generating any response.
If the plant is healthy, just give maintenance tips. """
),
(
"human",
"""Provide information about the leaf disease in question in bullet points.
Start your answer by mentioning the disease (if any) or healthy in this format: 'Condition: disease name'.
""",
),
("human", "{context}, {question}"),
]
)
rag_chain = (
{
"context": retriever,
"question": RunnablePassthrough()
}
| prompt
| llm
| StrOutputParser()
)
return rag_chain
def generate_response(rag_chain, query):
"""
input: rag chain, query
function: generates response using llm and knowledge base
output: generated response by the llm
"""
return rag_chain.invoke(f"{query}")
def main(image):
query = image_to_query(image)
chain = ragChain()
output = generate_response(chain, query)
return output
title = "Professor Bean: The Bean Disease Expert"
description = "Professor Bean is an agricultural expert. He will guide you on how to protect your plants from bean diseases"
app = gr.Interface(fn=main,
inputs="image",
outputs="text",
title=title,
description=description,
examples=[["sampleImages/sample1.jpg"], ["sampleImages/sample2.jpg"],["sampleImages/sample3.jpg"], ["sampleImages/sample4.jpeg"]]
)
app.launch(share=True)