개발로그

Hugging Face x LangChain을 활용하여 이전 대화를 기억하는 챗봇 만들기

pizzalist 2024. 5. 30. 20:58

목적

LLaMA3를 이용해 챗봇 api를 만들고 싶었습니다. 

 

Hugging Face x LangChain

  1. Hugging Face x LangChain tool을 이용하여 LLaMA3 inference 합니다.
pip install langchain-huggingface

 

다음과 같이 HuggingFacePipeline를 불러올 수 있습니다.

from langchain_huggingface import HuggingFacePipeline

llm = HuggingFacePipeline.from_model_id(
    model_id="microsoft/Phi-3-mini-4k-instruct",
    task="text-generation",
    pipeline_kwargs={
        "max_new_tokens": 100,
        "top_k": 50,
        "temperature": 0.1,
    },
)
llm.invoke("Hugging Face is")​

이를 적용해 huggingface의 Meta-Llama-3-8B-Instruct을 불러올 수 있다. 모델에게 두개의 질문을 했을때 이전 대화를 기억하는지 알아 보기 위해 다음과 같이 코드를 작성했습니다.

  • 첫번째 질문: hi, my name is pizza. how are you?
  • 두번째 질문: What is my name?
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
import torch

from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory

# CUDA 디바이스 설정
device = 0 if torch.cuda.is_available() else -1

llm = HuggingFacePipeline.from_model_id(
    model_id="meta-llama/Meta-Llama-3-8B-Instruct",
    task="text-generation",
    device=device,  
    
    # bfp16 사용
    model_kwargs={"torch_dtype": torch.bfloat16},
    
    pipeline_kwargs={
        "max_new_tokens": 256,
        "top_k": 50,
        "temperature": 0.5,
    },
)
print(llm.invoke("<|begin_of_text|><|start_header_id|>user<|end_header_id|>hi, my name is pizza. how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|"))

print(llm.invoke("<|begin_of_text|><|start_header_id|>user<|end_header_id|What is my name?<|eot_id|><|start_header_id|>assistant<|end_header_id|"))
  • 첫번째 질문: hi, my name is pizza. how are you?
    • 답변: Hi pizza! I'm doing well, thank you for asking! I'm a large language model, so I don't have feelings or emotions like humans do, but I'm always happy to chat with you and help with any questions or topics you'd like to discuss. What's on your mind today?
  • 두번째 질문: What is my name?
    • 답변: I'm happy to help! However, I'm a large language model, I don't have the ability to know your name or any personal information about you. Each time you interact with me, it's a new conversation and I don't retain any information from previous conversations. If you'd like to share your name with me, I'm happy to learn it!

위에서 알수 있듯이, 이전의 대화를 기억하지 못하는 것을 확인할 수 있습니다.

 

Add message history를 통한 이전 메세지를 기억하는 모델 구조

위 인퍼런스 코드에서 RunnableWithMessageHistory을 추가해 특정 유형의 체인에 메시지 기록을 추가할 수 있습니다. 다른 Runnable을 래핑하고 이에 대한 채팅 메시지 기록을 관리합니다.

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
import torch

from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory


# CUDA 디바이스 설정
device = 0 if torch.cuda.is_available() else -1

llm = HuggingFacePipeline.from_model_id(
    model_id="meta-llama/Meta-Llama-3-8B-Instruct",
    task="text-generation",
    device=device,  
    
    # bfp16 사용
    model_kwargs={"torch_dtype": torch.bfloat16},
    
    pipeline_kwargs={
        "max_new_tokens": 256,
        "top_k": 50,
        "temperature": 0.5,
    },
)

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "<|begin_of_text|><|start_header_id|>user<|end_header_id|>",
        ),
        MessagesPlaceholder(variable_name="history"),
        ("human", "{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|"),
    ]
)

runnable = prompt | llm
  • get_session_history를 통해 store라는 dict값에 ChatMessageHistory을 할당
store = {}

def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]


with_message_history = RunnableWithMessageHistory(
    runnable,
    get_session_history,
    input_messages_key="input",
    history_messages_key="history",
)
# 첫 번째 대화
response1 = with_message_history.invoke(
    {"input": "Hi, my name is noah."},
    config={"configurable": {"session_id": "abc123"}},
)
print(response1)

# 두 번째 대화
response2 = with_message_history.invoke(
    {"input": "What is my name?"},
    config={"configurable": {"session_id": "abc123"}},
)
print(response2)
  • 첫번째 질문: hi, my name is pizza. how are you?
    • 답변: Nice to meet you, pizza! I'm LLaMA, your friendly AI assistant. It's great to have you here. What brings you to this conversation today?
  • 두번째 질문: What is my name?
    • 답변: Your name is pizza!

다음과 같이 이전 대화를 기억하는 것을 확인 할 수 있습니다.

챗봇의 경우 이전 대화를 기억해야한 사용자 만족도가 높은 기능을 유지할 수 있다. 지금부터는 dict에 할당하는 것이 아닌 Redis를 통해 대화 로그를 저장하는 방법을 설명하겠습니다.

1.

%pip install --upgrade --quiet redis

 

2. 기존 Redis가 없는 경우, docker를 통해 Redis-stack 서버를 시작합니다.

docker run -d -p 6379:6379 -p 8001:8001 redis/redis-stack:latest

 

 

3. 아래와 같이 dict에 할당하는 get_session_history함수가 아닌, RedisChatMessageHistory 인스턴스를 반환하는 get_session_history를 정의해준다.
  • 기존 get_session_history
store = {}

def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]

 

  • get_message_history로 변경
def get_message_history(session_id: str) -> RedisChatMessageHistory:
    return RedisChatMessageHistory(session_id, url=REDIS_URL)


with_message_history = RunnableWithMessageHistory(
    runnable,
    get_message_history,
    input_messages_key="input",
    history_messages_key="history",
)
 

위 내용을 이용해서 LLaMA3를 이용한 챗봇 api 코드 베이스를 공유드립니다.

본 챗봇의 목적은 뉴스 요약 챗봇이며 모델은 해당 task에 적합하게 fine tuning한 “letgoofthepizza/Llama-3-8B-Instruct-ko-news-summary” 모델을 활용하였습니다. 현재는 preview 버전이며 이 모델에 대한 소개는 추후 예정입니다.

해당 코드는 모델에 맞게 수정되어야 하며 기본적인 message history를 가진 챗봇 기본 코드입니다.

 

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse

from fastapi.middleware.cors import CORSMiddleware

from pydantic import BaseModel

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_huggingface import HuggingFacePipeline
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables import ConfigurableFieldSpec
from langchain_community.chat_message_histories import RedisChatMessageHistory


import torch
import uuid

# CUDA 디바이스 설정
device = 0 if torch.cuda.is_available() else -1

REDIS_URL = "redis://localhost:6379/0"

app = FastAPI()

llm = HuggingFacePipeline.from_model_id(
    model_id="letgoofthepizza/Llama-3-8B-Instruct-ko-news-summary",
    task="text-generation",
    device=device,  
    
    # bfp16 사용
    model_kwargs={"torch_dtype": torch.bfloat16},
    
    pipeline_kwargs={
        "max_new_tokens": 500,
        "top_k": 50,
        "temperature": 0.5,
    },
)

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "<|begin_of_text|><|start_header_id|>Please explain your question or request in Korean. The output must be answered in complete sentences.<|end_header_id|>",
        ),
        MessagesPlaceholder(variable_name="history"),
        ("human", "{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"),
    ]
)

runnable = prompt | llm 


def get_message_history(session_id: str) -> RedisChatMessageHistory:
    return RedisChatMessageHistory(session_id, url=REDIS_URL)


with_message_history = RunnableWithMessageHistory(
    runnable,
    get_message_history,
    input_messages_key="input",
    history_messages_key="history",
)

class ChatRequest(BaseModel):
    input: str
    session_id: str = None


@app.post("/chat")
async def chat(request: ChatRequest):
    session_id = request.session_id

    response = with_message_history.invoke(
        {"input": request.input},
        config={"configurable": {"session_id": session_id}},
    )
    return JSONResponse({"response": response, "session_id": session_id})

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

 

구현 모습

Todo

  • 일정 message history만을 유지하도록 limit 장치 만들기
    • 대화가 길어지면 history가 많아져 OOM(Out of memory) 발생 가능성 막을 의도
  • llama3의 output이 불안정하게 나와 완벽한 질문의 답만 뽑기 어려움 

Todo또한 해결하면 바로 업로드 하겠습니다. 

감사합니다.

 

출처:

https://huggingface.co/blog/langchain

https://python.langchain.com/v0.1/docs/expression_language/how_to/message_history/