File size: 5,983 Bytes
cc5b602
6f619d7
ae90620
6386510
677d853
51a7d9e
652620b
6386510
51a7d9e
652620b
e6367a7
da0337e
51a7d9e
6386510
bd34f0b
da0337e
bd34f0b
 
51a7d9e
6386510
51a7d9e
 
bd34f0b
 
 
 
 
 
 
51a7d9e
 
da59244
652620b
 
7cb9567
 
 
 
 
 
652620b
0486bff
 
b179e70
6b67af9
677d853
f77fb99
0486bff
4ed884e
 
3d7390f
 
4ed884e
 
 
 
652620b
4ed884e
 
 
652620b
3d7390f
 
 
652620b
 
 
 
 
 
 
 
ce84a62
652620b
 
 
 
 
c4592e6
4ed884e
c4592e6
 
 
f77fb99
652620b
 
27dc368
652620b
 
 
 
 
 
 
 
51a7d9e
652620b
6386510
51a7d9e
fed0852
51a7d9e
 
 
 
 
0486bff
51a7d9e
3d7390f
da0337e
3d7390f
 
 
51a7d9e
 
 
 
 
 
 
 
 
 
4ed884e
51a7d9e
 
652620b
51a7d9e
 
bd34f0b
 
 
 
4ed884e
bd34f0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ed884e
bd34f0b
 
 
51a7d9e
 
040bf4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51a7d9e
 
 
 
 
652620b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import time
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import gradio as gr
from threading import Thread

HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL = os.environ.get("MODEL_ID")

TITLE = "<h1><center>KissanAI - Dhenu2 India - Climate Resilient and Sustainable Agriculture Experimental Model</center></h1>"

PLACEHOLDER = """
<center>
<p>Hi, I'm Dhenu. Ask me anything about Climate Resilient and Sustainable Agriculture in India.</p>
</center>
"""


CSS = """
.duplicate-button {
    margin: auto !important;
    color: white !important;
    background: black !important;
    border-radius: 100vh !important;
}
h3 {
    text-align: center;
}
"""

device = "cuda" # for GPU usage or "cpu" for CPU usage

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type= "nf4")

tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(
    MODEL,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=quantization_config)

@spaces.GPU()
def stream_chat(
    message: str, 
    history: list,
    system_prompt: str,
    temperature: float = 0.8, 
    max_new_tokens: int = 1024, 
    top_p: float = 1.0, 
    top_k: int = 20, 
    penalty: float = 1.2,
):
    print(f'message: {message}')
    print(f'history: {history}')

    conversation = [
        {"role": "system", "content": system_prompt}
    ]
    for prompt, answer in history:
        conversation.extend([
            {"role": "user", "content": prompt}, 
            {"role": "assistant", "content": answer},
        ])

    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
    
    streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
    
    generate_kwargs = dict(
        input_ids=input_ids, 
        max_new_tokens = max_new_tokens,
        do_sample = False if temperature == 0 else True,
        top_p = top_p,
        top_k = top_k,
        temperature = temperature,
        eos_token_id=[128001,128008,128009],
        streamer=streamer,
    )

    with torch.no_grad():
        thread = Thread(target=model.generate, kwargs=generate_kwargs)
        thread.start()
        
    buffer = ""
    for new_text in streamer:
        buffer += new_text
        yield buffer

            
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)

with gr.Blocks(css=CSS, theme="gradio/soft") as demo:
    gr.HTML(TITLE)
    gr.ChatInterface(
        fn=stream_chat,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Textbox(
                value="You are an climate resilient and sustainable agriculture assistant in the context of India. Provide precise and actionable response in proper markdown format.",
                label="System Prompt",
                render=False,
            ),
            gr.Slider(
                minimum=0,
                maximum=1,
                step=0.1,
                value=0.8,
                label="Temperature",
                render=False,
            ),
            gr.Slider(
                minimum=128,
                maximum=8192,
                step=1,
                value=1024,
                label="Max new tokens",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=1.0,
                step=0.1,
                value=1.0,
                label="top_p",
                render=False,
            ),
            gr.Slider(
                minimum=1,
                maximum=20,
                step=1,
                value=20,
                label="top_k",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=2.0,
                step=0.1,
                value=1.2,
                label="Repetition penalty",
                render=False,
            ),
        ],
        examples=[
    ["What are the best drought-resistant crops for farmers in Rajasthan?"],
    ["How can I implement rainwater harvesting on my farm?"],
    ["What are the most effective soil conservation techniques for terraced fields?"],
    ["Which crop rotation practices can improve soil health in Punjab?"],
    ["How can I manage pest outbreaks using sustainable methods?"],
    ["What are the benefits of using biofertilizers in paddy cultivation?"],
    ["How can I optimize water usage for irrigation during the dry season?"],
    ["What are the recommended practices for organic farming in Karnataka?"],
    ["How can I protect my crops from unpredictable monsoon patterns?"],
    ["What are the best practices for integrating livestock with crop farming?"],
    ["How can agroforestry enhance the resilience of my farm?"],
    ["What sustainable techniques can reduce the impact of flooding on my crops?"],
    ["How can I use weather forecasting to plan my planting schedule?"],
    ["What are the advantages of using drip irrigation over traditional methods?"],
    ["How can I improve soil fertility without relying on chemical fertilizers?"],
    ["What are the key indicators of climate resilience in agriculture?"],
    ["How can I access government schemes for sustainable farming practices?"],
    ["What are the best methods for conserving biodiversity on my farm?"],
    ["How can I reduce greenhouse gas emissions from my agricultural activities?"],
    ["What technologies are available for monitoring crop health in real-time?"]
],
        cache_examples=False,
    )


if __name__ == "__main__":
    demo.launch()