pablorocg commited on
Commit
ff770dd
1 Parent(s): 02fa3b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -56
app.py CHANGED
@@ -343,67 +343,66 @@ def answer_query(query_text, index, documents, llm_model, llm_tokenizer, embeddi
343
 
344
 
345
 
346
- if __name__ == '__main__':
347
- import os
348
- from faiss import write_index
349
- import gradio as gr
350
- import numpy as np
351
- import torch
352
- from tqdm import tqdm
353
- from torch.utils.data import DataLoader, Dataset
354
- from datasets import load_dataset
355
- import pandas as pd
356
- import faiss
357
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoModel
358
- from transformers import TextIteratorStreamer
359
- from threading import Thread
360
 
361
- torch.set_num_threads(2)
362
- HF_TOKEN = os.environ.get("SECRET_TOKEN")
363
 
364
- class CFG:
365
- embedding_model = 'TimKond/S-PubMedBert-MedQuAD'
366
- batch_size = 128
367
- device = ('cuda' if torch.cuda.is_available() else 'cpu')
368
- llm = 'google/gemma-2b-it'
369
- n_samples = 3
370
-
371
- # Show config
372
- config = CFG()
373
- # config_items = {k: v for k, v in vars(CFG).items() if not k.startswith('__')}
374
- # print(tabulate(config_items.items(), headers=['Parameter', 'Value'], tablefmt='fancy_grid'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
-
377
- # Obtener los datos y cargar o generar el índice
378
- df = get_all_data()
379
- documents = TextDataset(df)
380
- if not os.path.exists('./storage/faiss_index.faiss'):
381
- embeddings = get_bert_embeddings(documents, CFG.batch_size, CFG.embedding_model, CFG.device)
382
- index = create_faiss_index(embeddings)
383
- write_index(index, './storage/faiss_index.faiss')
384
- else:
385
- index = faiss.read_index('./storage/faiss_index.faiss')
386
-
387
- # Load the model
388
- quantization_config = BitsAndBytesConfig(
389
- load_in_4bit=True,
390
- bnb_4bit_use_double_quant=True,
391
- bnb_4bit_quant_type="nf4",
392
- bnb_4bit_compute_dtype=torch.bfloat16
393
- )
394
 
395
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=HF_TOKEN)
396
- model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", quantization_config=quantization_config, torch_dtype=torch.float16, low_cpu_mem_usage=True, token=HF_TOKEN)
397
 
 
 
398
 
399
- def make_inference(query, hist):
400
- return answer_query(query, index, documents, model, tokenizer, CFG.embedding_model, CFG.n_samples, CFG.device)
401
-
402
- demo = gr.ChatInterface(fn = make_inference,
403
- examples = ["What is diabetes?", "Is ginseng good for diabetes?", "What are the symptoms of diabetes?", "What is Celiac disease?"],
404
- title = "Gemma 2b MedicalQA Chatbot",
405
- description = "Gemma 2b Medical Chatbot is a chatbot that can help you with your medical queries. It is not a replacement for a doctor. Please consult a doctor for any medical advice.",
406
- )
407
- demo.launch()
408
 
409
 
 
343
 
344
 
345
 
346
+ # import os
347
+ # from faiss import write_index
348
+ # import gradio as gr
349
+ # import numpy as np
350
+ # import torch
351
+ # from tqdm import tqdm
352
+ # from torch.utils.data import DataLoader, Dataset
353
+ # from datasets import load_dataset
354
+ # import pandas as pd
355
+ # import faiss
356
+ # from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoModel
357
+ # from transformers import TextIteratorStreamer
358
+ # from threading import Thread
 
359
 
360
+ # torch.set_num_threads(2)
361
+ # HF_TOKEN = os.environ.get("SECRET_TOKEN")
362
 
363
+ class CFG:
364
+ embedding_model = 'TimKond/S-PubMedBert-MedQuAD'
365
+ batch_size = 128
366
+ device = ('cuda' if torch.cuda.is_available() else 'cpu')
367
+ llm = 'google/gemma-2b-it'
368
+ n_samples = 3
369
+
370
+ # Show config
371
+ config = CFG()
372
+ # config_items = {k: v for k, v in vars(CFG).items() if not k.startswith('__')}
373
+ # print(tabulate(config_items.items(), headers=['Parameter', 'Value'], tablefmt='fancy_grid'))
374
+
375
+
376
+ # Obtener los datos y cargar o generar el índice
377
+ df = get_all_data()
378
+ documents = TextDataset(df)
379
+ if not os.path.exists('./storage/faiss_index.faiss'):
380
+ embeddings = get_bert_embeddings(documents, CFG.batch_size, CFG.embedding_model, CFG.device)
381
+ index = create_faiss_index(embeddings)
382
+ write_index(index, './storage/faiss_index.faiss')
383
+ else:
384
+ index = faiss.read_index('./storage/faiss_index.faiss')
385
+
386
+ # Load the model
387
+ quantization_config = BitsAndBytesConfig(
388
+ load_in_4bit=True,
389
+ bnb_4bit_use_double_quant=True,
390
+ bnb_4bit_quant_type="nf4",
391
+ bnb_4bit_compute_dtype=torch.bfloat16
392
+ )
393
 
394
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=HF_TOKEN)
395
+ model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", quantization_config=quantization_config, torch_dtype=torch.float16, low_cpu_mem_usage=True, token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
 
 
397
 
398
+ def make_inference(query, hist):
399
+ return answer_query(query, index, documents, model, tokenizer, CFG.embedding_model, CFG.n_samples, CFG.device)
400
 
401
+ demo = gr.ChatInterface(fn = make_inference,
402
+ examples = ["What is diabetes?", "Is ginseng good for diabetes?", "What are the symptoms of diabetes?", "What is Celiac disease?"],
403
+ title = "Gemma 2b MedicalQA Chatbot",
404
+ description = "Gemma 2b Medical Chatbot is a chatbot that can help you with your medical queries. It is not a replacement for a doctor. Please consult a doctor for any medical advice.",
405
+ )
406
+ demo.launch()
 
 
 
407
 
408