Spaces:
Runtime error
Runtime error
import subprocess | |
#subprocess.run('pip install -r requirements.txt', shell = True) | |
import gradio as gr | |
import os | |
from PIL import Image | |
import numpy as np | |
from transformers import pipeline | |
import transformers | |
from langchain_community.document_loaders import TextLoader | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_fireworks import ChatFireworks | |
from langchain_core.prompts import ChatPromptTemplate | |
from transformers import AutoModelForImageClassification, AutoImageProcessor | |
from langchain import HuggingFacePipeline | |
def image_to_query(image): | |
""" | |
input: Image | |
function: Performs image classification using fine-tuned model | |
output: Query for the LLM | |
""" | |
image = Image.fromarray(image) | |
model = AutoModelForImageClassification.from_pretrained("nprasad24/bean_classifier", from_tf=True) | |
image_processor = AutoImageProcessor.from_pretrained("nprasad24/bean_classifier") | |
classifier = pipeline("image-classification", model=model, image_processor=image_processor) | |
scores = classifier(image) | |
# Get the dictionary with the maximum score | |
max_score_dict = max(scores, key=lambda x: x['score']) | |
# Extract the label with the maximum score | |
label_with_max_score = max_score_dict['label'] | |
# script to check if the image uploaded is indeed a leaf or not | |
counter = 0 | |
for ele in scores: | |
if 0.2 <= ele['score'] <= 0.4: | |
counter += 1 | |
if label_with_max_score == 'healthy' and counter != 3: | |
query = "The plant is healthy. Give tips on maintaining the plant" | |
elif label_with_max_score == 'bean_rust' and counter != 3: | |
query = "The detected disease is bean rust. Explain the disease" | |
elif label_with_max_score == 'angular_leaf_spot' and counter != 3: | |
query = "The detected disease is angular leaf spot. Explain the disease" | |
else: | |
query = "Given image is not of a plant." | |
return query | |
def ragChain(): | |
""" | |
function: creates a rag chain | |
output: rag chain | |
""" | |
loader = TextLoader("knowledgeBase.txt") | |
docs = loader.load() | |
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | |
docs = text_splitter.split_documents(docs) | |
vectorstore = vectorstore = FAISS.load_local("faiss_index", embeddings = HuggingFaceEmbeddings(), allow_dangerous_deserialization = True) | |
retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5}) | |
api_key = os.getenv("APIKEY") | |
llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct", api_key = api_key) | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
( | |
"system", | |
"""You are a knowledgeable agricultural assistant. If a disease is detected, you have to give information on the disease. | |
If the image is not of a plant, ask human to upload image of a plant and stop generating any response. | |
If the plant is healthy, just give maintenance tips. """ | |
), | |
( | |
"human", | |
"""Provide information about the leaf disease in question in bullet points. | |
Start your answer by mentioning the disease (if any) or healthy in this format: 'Condition: disease name'. | |
""", | |
), | |
("human", "{context}, {question}"), | |
] | |
) | |
rag_chain = ( | |
{ | |
"context": retriever, | |
"question": RunnablePassthrough() | |
} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
return rag_chain | |
def generate_response(rag_chain, query): | |
""" | |
input: rag chain, query | |
function: generates response using llm and knowledge base | |
output: generated response by the llm | |
""" | |
return rag_chain.invoke(f"{query}") | |
def main(image): | |
query = image_to_query(image) | |
chain = ragChain() | |
output = generate_response(chain, query) | |
return output | |
title = "Professor Bean: The Bean Disease Expert" | |
description = "Professor Bean is an agricultural expert. He will guide you on how to protect your plants from bean diseases" | |
app = gr.Interface(fn=main, | |
inputs="image", | |
outputs="text", | |
title=title, | |
description=description, | |
examples=[["sampleImages/sample1.jpg"], ["sampleImages/sample2.jpg"],["sampleImages/sample3.jpg"], ["sampleImages/sample4.jpeg"]] | |
) | |
app.launch(share=True) |