"""Wrapper for the Reddit API"""
from typing import Any, Dict, List, Optional
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.utils import get_from_dict_or_env
[docs]class RedditSearchAPIWrapper(BaseModel):
"""Wrapper for Reddit API
To use, set the environment variables ``REDDIT_CLIENT_ID``,
``REDDIT_CLIENT_SECRET``, ``REDDIT_USER_AGENT`` to set the client ID,
client secret, and user agent, respectively, as given by Reddit's API.
Alternatively, all three can be supplied as named parameters in the
constructor: ``reddit_client_id``, ``reddit_client_secret``, and
``reddit_user_agent``, respectively.
Example:
.. code-block:: python
from langchain_community.utilities import RedditSearchAPIWrapper
reddit_search = RedditSearchAPIWrapper()
"""
reddit_client: Any
# Values required to access Reddit API via praw
reddit_client_id: Optional[str]
reddit_client_secret: Optional[str]
reddit_user_agent: Optional[str]
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the API ID, secret and user agent exists in environment
and check that praw module is present.
"""
reddit_client_id = get_from_dict_or_env(
values, "reddit_client_id", "REDDIT_CLIENT_ID"
)
values["reddit_client_id"] = reddit_client_id
reddit_client_secret = get_from_dict_or_env(
values, "reddit_client_secret", "REDDIT_CLIENT_SECRET"
)
values["reddit_client_secret"] = reddit_client_secret
reddit_user_agent = get_from_dict_or_env(
values, "reddit_user_agent", "REDDIT_USER_AGENT"
)
values["reddit_user_agent"] = reddit_user_agent
try:
import praw
except ImportError:
raise ImportError(
"praw package not found, please install it with pip install praw"
)
reddit_client = praw.Reddit(
client_id=reddit_client_id,
client_secret=reddit_client_secret,
user_agent=reddit_user_agent,
)
values["reddit_client"] = reddit_client
return values
[docs] def run(
self, query: str, sort: str, time_filter: str, subreddit: str, limit: int
) -> str:
"""Search Reddit and return posts as a single string."""
results: List[Dict] = self.results(
query=query,
sort=sort,
time_filter=time_filter,
subreddit=subreddit,
limit=limit,
)
if len(results) > 0:
output: List[str] = [f"Searching r/{subreddit} found {len(results)} posts:"]
for r in results:
category = "N/A" if r["post_category"] is None else r["post_category"]
p = f"Post Title: '{r['post_title']}'\n\
User: {r['post_author']}\n\
Subreddit: {r['post_subreddit']}:\n\
Text body: {r['post_text']}\n\
Post URL: {r['post_url']}\n\
Post Category: {category}.\n\
Score: {r['post_score']}\n"
output.append(p)
return "\n".join(output)
else:
return f"Searching r/{subreddit} did not find any posts:"
[docs] def results(
self, query: str, sort: str, time_filter: str, subreddit: str, limit: int
) -> List[Dict]:
"""Use praw to search Reddit and return a list of dictionaries,
one for each post.
"""
subredditObject = self.reddit_client.subreddit(subreddit)
search_results = subredditObject.search(
query=query, sort=sort, time_filter=time_filter, limit=limit
)
search_results = [r for r in search_results]
results_object = []
for submission in search_results:
results_object.append(
{
"post_subreddit": submission.subreddit_name_prefixed,
"post_category": submission.category,
"post_title": submission.title,
"post_text": submission.selftext,
"post_score": submission.score,
"post_id": submission.id,
"post_url": submission.url,
"post_author": submission.author,
}
)
return results_object