PyTorch Question Answering
Introduction
Question Answering (QA) is an important task in Natural Language Processing (NLP) that involves designing systems capable of automatically answering questions posed in natural language. In recent years, transformer-based models have revolutionized this field, enabling highly accurate responses to a wide range of questions.
In this tutorial, we'll learn how to implement question answering systems using PyTorch and the Transformers library. We'll cover:
- Understanding the question answering task
- Types of question answering systems
- Implementing extractive QA with PyTorch
- Fine-tuning pre-trained models for QA
- Evaluating QA models
- Practical applications and use cases
Understanding Question Answering
Question Answering systems take a question as input and provide a relevant answer. There are two main approaches:
- Extractive QA: The answer is extracted as a span of text from a provided context/passage.
- Generative QA: The answer is generated from scratch, without necessarily being constrained to text in a context.
In this tutorial, we'll focus mainly on extractive QA, which is more common and typically more accurate for factual questions.
Setting Up the Environment
Let's start by installing the necessary libraries:
pip install torch transformers datasets evaluate
Now, let's import the libraries we'll need:
import torch
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
from datasets import load_dataset
import evaluate
Extractive Question Answering with Pre-trained Models
The Hugging Face Transformers library provides pre-trained models specifically designed for question answering tasks. Let's see how to use them:
Using the QA Pipeline
The simplest way to get started with question answering is to use the pipeline
function:
# Load a pre-trained QA pipeline
qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
# Context and question
context = """
PyTorch is an open source machine learning library based on the Torch library,
used for applications such as computer vision and natural language processing.
It is primarily developed by Facebook's AI Research lab (FAIR).
PyTorch provides two high-level features: tensor computations with strong
GPU acceleration support and deep neural networks built on a tape-based
autograd system.
"""
question = "Who develops PyTorch?"
# Get the answer
result = qa_pipeline(question=question, context=context)
print(f"Answer: {result['answer']}")
print(f"Score: {result['score']:.4f}")
print(f"Start position: {result['start']}")
print(f"End position: {result['end']}")
Output:
Answer: Facebook's AI Research lab (FAIR)
Score: 0.8764
Start position: 161
End position: 193
How Extractive QA Works
In extractive QA, models predict:
- The start position of the answer in the context
- The end position of the answer in the context
The model gives each token in the context a "start score" and an "end score," then selects the span with the highest combined score as the answer.
Building a Custom QA System with PyTorch
Let's build a more customized QA system using a pre-trained model:
# Load pre-trained model and tokenizer
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
def answer_question(question, context):
# Tokenize input
inputs = tokenizer(
question,
context,
add_special_tokens=True,
return_tensors="pt"
)
# Get model prediction
with torch.no_grad():
outputs = model(**inputs)
# Get start and end positions
answer_start = torch.argmax(outputs.start_logits)
answer_end = torch.argmax(outputs.end_logits)
# Convert tokens to answer string
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
answer = tokenizer.convert_tokens_to_string(tokens[answer_start:answer_end+1])
# Clean up answer
answer = answer.replace("[CLS]", "").replace("[SEP]", "").strip()
return answer
# Example usage
question = "What is PyTorch used for?"
context = """
PyTorch is a machine learning framework based on the Torch library.
It is primarily used for applications such as natural language processing
and computer vision. PyTorch has gained popularity among researchers
due to its flexibility and ease of use.
"""
answer = answer_question(question, context)
print(f"Q: {question}")
print(f"A: {answer}")
Output:
Q: What is PyTorch used for?
A: applications such as natural language processing and computer vision
Fine-tuning a QA Model on Custom Data
Now let's see how to fine-tune a pre-trained model on a custom dataset for question answering:
Preparing the Dataset
We'll use the SQuAD (Stanford Question Answering Dataset) for this example:
# Load SQuAD dataset
squad = load_dataset("squad", split="train[:1000]") # Using a subset for demonstration
# Preprocess function for tokenization
def preprocess_function(examples):
questions = [q.strip() for q in examples["question"]]
contexts = [c.strip() for c in examples["context"]]
# Tokenize inputs
inputs = tokenizer(
questions,
contexts,
max_length=384,
truncation="only_second",
stride=128,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length"
)
# Map example_id to corresponding features
offset_mapping = inputs.pop("offset_mapping")
sample_map = inputs.pop("overflow_to_sample_mapping")
# Get start and end positions
answers = examples["answers"]
start_positions = []
end_positions = []
for i, offset in enumerate(offset_mapping):
sample_idx = sample_map[i]
answer = answers[sample_idx]
start_char = answer["answer_start"][0]
end_char = start_char + len(answer["text"][0])
# Find the token positions that contain the answer
start_position = 0
end_position = 0
for j, (start, end) in enumerate(offset):
if start <= start_char < end:
start_position = j
if start <= end_char <= end and end_char > 0:
end_position = j
start_positions.append(start_position)
end_positions.append(end_position)
inputs["start_positions"] = start_positions
inputs["end_positions"] = end_positions
return inputs
# Tokenize and preprocess the dataset
tokenized_squad = squad.map(
preprocess_function,
batched=True,
remove_columns=squad.column_names
)
Fine-tuning the Model
Now we can fine-tune our model:
from transformers import Trainer, TrainingArguments
# Define training arguments
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=3e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=3,
weight_decay=0.01,
)
# Create Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_squad,
tokenizer=tokenizer,
)
# Start training
trainer.train()
# Save fine-tuned model
model.save_pretrained("./my-qa-model")
tokenizer.save_pretrained("./my-qa-model")
Evaluating the Model
Let's evaluate our fine-tuned model using common QA metrics:
# Load evaluation dataset
squad_eval = load_dataset("squad", split="validation[:500]") # Subset for demonstration
# Evaluation metric
metric = evaluate.load("squad")
# Prediction function
def compute_predictions(examples):
inputs = tokenizer(
examples["question"],
examples["context"],
padding="max_length",
truncation="only_second",
max_length=384,
return_tensors="pt"
)
# Get predictions
with torch.no_grad():
outputs = model(**{k: v.to(model.device) for k, v in inputs.items()})
start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()
# Convert to answer text
predictions = []
for i in range(len(examples["question"])):
start_idx = np.argmax(start_logits[i])
end_idx = np.argmax(end_logits[i])
# Get the answer text from the context
answer = tokenizer.convert_tokens_to_string(
tokenizer.convert_ids_to_tokens(
inputs.input_ids[i][start_idx:end_idx+1]
)
)
predictions.append({"prediction_text": answer, "id": examples["id"][i]})
return predictions
# Prepare references
references = [{"answers": {"answer_start": ex["answers"]["answer_start"], "text": ex["answers"]["text"]}, "id": ex["id"]} for ex in squad_eval]
# Get predictions
predictions = compute_predictions(squad_eval)
# Calculate metrics
result = metric.compute(predictions=predictions, references=references)
print(f"Exact Match: {result['exact_match']:.2f}")
print(f"F1 Score: {result['f1']:.2f}")
Note: The evaluation code above is simplified and may need adjustments based on the actual data structure.
Real-World Applications
Question answering systems have numerous practical applications:
1. Customer Support Automation
# Example: Customer Support QA System
support_qa = pipeline("question-answering", model="./my-qa-model")
support_docs = """
Our return policy allows customers to return items within 30 days of purchase.
To initiate a return, log into your account, go to Order History, and select the Return option.
You will receive a refund within 5-7 business days after we receive the returned item.
For damaged items, please contact our support team at [email protected].
"""
customer_question = "How many days do I have to return an item?"
answer = support_qa(question=customer_question, context=support_docs)
print(f"Customer: {customer_question}")
print(f"Support Bot: {answer['answer']}")
2. Information Extraction from Documents
# Example: Document QA for information extraction
document_qa = pipeline("question-answering", model="./my-qa-model")
research_paper = """
The study conducted in 2022 showed that the new treatment reduced symptoms
by 65% in the treatment group compared to 23% in the control group.
Side effects were minimal, with only 7% of participants reporting mild headaches.
The study included 450 participants across 12 medical centers in North America.
"""
questions = [
"What was the symptom reduction percentage?",
"How many participants were in the study?",
"What side effects were reported?"
]
for question in questions:
answer = document_qa(question=question, context=research_paper)
print(f"Q: {question}")
print(f"A: {answer['answer']}")
print()
3. Educational Assistant
QA systems can help students understand complex topics:
# Example: Educational QA Assistant
education_qa = pipeline("question-answering", model="./my-qa-model")
lesson_content = """
Photosynthesis is the process by which green plants and certain other organisms
transform light energy into chemical energy. During photosynthesis in green plants,
light energy is captured and used to convert water, carbon dioxide, and minerals
into oxygen and energy-rich organic compounds. The process occurs in the chloroplasts,
specifically using chlorophyll, the green pigment involved in photosynthesis.
"""
student_questions = [
"What is photosynthesis?",
"Where does photosynthesis take place?",
"What are the inputs of photosynthesis?"
]
for question in student_questions:
answer = education_qa(question=question, context=lesson_content)
print(f"Student: {question}")
print(f"Assistant: {answer['answer']}")
print()
Beyond Extractive QA: Generative Question Answering
While we've focused on extractive QA, generative QA is becoming increasingly popular with the advancement of language models:
from transformers import AutoModelForSeq2SeqLM
# Load a T5 model fine-tuned for generative QA
model_name = "google/t5-small-qa"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
def generate_answer(question, context):
input_text = f"question: {question} context: {context}"
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
outputs = model.generate(
inputs.input_ids,
max_length=64,
num_beams=4,
early_stopping=True
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
return answer
# Example usage
question = "Who invented the light bulb?"
context = """
The invention of the light bulb is often attributed to Thomas Edison,
who patented his design in 1879. However, the history of electric lighting
began long before Edison. In 1802, Humphry Davy invented the first electric
light, which was an electric arc lamp. Many other inventors like Joseph Swan
also worked on incandescent lamps around the same time as Edison.
"""
answer = generate_answer(question, context)
print(f"Q: {question}")
print(f"A: {answer}")
Summary
In this tutorial, we've covered:
- The fundamentals of question answering in NLP
- How to use pre-trained QA models with PyTorch and Transformers
- The process of fine-tuning QA models on custom datasets
- Evaluating QA systems using standard metrics
- Practical applications of QA systems in various domains
- An introduction to generative QA as an alternative to extractive QA
Question answering is a rapidly evolving area of NLP with numerous practical applications. By combining the power of PyTorch with pre-trained transformer models, you can build sophisticated QA systems tailored to your specific needs.
Additional Resources
- SQuAD Dataset - Stanford Question Answering Dataset
- HuggingFace QA Documentation
- PyTorch Documentation
- BERT Paper - The original BERT paper that revolutionized QA
Exercises
- Basic QA: Fine-tune a QA model on a small subset of SQuAD and evaluate its performance.
- Domain-Specific QA: Create a custom QA dataset in a specific domain (e.g., medical, legal) and fine-tune a model on it.
- Error Analysis: Analyze the errors made by your QA model. What types of questions are most challenging?
- Cross-Lingual QA: Experiment with multilingual QA models to answer questions in languages other than English.
- QA Web App: Build a simple web application that allows users to ask questions about uploaded documents.
Happy coding!
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)