-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support hallucination score #218
Open
WenjingKangIntel
wants to merge
5
commits into
openvinotoolkit:master
Choose a base branch
from
jnzw:issue-199-dev-hallucination
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
9b2bf55
* Support hallucination score
WenjingKangIntel 8c5c91f
updated the action.yml
xinpengzz 6ee01c0
Merge branch 'master' into issue-199-dev-hallucination
xinpengzz 4fe39f2
modified some files name
xinpengzz 6897df5
import `load_chat_model` from `main.py` instead of copying it.
xinpengzz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# How to select hallucination computation algorithm | ||
|
||
Currently, two methods are availble: [deepeval](#use-deepeval-to-compute-hallucination-score) and [selfcheckgpt](#use-selfcheckgpt-to-compute-hallucination-score). | ||
|
||
If you have an evaluation dataset (i.e. both question and correct answer), you can choose [deepeval](#use-deepeval-to-compute-hallucination-score). However, if you do not have a labeled dataset, you can choose [selfcheckgpt](#use-selfcheckgpt-to-compute-hallucination-score). It will compute hallucination score based on the output consistency. | ||
|
||
# Use deepeval to compute hallucination score | ||
## Prerequisite libraries | ||
1. [deepeval](https://github.com/confident-ai/deepeval) | ||
2. [Ollama](https://github.com/ollama/ollama/blob/main/README.md) | ||
|
||
## How to set up | ||
1. Install deepeval: | ||
``` | ||
pip install -U deepeval | ||
``` | ||
2. Install Ollama: | ||
Please refer to [ollama](https://github.com/ollama/ollama/blob/main/README.md#ollama) | ||
|
||
3. Run Ollama, taking `deepseek-r1` as an example: | ||
``` | ||
ollama run deepseek-r1 | ||
``` | ||
4. Set deepeval to use Ollama for evaluation: | ||
``` | ||
deepeval set-ollama deepseek-r1 | ||
``` | ||
|
||
## How to run the test | ||
``` | ||
python test.py --personality /path/to/personality.yaml --check_type deepeval | ||
``` | ||
|
||
## More to read | ||
[deepeval hallucination](https://docs.confident-ai.com/docs/metrics-hallucination) | ||
|
||
# Use selfcheckgpt to compute hallucination score | ||
## Prerequisite libraries | ||
1. [selfcheckgpt](https://github.com/potsawee/selfcheckgpt) | ||
|
||
## How to set up and run the test | ||
1. Install deepeval: | ||
``` | ||
pip install selfcheckgpt==0.1.7 | ||
``` | ||
|
||
2. Run test | ||
``` | ||
python test.py --personality /path/to/personality.yaml --check_type selfcheckgpt | ||
``` |
20 changes: 20 additions & 0 deletions
20
demos/virtual_ai_assistant_demo/test/bartender_questions.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
Can you suggest some popular fruit-based drinks that are healthy and refreshing? | ||
Can you suggest some recipes using your favorite juices or ingredients? | ||
Can you suggest some refreshing drinks with watermelon or lime? | ||
Can you suggest some tropical juices or smoothies with kiwi or banana? | ||
What are the ingredients in a classic Martini? | ||
What are some popular drinks that use pomegranate juice? | ||
Can you suggest a cocktail that uses honey? | ||
What are the ingredients in a classic Daiquiri? | ||
Can you recommend a cocktail that uses apple cider? | ||
What are some popular drinks that use cranberry juice? | ||
Can you suggest a cocktail that uses chocolate? | ||
What are the ingredients in a classic Negroni? | ||
Can you recommend a cocktail that uses almond milk? | ||
What are some popular drinks that use grapefruit juice? | ||
Can you suggest a cocktail that uses lavender? | ||
What are the ingredients in a classic Pina Colada? | ||
Can you recommend a cocktail that uses maple syrup? | ||
What are some popular drinks that use lemon or lime juice? | ||
Can you suggest a cocktail that uses cinnamon? | ||
What are the ingredients in a classic Bloody Mary? |
20 changes: 20 additions & 0 deletions
20
demos/virtual_ai_assistant_demo/test/culinara_questions.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
I'm planning to cook a classic spaghetti carbonara. What ingredients do I need? | ||
Can I substitute pancetta with bacon in my carbonara? | ||
I'm planning to make a vegan lasagna. What can I use instead of ricotta cheese? | ||
How long should I bake my lasagna for the best results? | ||
I'm making a chicken curry. What spices should I use for an authentic flavor? | ||
Can I use coconut milk instead of cream in my chicken curry? | ||
I'm planning to bake a chocolate cake. What type of cocoa powder is best? | ||
Can I use almond flour instead of all-purpose flour in my cake? | ||
I'm making a Caesar salad. What ingredients are essential for the dressing? | ||
Can I use Greek yogurt instead of mayonnaise in my Caesar dressing? | ||
I'm planning to cook a beef stew. What cut of beef is best for stewing? | ||
Can I use red wine instead of beef broth in my stew? | ||
I'm planning to cook a seafood paella. What types of seafood are best to use? | ||
Can I use brown rice instead of white rice in my paella? | ||
How do I achieve the perfect socarrat (crispy bottom) in my paella? | ||
I'm making a vegetarian chili. What beans are best to use? | ||
Can I add quinoa to my chili for extra protein? | ||
I'm planning to bake a batch of cookies. What type of sugar should I use? | ||
Can I substitute butter with coconut oil in my cookies? | ||
I'm making a Greek salad. What ingredients are essential? |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
import argparse | ||
import logging as log | ||
import os | ||
|
||
from typing import List, Set | ||
from pathlib import Path | ||
from tqdm import tqdm | ||
|
||
import numpy as np | ||
import openvino as ov | ||
import yaml | ||
|
||
from datasets import load_dataset | ||
from urllib.request import getproxies | ||
from deepeval.metrics import HallucinationMetric | ||
from deepeval.test_case import LLMTestCase | ||
from llama_index.core.chat_engine import SimpleChatEngine | ||
from llama_index.core.memory import ChatMemoryBuffer | ||
from llama_index.llms.openvino import OpenVINOLLM | ||
from selfcheckgpt.modeling_selfcheck import SelfCheckLLMPrompt | ||
from transformers import AutoTokenizer | ||
|
||
import sys | ||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | ||
from main import load_chat_model | ||
|
||
proxies = getproxies() | ||
os.environ["http_proxy"] = proxies["http"] | ||
os.environ["https_proxy"] = proxies["https"] | ||
os.environ["no_proxy"] = "localhost, 127.0.0.1/8, ::1" | ||
from optimum.intel import OVModelForCausalLM, OVWeightQuantizationConfig | ||
|
||
|
||
DATASET_MAPPING = { | ||
"agribot_personality.yaml": {"name": "KisanVaani/agriculture-qa-english-only", "split": "train", "col": "question"}, | ||
"healthcare_personality.yaml": {"name": "medalpaca/medical_meadow_medical_flashcards", "split": "train", "col": "input"}, | ||
"bartender_personality.yaml": {"name": str(Path(__file__).parent / "bartender_questions.txt"), "col": "text"}, | ||
"culinara_personality.yaml": {"name": str(Path(__file__).parent / "culinara_questions.txt"), "col": "text"}, | ||
"tutor_personality.yaml": {"name": str(Path(__file__).parent / "tutor_questions.txt"), "col": "text"} | ||
} | ||
MODEL_DIR = Path("model") | ||
|
||
|
||
def get_available_devices() -> Set[str]: | ||
core = ov.Core() | ||
return {device.split(".")[0] for device in core.available_devices} | ||
|
||
|
||
def compute_deepeval_hallucination(inputs, outputs, contexts) -> float: | ||
avg_score = 0. | ||
for input, output, context in zip(inputs, outputs, contexts): | ||
test_case = LLMTestCase( | ||
input=input, | ||
actual_output=output, | ||
context=context | ||
) | ||
metric = HallucinationMetric(threshold=0.5) | ||
metric.measure(test_case) | ||
score = metric.score | ||
# reason = metric.reason | ||
avg_score += score / len(inputs) | ||
return avg_score | ||
|
||
|
||
def prepare_dataset_and_model(chat_model_name: str, personality_file_path: Path, auth_token: str): | ||
dataset_info = DATASET_MAPPING.get(personality_file_path.name, "") | ||
assert dataset_info != "" | ||
log.info("Loading dataset") | ||
if dataset_info["name"].endswith(".txt"): | ||
dataset = load_dataset("text", data_files={"data": dataset_info["name"]})["data"] | ||
else: | ||
dataset = load_dataset(dataset_info["name"])[dataset_info["split"]] | ||
log.info("Dataset loading is finished") | ||
|
||
with open(personality_file_path, "rb") as f: | ||
chatbot_config = yaml.safe_load(f) | ||
|
||
ov_llm = load_chat_model(chat_model_name, auth_token) | ||
ov_chat_engine = SimpleChatEngine.from_defaults(llm=ov_llm, system_prompt=chatbot_config["system_configuration"], | ||
memory=ChatMemoryBuffer.from_defaults()) | ||
return dataset[dataset_info["col"]], ov_chat_engine | ||
|
||
|
||
def run_test_deepeval(chat_model_name: str, personality_file_path: Path, auth_token: str, selection_num: int = 10) -> float: | ||
""" | ||
Args: | ||
chat_model_name (str): large language model path. | ||
personality_file_path (Path): personality file path. | ||
auth_token (str): auth token used for huggingface. | ||
selection_num (int): maximum number of prompt are selected to compute hallucination score | ||
|
||
Returns: | ||
hallucination score: the higher the score, the higher possibility of having hallucination issue. | ||
""" | ||
dataset_question, ov_chat_engine = prepare_dataset_and_model(chat_model_name, personality_file_path, auth_token) | ||
inputs = dataset_question | ||
# We use question as context because the dataset lacks context | ||
contexts = dataset_question | ||
contexts_res = [[context] for context in contexts] | ||
|
||
outputs = [] | ||
for input in tqdm(inputs[:selection_num]): | ||
output = ov_chat_engine.chat(input).response | ||
outputs.append(output) | ||
|
||
final_score = compute_deepeval_hallucination(inputs[:selection_num], outputs[:selection_num], contexts_res[:selection_num]) | ||
return final_score | ||
|
||
|
||
class OVSelfCheckLLMPrompt(SelfCheckLLMPrompt): | ||
def __init__(self, ov_chat_engine: SimpleChatEngine): | ||
self.ov_chat_engine = ov_chat_engine | ||
self.text_mapping = {'yes': 0.0, 'no': 1.0, 'n/a': 0.5} | ||
self.prompt_template = "Context: {context}\n\nSentence: {sentence}\n\nIs the sentence supported by the context above? Answer Yes or No.\n\nAnswer: " | ||
self.not_defined_text = set() | ||
self.generate_num = 3 | ||
|
||
def generate_outputs(self, prompt_list: List[str]) -> List[str]: | ||
response_list = [] | ||
for prompt in tqdm(prompt_list, desc="generating responses"): | ||
tmp_list = [] | ||
for _ in range(self.generate_num): | ||
response = self.ov_chat_engine.chat(prompt).response | ||
# remove </think> part | ||
response = response[response.rfind("</think>") + 8:].strip() | ||
tmp_list.append(response) | ||
response_list.append(tmp_list) | ||
return response_list | ||
|
||
def predict( | ||
self, | ||
sampled_passages: List[str], | ||
) -> np.array: | ||
num_samples = len(sampled_passages) | ||
scores = np.zeros((num_samples, num_samples)) | ||
|
||
for sent_i in range(num_samples): | ||
sentence = sampled_passages[sent_i].replace("\n", " ") | ||
for sample_i, sample in enumerate(sampled_passages): | ||
if sent_i == sample_i: | ||
continue | ||
|
||
# this seems to improve performance when using the simple prompt template | ||
sample = sample.replace("\n", " ") | ||
prompt = self.prompt_template.format(context=sample, sentence=sentence) | ||
generate_output = self.ov_chat_engine.chat(prompt).response | ||
|
||
# get text after </think> | ||
truncate_output = generate_output[generate_output.rfind("</think>") + 8:].strip() | ||
score_ = self.text_postprocessing(truncate_output) | ||
scores[sent_i, sample_i] = score_ | ||
|
||
avg_score = np.sum(scores) / num_samples / (num_samples - 1) | ||
return avg_score | ||
|
||
|
||
def run_test_selfcheckgpt(chat_model_name: str, personality_file_path: Path, auth_token: str, selection_num: int = 10) -> float: | ||
""" | ||
Args: | ||
chat_model_name (str): large language model path. | ||
personality_file_path (Path): personality file path. | ||
auth_token (str): auth token used for huggingface. | ||
selection_num (int): maximum number of prompt are selected to compute hallucination score | ||
|
||
Returns: | ||
hallucination score: the higher the score, the higher possibility of having hallucination issue. | ||
""" | ||
dataset_question, ov_chat_engine = prepare_dataset_and_model(chat_model_name, personality_file_path, auth_token) | ||
check_eng = OVSelfCheckLLMPrompt(ov_chat_engine) | ||
response_list = check_eng.generate_outputs(dataset_question[:selection_num]) | ||
score_list = [] | ||
for response_list_per_prompt in tqdm(response_list, desc="predict hallucination ratio"): | ||
score_list.append(check_eng.predict(response_list_per_prompt)) | ||
final_score = float(np.mean(score_list)) | ||
return final_score | ||
|
||
|
||
if __name__ == "__main__": | ||
# set up logging | ||
log.getLogger().setLevel(log.INFO) | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--chat_model", type=str, default="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", help="Path/name of the chat model") | ||
parser.add_argument("--personality", type=str, default="../healthcare_personality.yaml", help="Path to the YAML file with chatbot personality") | ||
parser.add_argument("--hf_token", type=str, help="HuggingFace access token to get Llama3") | ||
parser.add_argument("--check_type", type=str, choices=["deepeval", "selfcheckgpt"], default="deepeval", help="Hallucination check type") | ||
parser.add_argument("--selection_num", type=int, default=5, help="Maximum number of prompt are selected to compute hallucination score") | ||
|
||
args = parser.parse_args() | ||
if args.check_type == "deepeval": | ||
hallucination_score = run_test_deepeval(args.chat_model, Path(args.personality), args.hf_token, args.selection_num) | ||
else: | ||
hallucination_score = run_test_selfcheckgpt(args.chat_model, Path(args.personality), args.hf_token, args.selection_num) | ||
print(f"hallucination_score for personality {args.personality}: {hallucination_score}") |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using deepseek-r1 will take forever. Is it possible to use any of the distilled models? https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, the command
ollama run deepseek-r1
will rundeepseek-r1-distill-qwen-7B
model, not the 681B version model. Within the deepseek-r1 series, the only smaller model available is thedeepseek-r1-distill-qwen-1.5B
, but its performance is not so good.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then ok