Nathan Slaughter commited on
Commit
4d17caa
1 Parent(s): b8a0d78

add pytorch manual method

Browse files
.github/workflows/python-app.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # .github/workflows/python-app.yml
2
+
3
+ name: Python application
4
+
5
+ on:
6
+ push:
7
+ branches: [ main ]
8
+ pull_request:
9
+ branches: [ main ]
10
+
11
+ jobs:
12
+ build:
13
+
14
+ runs-on: ubuntu-latest
15
+
16
+ steps:
17
+ - uses: actions/checkout@v2
18
+ - name: Set up Python
19
+ uses: actions/setup-python@v2
20
+ with:
21
+ python-version: '3.8'
22
+ - name: Install dependencies
23
+ run: |
24
+ python -m pip install --upgrade pip
25
+ pip install -r requirements.txt
26
+ pip install pytest pytest-mock
27
+ - name: Run tests
28
+ run: |
29
+ pytest
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from app.interface import create_interface
2
+
3
+ def main():
4
+ interface = create_interface()
5
+ interface.launch()
6
+
7
+ if __name__ == "__main__":
8
+ main()
app/__init__.py ADDED
File without changes
app/interface.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from .models import LanguageModel
3
+ from .processing import process_file, process_text_input
4
+
5
+ def create_interface():
6
+ # Initialize the language model
7
+ language_model = LanguageModel()
8
+
9
+ # Define the Output Format Selector
10
+ output_format_selector = gr.Radio(
11
+ choices=["CSV", "JSON"],
12
+ label="Select Output Format",
13
+ value="JSON",
14
+ type="value"
15
+ )
16
+
17
+ # Define the Output Flashcards
18
+ flashcard_output_file = gr.Textbox(
19
+ label="Flashcards",
20
+ lines=20,
21
+ placeholder="Extracted flashcards will appear here..."
22
+ )
23
+ flashcard_output_text = gr.Textbox(
24
+ label="Flashcards",
25
+ lines=20,
26
+ placeholder="Extracted flashcards will appear here..."
27
+ )
28
+
29
+ # Define the Gradio interface function for File Upload
30
+ def handle_file_upload(file_obj, output_format):
31
+ try:
32
+ flashcards = process_file(file_obj, output_format, language_model)
33
+ return flashcards
34
+ except ValueError as ve:
35
+ return str(ve)
36
+
37
+ # Define the Gradio interface function for Text Input
38
+ def handle_text_input(input_text, output_format):
39
+ try:
40
+ flashcards = process_text_input(input_text, output_format, language_model)
41
+ return flashcards
42
+ except ValueError as ve:
43
+ return str(ve)
44
+
45
+ # Create the Gradio Tabs
46
+ with gr.Blocks() as interface:
47
+ gr.Markdown("# Flashcard Extraction Tool")
48
+ gr.Markdown(
49
+ "Extract flashcards from uploaded files or directly input text. Choose your preferred output format."
50
+ )
51
+ with gr.Tab("Upload File"):
52
+ with gr.Row():
53
+ with gr.Column():
54
+ file_input = gr.File(
55
+ label="Upload a File",
56
+ file_types=['.pdf', '.txt', '.md']
57
+ )
58
+ format_selector = gr.Radio(
59
+ choices=["CSV", "JSON"],
60
+ label="Select Output Format",
61
+ value="JSON",
62
+ type="value"
63
+ )
64
+ submit_file = gr.Button("Extract Flashcards")
65
+ with gr.Column():
66
+ flashcard_output_file = gr.Textbox(
67
+ label="Flashcards",
68
+ lines=20,
69
+ placeholder="Extracted flashcards will appear here..."
70
+ )
71
+ submit_file.click(
72
+ fn=handle_file_upload,
73
+ inputs=[file_input, format_selector],
74
+ outputs=flashcard_output_file
75
+ )
76
+
77
+ with gr.Tab("Input Text"):
78
+ with gr.Row():
79
+ with gr.Column():
80
+ text_input = gr.Textbox(
81
+ label="Enter Text",
82
+ lines=20,
83
+ placeholder="Type or paste your text here..."
84
+ )
85
+ format_selector_text = gr.Radio(
86
+ choices=["CSV", "JSON"],
87
+ label="Select Output Format",
88
+ value="JSON",
89
+ type="value"
90
+ )
91
+ submit_text = gr.Button("Extract Flashcards")
92
+ with gr.Column():
93
+ flashcard_output_text = gr.Textbox(
94
+ label="Flashcards",
95
+ lines=20,
96
+ placeholder="Extracted flashcards will appear here..."
97
+ )
98
+ submit_text.click(
99
+ fn=handle_text_input,
100
+ inputs=[text_input, format_selector_text],
101
+ outputs=flashcard_output_text
102
+ )
103
+
104
+ gr.Markdown(
105
+ """
106
+ ---
107
+ **Notes:**
108
+ - Supported file types: `.pdf`, `.txt`, `.md`.
109
+ - Ensure that the input text is clear and well-structured for optimal flashcard extraction.
110
+ """
111
+ )
112
+
113
+ return interface
app/models.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+
4
+ class LanguageModel:
5
+ def __init__(self, model_name: str = "Qwen/Qwen2.5-7B-Instruct"):
6
+ self.device = self._determine_device()
7
+ self.model = AutoModelForCausalLM.from_pretrained(
8
+ model_name,
9
+ torch_dtype="auto",
10
+ device_map="auto"
11
+ )
12
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+
14
+ def _determine_device(self):
15
+ if torch.cuda.is_available():
16
+ return torch.device("cuda")
17
+ elif torch.backends.mps.is_available():
18
+ return torch.device("mps")
19
+ else:
20
+ return torch.device("cpu")
21
+
22
+ def generate_flashcards(self, prompt: str, max_new_tokens: int = 1024) -> str:
23
+ inputs = self.tokenizer(prompt, return_tensors='pt').to(self.model.device)
24
+ with torch.no_grad():
25
+ output_ids = self.model.generate(
26
+ inputs.input_ids,
27
+ max_new_tokens=max_new_tokens,
28
+ do_sample=True
29
+ )
30
+ response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
31
+ return response
app/processing.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pymupdf4llm
3
+
4
+ def process_pdf(pdf_path: str) -> str:
5
+ """
6
+ Extracts text from a PDF file using pymupdf4llm.
7
+ """
8
+ try:
9
+ text = pymupdf4llm.extract_text(pdf_path)
10
+ return text
11
+ except Exception as e:
12
+ raise ValueError(f"Error processing PDF: {str(e)}")
13
+
14
+ def read_text_file(file_path: str) -> str:
15
+ """
16
+ Reads text from a .txt or .md file.
17
+ """
18
+ try:
19
+ with open(file_path, 'r', encoding='utf-8') as f:
20
+ text = f.read()
21
+ return text
22
+ except Exception as e:
23
+ raise ValueError(f"Error reading text file: {str(e)}")
24
+
25
+ def format_prompt(output_format: str) -> str:
26
+ """
27
+ Formats the prompt based on the output type.
28
+ """
29
+ if output_format.lower() == "json":
30
+ return """You only respond with cards in JSON format. Follow the example below.
31
+
32
+ EXAMPLE:
33
+ [
34
+ {"question": "What is AI?", "answer": "Artificial Intelligence."},
35
+ {"question": "What is ML?", "answer": "Machine Learning."}
36
+ ...
37
+ ]
38
+ """
39
+ elif output_format.lower() == "csv":
40
+ return """You only respond with cards in CSV format. Follow the example below.
41
+
42
+ EXAMPLE:
43
+ "What is AI?", "Artificial Intelligence."
44
+ "What is ML?", "Machine Learning."
45
+ ...
46
+ """
47
+
48
+ def extract_flashcards(text: str, output_format: str, language_model: str) -> str:
49
+ """
50
+ Extracts flashcards from the input text using the LLM and formats them in CSV or JSON.
51
+ """
52
+ prompt = f"""You are an expert flashcard creator. You always include a single knowledge item per flashcard.
53
+
54
+ {format_prompt(output_format)}
55
+
56
+
57
+ Extract flashcards from the user's text:
58
+
59
+ {text}
60
+
61
+ Do not include the prompt or any other unnecessary information in the flashcards.
62
+ Do not include triple ticks (```) or any other code blocks in the flashcards.
63
+ """
64
+ # TODO:
65
+ # see https://qwen.readthedocs.io/en/latest/inference/chat.html
66
+ # e.g. pipeline = pipeline("text-generation", model="Qwen/Qwen2.5-7B-Instruct")
67
+ response = language_model.generate_flashcards(prompt)
68
+ return response
69
+
70
+ def process_file(file_obj, output_format: str, language_model) -> str:
71
+ """
72
+ Processes the uploaded file based on its type and extracts flashcards.
73
+ """
74
+ file_path = file_obj.name
75
+ file_ext = os.path.splitext(file_path)[1].lower()
76
+
77
+ if file_ext == '.pdf':
78
+ text = process_pdf(file_path)
79
+ elif file_ext in ['.txt', '.md']:
80
+ text = read_text_file(file_path)
81
+ else:
82
+ raise ValueError("Unsupported file type.")
83
+
84
+ flashcards = extract_flashcards(text, output_format, language_model)
85
+ return flashcards
86
+
87
+ def process_text_input(input_text: str, output_format: str, language_model) -> str:
88
+ """
89
+ Processes the input text and extracts flashcards.
90
+ """
91
+ if not input_text.strip():
92
+ raise ValueError("No text provided.")
93
+
94
+ flashcards = extract_flashcards(input_text, output_format, language_model)
95
+ return flashcards
environment.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: flashcard-maker
2
+ channels:
3
+ - conda-forge
4
+ - pytorch
5
+ - defaults
6
+ dependencies:
7
+ - python=3.12
8
+ - torch
9
+ - torchvision
10
+ - torchaudio
11
+ - cudatoolkit=11.7 # Remove or adjust if installing CPU-only
12
+ - transformers
13
+ - gradio
14
+ - librosa
15
+ - pytest
16
+ - pytest-mock
17
+ - pip
18
+ - pip:
19
+ - pymupdf4llm
pytest.ini ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # pytest.ini
2
+
3
+ [pytest]
4
+ filterwarnings =
5
+ ignore::DeprecationWarning
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ pytorch
2
+ transformers
3
+ gradio
4
+ librosa
5
+ pymupdf4llm
6
+ pytest
7
+ pytest-mock # Added for mocking capabilities
tests/__init__.py ADDED
File without changes
tests/conftest.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from unittest.mock import Mock
3
+ from app.models import LanguageModel
4
+
5
+ @pytest.fixture
6
+ def language_model():
7
+ """
8
+ Fixture to provide a mocked LanguageModel instance.
9
+ """
10
+ # Create a mock instance of LanguageModel
11
+ lm = Mock(spec=LanguageModel)
12
+ # Mock the generate_flashcards method
13
+ lm.generate_flashcards.return_value = '{"flashcards": []}'
14
+ return lm
tests/test_models.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tests/test_models.py
2
+
3
+ import pytest
4
+
5
+ def test_generate_flashcards(language_model, mocker):
6
+ """
7
+ Test the generate_flashcards method of LanguageModel.
8
+ """
9
+ prompt = "Sample prompt for flashcard generation."
10
+ expected_response = '{"flashcards": [{"Question": "What is AI?", "Answer": "Artificial Intelligence."}]}'
11
+
12
+ # Configure the mock to return a specific response
13
+ language_model.generate_flashcards.return_value = expected_response
14
+
15
+ # Call the method
16
+ response = language_model.generate_flashcards(prompt)
17
+
18
+ # Assertions
19
+ assert response == expected_response
20
+ language_model.generate_flashcards.assert_called_once_with(prompt)
tests/test_processing.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tests/test_processing.py
2
+
3
+ import pytest
4
+ from app.processing import process_text_input, process_file
5
+
6
+ def test_process_text_input_success(language_model):
7
+ """
8
+ Test processing of valid text input.
9
+ """
10
+ input_text = "This is a sample text for flashcard extraction."
11
+ output_format = "JSON"
12
+ expected_output = '{"flashcards": []}'
13
+
14
+ result = process_text_input(input_text, output_format, language_model)
15
+ assert result == expected_output
16
+ language_model.generate_flashcards.assert_called_once()
17
+
18
+ def test_process_text_input_empty(language_model):
19
+ """
20
+ Test processing of empty text input.
21
+ """
22
+ input_text = " "
23
+ output_format = "JSON"
24
+
25
+ with pytest.raises(ValueError) as excinfo:
26
+ process_text_input(input_text, output_format, language_model)
27
+ assert "No text provided." in str(excinfo.value)
28
+
29
+ def test_process_file_unsupported_type(language_model, tmp_path):
30
+ """
31
+ Test processing of an unsupported file type.
32
+ """
33
+ # Create a dummy unsupported file
34
+ dummy_file = tmp_path / "dummy.unsupported"
35
+ dummy_file.write_text("Unsupported content")
36
+
37
+ with pytest.raises(ValueError) as excinfo:
38
+ process_file(dummy_file, "JSON", language_model)
39
+ assert "Unsupported file type." in str(excinfo.value)
40
+
41
+ def test_process_file_pdf(language_model, tmp_path, mocker):
42
+ """
43
+ Test processing of a PDF file.
44
+ """
45
+ # Mock the process_pdf function
46
+ mocker.patch('app.processing.process_pdf', return_value="Extracted PDF text.")
47
+
48
+ # Create a dummy PDF file
49
+ dummy_file = tmp_path / "test.pdf"
50
+ dummy_file.write_text("PDF content")
51
+
52
+ expected_output = '{"flashcards": []}'
53
+
54
+ result = process_file(dummy_file, "JSON", language_model)
55
+ assert result == expected_output
56
+ language_model.generate_flashcards.assert_called_once()
57
+
58
+ def test_process_file_txt(language_model, tmp_path, mocker):
59
+ """
60
+ Test processing of a TXT file.
61
+ """
62
+ # Mock the read_text_file function
63
+ mocker.patch('app.processing.read_text_file', return_value="Extracted TXT text.")
64
+
65
+ # Create a dummy TXT file
66
+ dummy_file = tmp_path / "test.txt"
67
+ dummy_file.write_text("TXT content")
68
+
69
+ expected_output = '{"flashcards": []}'
70
+
71
+ result = process_file(dummy_file, "JSON", language_model)
72
+ assert result == expected_output
73
+ language_model.generate_flashcards.assert_called_once()