File size: 4,727 Bytes
bfe7e73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b608750
bfe7e73
 
 
 
 
 
 
5cfc086
bfe7e73
 
 
 
 
 
 
2014617
 
bfe7e73
5cfc086
 
 
 
 
 
bfe7e73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cfc086
 
bfe7e73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4c008b
bfe7e73
 
c4c008b
bfe7e73
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import subprocess
subprocess.run('pip install -r requirements.txt', shell = True)

import gradio as gr
import os
from PIL import Image
import numpy as np
from transformers import pipeline
import transformers
transformers.logging.set_verbosity_error()
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_fireworks import ChatFireworks
from langchain_community.llms import Ollama
from langchain_core.prompts import ChatPromptTemplate
from transformers import AutoModelForImageClassification, AutoImageProcessor
from rich.console import Console
from rich.markdown import Markdown

transformers.logging.set_verbosity_error()



def image_to_query(image):
    """
    input: Image

    function: Performs image classification using fine-tuned model

    output: Query for the LLM
    """
    #image = Image.open(image)
    image = Image.fromarray(image.astype('uint8'), 'RGB')

    model = AutoModelForImageClassification.from_pretrained("nprasad24/bean_classifier", from_tf=True)
    image_processor = AutoImageProcessor.from_pretrained("nprasad24/bean_classifier")

    classifier = pipeline("image-classification", model=model, image_processor=image_processor)

    scores = classifier(image)

    # Get the dictionary with the maximum score
    max_score_dict = max(scores, key=lambda x: x['score'])

    # Extract the label with the maximum score
    label_with_max_score = max_score_dict['label']

    # script to check if the image uploaded is indeed a leaf or not
    counter = 0
    for ele in scores:
        if 0.2 <= ele['score'] <= 0.4:
            counter += 1

    if label_with_max_score == 'healthy' and counter != 3:
        query = "The plant is healthy. Give tips on maintaining the plant"
    elif label_with_max_score == 'bean_rust' and counter != 3:
        query = "The detected disease is bean rust. Explain the disease"
    elif label_with_max_score == 'angular_leaf_spot' and counter != 3:
        query = "The detected disease is angular leaf spot. Explain the disease"
    else:
        query = "Given image is not of a plant."

    return query

def ragChain():
    """
    function: creates a rag chain

    output: rag chain
    """
    loader = TextLoader("knowledgeBase.txt")
    docs = loader.load()

    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    docs = text_splitter.split_documents(docs)

    vectorstore = FAISS.from_documents(documents = docs, embedding = HuggingFaceEmbeddings())
    retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5})

    APIKEY = "o7T3gVx9Vt8GSJbLyPV1974vF8LXVp01CWqOkWQuHgoHm07H"
    os.environ["FIREWORKS_API_KEY"] = APIKEY

    llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct")

    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """You are a knowledgeable agricultural assistant. If a disease is detected, you have to give information on the disease.
                If the image is not of a plant, ask human to upload image of a plant and stop generating any response.
                If the plant is healthy, just give maintenance tips. """
            ),
            (
                "human",
                """Provide information about the leaf disease in question in bullet points. 
                Start your answer by mentioning the disease (if any) or healthy in this format: 'Condition: disease name'.
                """,
            ),
            
            ("human", "{context}, {question}"),
        ]
    )

    rag_chain = (
    { 
    "context": retriever,
    "question": RunnablePassthrough()
    }
    | prompt
    | llm
    | StrOutputParser()
    )

    return rag_chain

def generate_response(rag_chain, query):
    """
    input: rag chain, query

    function: generates response using llm and knowledge base

    output: generated response by the llm
    """
    return Markdown(rag_chain.invoke(f"{query}"))

def main(image):
    console = Console()

    query = image_to_query(image)
    chain = ragChain()
    output = generate_response(chain, query)
    return output

title = "Bean Classifier and Instructor"
description = "Professor Bean is an agricultural expert. He will guide you on how to protect your plants from bean diseases"
app = gr.Interface(fn=main, inputs="image", outputs="text", title=title,
    description=description)
app.launch(share=True)