Auto-Refine The Prompt
How do we remove irrelevant information from the prompt?
The S2A (System 2 Attention) technique auto-refines a prompt by asking the model to rewrite the prompt to include only relevant information. We implement this in two steps:
- Ask the model to rewrite the prompt
- Pass the rewritten prompt back to the model
Implementation¶
import openai
import instructor
from pydantic import BaseModel, Field
client = instructor.from_openai(openai.OpenAI())
class Step1(BaseModel):
    relevant_context: str = Field(..., description="Relevant context")
    user_query: str = Field(..., description="The question from the user")
class Step2(BaseModel):
    answer: int
def rewrite_prompt(query):
    rewritten_prompt = client.chat.completions.create(
        model="gpt-4o",
        response_model=Step1,
        messages=[
            {
                "role": "user",
                "content": f"""
                    Given the following text by a user, extract the part
                    that is actually relevant to their question. Please
                    include the actual question or query that the user
                    is asking.
                    Text by user:
                    {query}
                    """,  # (1)!
            }
        ],
    )
    return rewritten_prompt
def generate_final_response(rewritten_prompt):
    final_response = client.chat.completions.create(
        model="gpt-4o",
        response_model=Step2,
        messages=[
            {
                "role": "user",
                "content": f"""{rewritten_prompt.relevant_context}
                    Question: {rewritten_prompt.user_query}""",
            }
        ],
    )
    return final_response
if __name__ == "__main__":
    query = """Mary has 3 times as much candy as Megan.
        Mary then adds 10 more pieces of candy to her collection.
        Max is 5 years older than Mary.
        If Megan has 5 pieces of candy, how many does Mary have in total?
        """
    # Step 1: Rewrite the prompt
    rewritten_prompt = rewrite_prompt(query)
    print(rewritten_prompt.relevant_context)
    """
    Mary has 3 times as much candy as Megan. Mary then adds 10 more pieces of candy to her collection. If Megan has 5 pieces of candy, how many does Mary have in total?
    """
    print(rewritten_prompt.user_query)
    #> how many does Mary have in total?
    # Step 2: Generate the final response
    final_response = generate_final_response(rewritten_prompt)
    print(final_response.answer)
    #> 25
- This prompt template comes from this paper.