import gradio as gr from sentence_transformers import SentenceTransformer, util from transformers import pipeline, GPT2Tokenizer import os # Define paths and models filename = "output_country_details.txt" # Adjust the filename as needed retrieval_model_name = 'output/sentence-transformer-finetuned/' gpt2_model_name = "gpt2" # GPT-2 model tokenizer = GPT2Tokenizer.from_pretrained("gpt2") # Load models try: retrieval_model = SentenceTransformer(retrieval_model_name) gpt_model = pipeline("text-generation", model=gpt2_model_name) print("Models loaded successfully.") except Exception as e: print(f"Failed to load models: {e}") # Load and preprocess text from the country details file def load_and_preprocess_text(filename): try: with open(filename, 'r', encoding='utf-8') as file: segments = [line.strip() for line in file if line.strip()] print("Text loaded and preprocessed successfully.") return segments except Exception as e: print(f"Failed to load or preprocess text: {e}") return [] segments = load_and_preprocess_text(filename) def find_relevant_segment(user_query, segments): try: query_embedding = retrieval_model.encode(user_query) segment_embeddings = retrieval_model.encode(segments) similarities = util.pytorch_cos_sim(query_embedding, segment_embeddings)[0] best_idx = similarities.argmax() print("Relevant segment found:", segments[best_idx]) return segments[best_idx] except Exception as e: print(f"Error finding relevant segment: {e}") return "" def generate_response(user_query, relevant_segment): try: # Construct the prompt with the user query prompt = f"Thank you for your question! this is an additional fact about your topic: {relevant_segment}" # Generate response with adjusted max_length for completeness max_tokens = len(tokenizer(prompt)['input_ids']) + 50 response = gpt_model(prompt, max_length=max_tokens, temperature=0.25)[0]['generated_text'] # Clean and format the response response_cleaned = clean_up_response(response, relevant_segment) return response_cleaned except Exception as e: print(f"Error generating response: {e}") return "" def clean_up_response(response, segments): # Split the response into sentences sentences = response.split('.') # Remove empty sentences and any repetitive parts cleaned_sentences = [] for sentence in sentences: if sentence.strip() and sentence.strip() not in segments and sentence.strip() not in cleaned_sentences: cleaned_sentences.append(sentence.strip()) # Join the sentences back together cleaned_response = '. '.join(cleaned_sentences).strip() # Check if the last sentence ends with a complete sentence if cleaned_response and not cleaned_response.endswith((".", "!", "?")): cleaned_response += "." return cleaned_response # Define the welcome message with markdown for formatting and larger fonts welcome_message = """ # Welcome to VISABOT! ## Your AI-driven visa assistant for all travel-related queries. """ # Define topics and countries with flag emojis topics = """ ### Feel Free to ask me anything from the topics below! - Visa issuance - Documents needed - Application process - Processing time - Recommended Vaccines - Health Risks - Healthcare Facilities - Currency Information - Embassy Information - Allowed stay """ countries = """ ### Our chatbot can currently answer questions for these countries! - πŸ‡¨πŸ‡³ China - πŸ‡«πŸ‡· France - πŸ‡¬πŸ‡Ή Guatemala - πŸ‡±πŸ‡§ Lebanon - πŸ‡²πŸ‡½ Mexico - πŸ‡΅πŸ‡­ Philippines - πŸ‡·πŸ‡Έ Serbia - πŸ‡ΈπŸ‡± Sierra Leone - πŸ‡ΏπŸ‡¦ South Africa - πŸ‡»πŸ‡³ Vietnam """ # Define the Gradio app interface def query_model(question): if question == "": # If there's no input, the bot will display the greeting message. return welcome_message relevant_segment = find_relevant_segment(question, segments) response = generate_response(question, relevant_segment) return response # Create Gradio Blocks interface for custom layout with gr.Blocks() as demo: gr.Markdown(welcome_message) # Display the welcome message with large fonts with gr.Row(): with gr.Column(): gr.Markdown(topics) # Display the topics on the left with gr.Column(): gr.Markdown(countries) # Display the countries with flag emojis on the right with gr.Row(): img = gr.Image(os.path.join(os.getcwd(), "final.png"), width=500) # Adjust width as needed with gr.Row(): with gr.Column(): question = gr.Textbox(label="Your question", placeholder="What do you want to ask about?") answer = gr.Textbox(label="VisaBot Response", placeholder="VisaBot will respond here...", interactive=False, lines=10) submit_button = gr.Button("Submit") submit_button.click(fn=query_model, inputs=question, outputs=answer) # Launch the app demo.launch()