import os import spaces import torch from transformers import AutoTokenizer, AutoModelForCausalLM # import torch # from transformers import ( # AutoTokenizer, # TextStreamer, # pipeline, # BitsAndBytesConfig, # AutoModelForCausalLM # ) from langchain.embeddings import HuggingFaceInstructEmbeddings from langchain.vectorstores import Chroma from langchain.prompts import PromptTemplate from langchain.chains import RetrievalQA from langchain.llms import HuggingFacePipeline import gradio as gr DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" model_id = "meta-llama/Llama-3.2-3B-Instruct" # Remove the spaces.GPU decorator since we'll handle GPU directly # def initialize_model(): # bnb_config = BitsAndBytesConfig( # load_in_4bit=True, # bnb_4bit_use_double_quant=True, # bnb_4bit_quant_type="nf4", # bnb_4bit_compute_dtype=torch.bfloat16 # ) # tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ.get("HF_TOKEN")) # model = AutoModelForCausalLM.from_pretrained( # model_id, # token=os.environ.get("HF_TOKEN"), # quantization_config=bnb_config if torch.cuda.is_available() else None, # device_map="auto" if torch.cuda.is_available() else "cpu", # torch_dtype=torch.float32 if not torch.cuda.is_available() else None # ) # return model, tokenizer def initialize_model(): model_id = "meta-llama/Llama-3.2-3B-Instruct" token = os.environ.get("HF_TOKEN") tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) model = AutoModelForCausalLM.from_pretrained( model_id, token=token, device_map="auto" # This works better with ZeroGPU ) return model, tokenizer @spaces.GPU def respond(message, history, system_message, max_tokens, temperature, top_p): try: model, tokenizer = initialize_model() streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) text_pipeline = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, repetition_penalty=1.15, streamer=streamer, ) llm = HuggingFacePipeline(pipeline=text_pipeline) qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=db.as_retriever(search_kwargs={"k": 2}), return_source_documents=False, chain_type_kwargs={"prompt": prompt_template} ) response = qa_chain.invoke({"query": message}) return response["result"] except Exception as e: return f"An error occurred: {str(e)}" demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox( value=DEFAULT_SYSTEM_PROMPT, label="System Message", lines=3, visible=False ), gr.Slider( minimum=1, maximum=2048, value=500, step=1, label="Max new tokens" ), gr.Slider( minimum=0.1, maximum=1.0, value=0.1, step=0.1, label="Temperature" ), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p" ), ], title="ROS2 Expert Assistant", description="Ask questions about ROS2, navigation, and robotics. I'll provide concise answers based on the available documentation.", )